-
Notifications
You must be signed in to change notification settings - Fork 16
Fix policy update test #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,18 +6,22 @@ | |
|
||
import asyncio | ||
import logging | ||
from tempfile import TemporaryDirectory | ||
import shutil | ||
from pathlib import Path | ||
|
||
import pytest | ||
import pytest_asyncio | ||
|
||
import torch | ||
import torchstore as ts | ||
from forge.actors.generator import Generator | ||
|
||
from forge.actors.trainer import RLTrainer | ||
from forge.cli.config import resolve_hf_hub_paths | ||
from forge.controller.provisioner import init_provisioner | ||
|
||
from forge.controller.service.service import uuid | ||
from forge.types import LauncherConfig, ProvisionerConfig | ||
from monarch.actor import endpoint | ||
|
||
from omegaconf import DictConfig, OmegaConf | ||
|
@@ -35,13 +39,16 @@ | |
""" | ||
Run tests: | ||
|
||
pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ | ||
--config tests/integration_tests/artifacts/qwen3_1_7b_tp.yaml --use_dcp=false | ||
PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ | ||
--config tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml --use_dcp=false | ||
|
||
pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ | ||
PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ | ||
--config apps/grpo/qwen3_8b.yaml | ||
""" | ||
|
||
# Temp directory won't work for multi-node because NFS does not cover the tmp path | ||
TEST_DCP_DIR = "test_dcp_tmp" | ||
|
||
|
||
class MockRLTrainer(RLTrainer): | ||
@endpoint | ||
|
@@ -58,13 +65,27 @@ async def zero_out_model_states(self): | |
sd[k] *= 0.0 | ||
|
||
|
||
# exceptions sometimes are not propogated in monarch, do it manually | ||
def validate_fn(prev_params, curr_model, logger) -> Exception | None: | ||
def _load_config(config_path: str) -> DictConfig: | ||
cfg = None | ||
try: | ||
cfg = OmegaConf.load(config_path) | ||
except Exception as e: | ||
pytest.fail(f"Failed to load config file {config_path}: {e}") | ||
|
||
assert isinstance(cfg, DictConfig) | ||
|
||
cfg = resolve_hf_hub_paths(cfg) | ||
return cfg | ||
|
||
|
||
def _test_validate_params_unchanged( | ||
prev_params, curr_model, logger | ||
) -> Exception | None: | ||
"""Validate that current parameters are the same as prev_params.""" | ||
verified = set() | ||
skipped = set() | ||
logger.info( | ||
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" | ||
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" | ||
) | ||
errs = [] | ||
for name, param in curr_model.named_parameters(): | ||
|
@@ -83,7 +104,6 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None: | |
) | ||
verified.add(name) | ||
except Exception as e: | ||
# logger.error(f"Validation failed with exception: {e}") | ||
errs.append((name, e)) | ||
logger.info(f"Verified params = {verified}") | ||
logger.info(f"Skipped params = {skipped}") | ||
|
@@ -94,14 +114,15 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None: | |
return AssertionError(f"Validation failed: {errs}") | ||
|
||
|
||
# exceptions sometimes are not propogated in monarch, do it manually | ||
def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None: | ||
def _test_validate_params_all_zeros( | ||
prev_params, curr_model, logger | ||
) -> Exception | None: | ||
"""Validate all parameters are set to zero. prev_params is actually not used.""" | ||
_ = prev_params | ||
verified = set() | ||
skipped = set() | ||
logger.info( | ||
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" | ||
f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" | ||
) | ||
errs = [] | ||
for name, param in curr_model.named_parameters(): | ||
|
@@ -113,10 +134,9 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None: | |
param = param.cpu() | ||
assert torch.allclose( | ||
torch.zeros_like(param), param, atol=1e-4, rtol=1e-3 | ||
), "param {name} is not zero." | ||
), f"param {name} is not zero." | ||
verified.add(name) | ||
except Exception as e: | ||
# logger.error(f"Validation failed with exception: {e}") | ||
errs.append((name, e)) | ||
logger.info(f"Verified params = {verified}") | ||
logger.info(f"Skipped params = {skipped}") | ||
|
@@ -127,24 +147,93 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None: | |
return AssertionError(f"Validation failed: {errs}") | ||
|
||
|
||
class TestWeightSync: | ||
"""Tests for weight sync between trainer and policy.""" | ||
@pytest_asyncio.fixture(autouse=True) | ||
async def _setup_and_teardown(request): | ||
# ---- setup ---- # | ||
config_path = request.config.getoption("--config", default=None) | ||
if not config_path: | ||
pytest.skip( | ||
"No config file provided. Use --config <path> to specify a YAML config file" | ||
) | ||
|
||
def _load_config(self, config_path: str) -> DictConfig: | ||
cfg = None | ||
try: | ||
cfg = OmegaConf.load(config_path) | ||
except Exception as e: | ||
pytest.fail(f"Failed to load config file {config_path}: {e}") | ||
use_dcp_override = request.config.getoption("--use_dcp") | ||
cfg = _load_config(config_path=config_path) | ||
|
||
trainer_proc_size = cfg.actors.trainer.procs | ||
policy_tp_size = cfg.policy.engine_args.tensor_parallel_size | ||
|
||
if policy_tp_size != cfg.services.policy.procs: | ||
pytest.fail( | ||
f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}" | ||
) | ||
|
||
model_card = cfg.model | ||
logger.info(f"Running sanity check with config: {config_path}") | ||
logger.info(f"Model name: {model_card}") | ||
logger.info(f"Trainer proc size: {trainer_proc_size}") | ||
logger.info(f"Policy tensor parallel size: {policy_tp_size}") | ||
|
||
logger.info("Downloading model checkpoint from HuggingFace Hub") | ||
cached_dir = snapshot_download(repo_id=model_card) | ||
logger.info("Finished downloading model checkpoint from HuggingFace Hub") | ||
|
||
services_policy_cfg = cfg.services.policy | ||
services_policy_cfg.num_replicas = 1 | ||
|
||
trainer_cfg = cfg.trainer | ||
trainer_cfg.dcp_path = TEST_DCP_DIR | ||
trainer_cfg.checkpoint = { | ||
"enable": True, | ||
"folder": "/tmp/saved_checkpoints", | ||
"initial_load_path": cached_dir, | ||
"initial_load_in_hf": True, | ||
} | ||
|
||
if use_dcp_override is not None: | ||
trainer_cfg["use_dcp"] = use_dcp_override | ||
logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}") | ||
|
||
if cfg.get("provisioner", None) is not None: | ||
await init_provisioner( | ||
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) | ||
) | ||
await ts.initialize(strategy=ts.ControllerStorageVolumes()) | ||
|
||
policy, rl_trainer = await asyncio.gather( | ||
*[ | ||
Generator.options(**services_policy_cfg).as_service(**cfg.policy), | ||
MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg), | ||
] | ||
) | ||
|
||
yield policy, rl_trainer | ||
|
||
# ---- teardown ---- # | ||
logger.info("Shutting down services and cleaning up DCP directory..") | ||
|
||
await asyncio.gather( | ||
policy.shutdown(), | ||
ts.shutdown(), | ||
RLTrainer.shutdown(rl_trainer), | ||
) | ||
|
||
# Cleanup DCP directory | ||
path = Path(TEST_DCP_DIR) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use TemporaryDirectory which handles this automatically There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have to change this because NFS does not work for tmp directory |
||
if not path.exists() or not path.is_dir(): | ||
return | ||
try: | ||
shutil.rmtree(path) | ||
logger.info(f"Successfully removed {TEST_DCP_DIR}") | ||
except Exception as e: | ||
logger.error(f"Failed to remove {TEST_DCP_DIR}: {e}") | ||
|
||
assert isinstance(cfg, DictConfig) | ||
|
||
cfg = resolve_hf_hub_paths(cfg) | ||
return cfg | ||
class TestWeightSync: | ||
"""Tests for weight sync between trainer and policy.""" | ||
|
||
@pytest.mark.asyncio | ||
@requires_cuda | ||
async def test_sanity_check(self, request): | ||
async def test_sanity_check(self, _setup_and_teardown): | ||
""" | ||
Sanity check for weight sync sharding between RLTrainer and Policy for a given model config. | ||
|
||
|
@@ -155,89 +244,41 @@ async def test_sanity_check(self, request): | |
- Load weights v1 and check the policy has all the weights back | ||
|
||
""" | ||
# Test setup | ||
config_path = request.config.getoption("--config", default=None) | ||
if not config_path: | ||
pytest.skip( | ||
"No config file provided. Use --config <path> to specify a YAML config file" | ||
) | ||
|
||
use_dcp_override = request.config.getoption("--use_dcp") | ||
cfg = self._load_config(config_path=config_path) | ||
policy, rl_trainer = _setup_and_teardown | ||
|
||
trainer_proc_size = cfg.actors.trainer.procs | ||
policy_tp_size = cfg.policy.engine_args.tensor_parallel_size | ||
v0 = uuid.uuid4().int | ||
v1 = v0 + 1 | ||
|
||
if policy_tp_size != cfg.services.policy.procs: | ||
pytest.fail( | ||
f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}" | ||
) | ||
await rl_trainer.push_weights.call(policy_version=v0) | ||
# Setting everything to zero | ||
await rl_trainer.zero_out_model_states.call() | ||
await rl_trainer.push_weights.call(policy_version=v1) | ||
await policy._test_save_model_params.fanout() | ||
|
||
model_card = cfg.model | ||
|
||
logger.info(f"Running sanity check with config: {config_path}") | ||
logger.info(f"Model name: {model_card}") | ||
logger.info(f"Trainer proc size: {trainer_proc_size}") | ||
logger.info(f"Policy tensor parallel size: {policy_tp_size}") | ||
|
||
logger.info("Downloading model checkpoint from HuggingFace Hub") | ||
cached_dir = snapshot_download(repo_id=model_card) | ||
logger.info("Finished downloading model checkpoint from HuggingFace Hub") | ||
|
||
await ts.initialize() | ||
services_policy_cfg = cfg.services.policy | ||
services_policy_cfg.num_replicas = 1 | ||
|
||
trainer_cfg = cfg.trainer | ||
trainer_cfg.checkpoint = { | ||
"enable": True, | ||
"folder": "/tmp/saved_checkpoints", | ||
"initial_load_path": cached_dir, | ||
"initial_load_in_hf": True, | ||
} | ||
if use_dcp_override is not None: | ||
trainer_cfg["use_dcp"] = use_dcp_override | ||
logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}") | ||
|
||
with TemporaryDirectory(dir="/dev/shm/") as tmpdir: | ||
trainer_cfg["dcp_path"] = tmpdir | ||
policy, rl_trainer = await asyncio.gather( | ||
*[ | ||
Generator.options(**services_policy_cfg).as_service(**cfg.policy), | ||
MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg), | ||
] | ||
) | ||
# Sanity check that before update all the tests pass | ||
all_errs = await policy._test_validate_model_params.fanout( | ||
_test_validate_params_unchanged | ||
) | ||
for errs in all_errs: | ||
for _, e in errs.items(): | ||
assert not e, f"Validation failed with exception: {e}" | ||
|
||
# Main logic begins here | ||
v0 = uuid.uuid4().int | ||
v1 = v0 + 1 | ||
|
||
await rl_trainer.push_weights.call(policy_version=v0) | ||
# Setting everything to zero | ||
await rl_trainer.zero_out_model_states.call() | ||
await rl_trainer.push_weights.call(policy_version=v1) | ||
await policy._test_save_model_params.fanout() | ||
|
||
# Sanity check that before update all the tests pass | ||
all_errs = await policy._test_validate_model_params.fanout(validate_fn) | ||
for errs in all_errs: | ||
for _, e in errs.items(): | ||
assert not e, f"Validation failed with exception: {e}" | ||
|
||
await policy.update_weights.fanout(version=v1) | ||
all_errs = await policy._test_validate_model_params.fanout( | ||
validate_fn_all_zeros | ||
) | ||
for errs in all_errs: | ||
for _, e in errs.items(): | ||
assert not e, f"Validation failed with exception: {e}" | ||
|
||
# Reloading v0, getting back original weights | ||
await policy.update_weights.fanout(version=v0) | ||
all_errs = await policy._test_validate_model_params.fanout(validate_fn) | ||
for errs in all_errs: | ||
for _, e in errs.items(): | ||
assert not e, f"Validation failed with exception: {e}" | ||
|
||
logger.info("✅ Weight sharding sanity check passed!") | ||
await ts.shutdown() | ||
await policy.update_weights.fanout(version=v1) | ||
all_errs = await policy._test_validate_model_params.fanout( | ||
_test_validate_params_all_zeros | ||
) | ||
for errs in all_errs: | ||
for _, e in errs.items(): | ||
assert not e, f"Validation failed with exception: {e}" | ||
|
||
# Reloading v0, getting back original weights | ||
await policy.update_weights.fanout(version=v0) | ||
all_errs = await policy._test_validate_model_params.fanout( | ||
_test_validate_params_unchanged | ||
) | ||
for errs in all_errs: | ||
for _, e in errs.items(): | ||
assert not e, f"Validation failed with exception: {e}" | ||
|
||
logger.info("✅ Weight sharding sanity check passed!") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.