From 917051ca8b8bf1993a86623d6a24d642ce65654c Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Wed, 15 Oct 2025 09:21:37 -0400 Subject: [PATCH 1/2] rebase + squash --- src/forge/actors/generator.py | 4 + src/forge/actors/trainer.py | 11 + .../fixtures/qwen3_1_7b_no_tp.yaml | 2 + .../fixtures/qwen3_1_7b_tp.yaml | 2 + tests/integration_tests/test_policy_update.py | 253 ++++++++++-------- 5 files changed, 166 insertions(+), 106 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 8f8cf8fc7..f7e0cfe10 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -50,10 +50,14 @@ ) from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh +<<<<<<< HEAD:src/forge/actors/generator.py from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt from forge.env import TORCHSTORE_USE_RDMA from forge.interfaces import Policy as GeneratorInterface +from forge.data.sharding import VLLMSharding +from forge.data_models.completion import Completion +from forge.data_models.prompt import to_prompt from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index dd85b3c82..83229a993 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -18,6 +18,17 @@ import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.actors._torchstore_utils import ( + DcpHandle, + get_dcp_whole_state_dict_key, + get_param_key, +) + +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml index 8b64b83ca..64588b6d4 100644 --- a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml +++ b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml @@ -66,6 +66,8 @@ services: procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 1 with_gpus: true + +actors: trainer: procs: 1 num_replicas: 1 diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml index 5d754c3ad..cf3d7dc80 100644 --- a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml +++ b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml @@ -68,6 +68,8 @@ services: procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 1 with_gpus: true + +actors: trainer: procs: 2 num_replicas: 1 diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 7edf0fcf3..614d012f8 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -6,9 +6,11 @@ import asyncio import logging -from tempfile import TemporaryDirectory +import shutil +from pathlib import Path import pytest +import pytest_asyncio import torch import torchstore as ts @@ -16,8 +18,10 @@ from forge.actors.trainer import RLTrainer from forge.cli.config import resolve_hf_hub_paths +from forge.controller.provisioner import init_provisioner from forge.controller.service.service import uuid +from forge.types import LauncherConfig, ProvisionerConfig from monarch.actor import endpoint from omegaconf import DictConfig, OmegaConf @@ -35,13 +39,16 @@ """ Run tests: -pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ - --config tests/integration_tests/artifacts/qwen3_1_7b_tp.yaml --use_dcp=false +PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ + --config tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml --use_dcp=false -pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ +PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ --config apps/grpo/qwen3_8b.yaml """ +# Temp directory won't work for multi-node because NFS does not cover the tmp path +TEST_DCP_DIR = "test_dcp_tmp" + class MockRLTrainer(RLTrainer): @endpoint @@ -58,13 +65,27 @@ async def zero_out_model_states(self): sd[k] *= 0.0 -# exceptions sometimes are not propogated in monarch, do it manually -def validate_fn(prev_params, curr_model, logger) -> Exception | None: +def _load_config(config_path: str) -> DictConfig: + cfg = None + try: + cfg = OmegaConf.load(config_path) + except Exception as e: + pytest.fail(f"Failed to load config file {config_path}: {e}") + + assert isinstance(cfg, DictConfig) + + cfg = resolve_hf_hub_paths(cfg) + return cfg + + +def _test_validate_params_unchanged( + prev_params, curr_model, logger +) -> Exception | None: """Validate that current parameters are the same as prev_params.""" verified = set() skipped = set() logger.info( - f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" + f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" ) errs = [] for name, param in curr_model.named_parameters(): @@ -83,7 +104,6 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None: ) verified.add(name) except Exception as e: - # logger.error(f"Validation failed with exception: {e}") errs.append((name, e)) logger.info(f"Verified params = {verified}") logger.info(f"Skipped params = {skipped}") @@ -94,14 +114,15 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None: return AssertionError(f"Validation failed: {errs}") -# exceptions sometimes are not propogated in monarch, do it manually -def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None: +def _test_validate_params_all_zeros( + prev_params, curr_model, logger +) -> Exception | None: """Validate all parameters are set to zero. prev_params is actually not used.""" _ = prev_params verified = set() skipped = set() logger.info( - f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" + f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" ) errs = [] for name, param in curr_model.named_parameters(): @@ -113,10 +134,9 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None: param = param.cpu() assert torch.allclose( torch.zeros_like(param), param, atol=1e-4, rtol=1e-3 - ), "param {name} is not zero." + ), f"param {name} is not zero." verified.add(name) except Exception as e: - # logger.error(f"Validation failed with exception: {e}") errs.append((name, e)) logger.info(f"Verified params = {verified}") logger.info(f"Skipped params = {skipped}") @@ -127,24 +147,93 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None: return AssertionError(f"Validation failed: {errs}") -class TestWeightSync: - """Tests for weight sync between trainer and policy.""" +@pytest_asyncio.fixture(autouse=True) +async def _setup_and_teardown(request): + # ---- setup ---- # + config_path = request.config.getoption("--config", default=None) + if not config_path: + pytest.skip( + "No config file provided. Use --config to specify a YAML config file" + ) - def _load_config(self, config_path: str) -> DictConfig: - cfg = None - try: - cfg = OmegaConf.load(config_path) - except Exception as e: - pytest.fail(f"Failed to load config file {config_path}: {e}") + use_dcp_override = request.config.getoption("--use_dcp") + cfg = _load_config(config_path=config_path) + + trainer_proc_size = cfg.actors.trainer.procs + policy_tp_size = cfg.policy.engine_args.tensor_parallel_size + + if policy_tp_size != cfg.services.policy.procs: + pytest.fail( + f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}" + ) + + model_card = cfg.model + logger.info(f"Running sanity check with config: {config_path}") + logger.info(f"Model name: {model_card}") + logger.info(f"Trainer proc size: {trainer_proc_size}") + logger.info(f"Policy tensor parallel size: {policy_tp_size}") + + logger.info("Downloading model checkpoint from HuggingFace Hub") + cached_dir = snapshot_download(repo_id=model_card) + logger.info("Finished downloading model checkpoint from HuggingFace Hub") + + services_policy_cfg = cfg.services.policy + services_policy_cfg.num_replicas = 1 + + trainer_cfg = cfg.trainer + trainer_cfg.dcp_path = TEST_DCP_DIR + trainer_cfg.checkpoint = { + "enable": True, + "folder": "/tmp/saved_checkpoints", + "initial_load_path": cached_dir, + "initial_load_in_hf": True, + } + + if use_dcp_override is not None: + trainer_cfg["use_dcp"] = use_dcp_override + logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}") + + if cfg.get("provisioner", None) is not None: + await init_provisioner( + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + ) + await ts.initialize(strategy=ts.ControllerStorageVolumes()) + + policy, rl_trainer = await asyncio.gather( + *[ + Generator.options(**services_policy_cfg).as_service(**cfg.policy), + MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg), + ] + ) + + yield policy, rl_trainer + + # ---- teardown ---- # + logger.info("Shutting down services and cleaning up DCP directory..") + + await asyncio.gather( + policy.shutdown(), + ts.shutdown(), + RLTrainer.shutdown(rl_trainer), + ) + + # Cleanup DCP directory + path = Path(TEST_DCP_DIR) + if not path.exists() or not path.is_dir(): + return + try: + shutil.rmtree(path) + logger.info(f"Successfully removed {TEST_DCP_DIR}") + except Exception as e: + logger.error(f"Failed to remove {TEST_DCP_DIR}: {e}") - assert isinstance(cfg, DictConfig) - cfg = resolve_hf_hub_paths(cfg) - return cfg +class TestWeightSync: + """Tests for weight sync between trainer and policy.""" @pytest.mark.asyncio @requires_cuda - async def test_sanity_check(self, request): + async def test_sanity_check(self, _setup_and_teardown): """ Sanity check for weight sync sharding between RLTrainer and Policy for a given model config. @@ -155,89 +244,41 @@ async def test_sanity_check(self, request): - Load weights v1 and check the policy has all the weights back """ - # Test setup - config_path = request.config.getoption("--config", default=None) - if not config_path: - pytest.skip( - "No config file provided. Use --config to specify a YAML config file" - ) - use_dcp_override = request.config.getoption("--use_dcp") - cfg = self._load_config(config_path=config_path) + policy, rl_trainer = _setup_and_teardown - trainer_proc_size = cfg.actors.trainer.procs - policy_tp_size = cfg.policy.engine_args.tensor_parallel_size + v0 = uuid.uuid4().int + v1 = v0 + 1 - if policy_tp_size != cfg.services.policy.procs: - pytest.fail( - f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}" - ) + await rl_trainer.push_weights.call(policy_version=v0) + # Setting everything to zero + await rl_trainer.zero_out_model_states.call() + await rl_trainer.push_weights.call(policy_version=v1) + await policy._test_save_model_params.fanout() - model_card = cfg.model - - logger.info(f"Running sanity check with config: {config_path}") - logger.info(f"Model name: {model_card}") - logger.info(f"Trainer proc size: {trainer_proc_size}") - logger.info(f"Policy tensor parallel size: {policy_tp_size}") - - logger.info("Downloading model checkpoint from HuggingFace Hub") - cached_dir = snapshot_download(repo_id=model_card) - logger.info("Finished downloading model checkpoint from HuggingFace Hub") - - await ts.initialize() - services_policy_cfg = cfg.services.policy - services_policy_cfg.num_replicas = 1 - - trainer_cfg = cfg.trainer - trainer_cfg.checkpoint = { - "enable": True, - "folder": "/tmp/saved_checkpoints", - "initial_load_path": cached_dir, - "initial_load_in_hf": True, - } - if use_dcp_override is not None: - trainer_cfg["use_dcp"] = use_dcp_override - logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}") - - with TemporaryDirectory(dir="/dev/shm/") as tmpdir: - trainer_cfg["dcp_path"] = tmpdir - policy, rl_trainer = await asyncio.gather( - *[ - Generator.options(**services_policy_cfg).as_service(**cfg.policy), - MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg), - ] - ) + # Sanity check that before update all the tests pass + all_errs = await policy._test_validate_model_params.fanout( + _test_validate_params_unchanged + ) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" - # Main logic begins here - v0 = uuid.uuid4().int - v1 = v0 + 1 - - await rl_trainer.push_weights.call(policy_version=v0) - # Setting everything to zero - await rl_trainer.zero_out_model_states.call() - await rl_trainer.push_weights.call(policy_version=v1) - await policy._test_save_model_params.fanout() - - # Sanity check that before update all the tests pass - all_errs = await policy._test_validate_model_params.fanout(validate_fn) - for errs in all_errs: - for _, e in errs.items(): - assert not e, f"Validation failed with exception: {e}" - - await policy.update_weights.fanout(version=v1) - all_errs = await policy._test_validate_model_params.fanout( - validate_fn_all_zeros - ) - for errs in all_errs: - for _, e in errs.items(): - assert not e, f"Validation failed with exception: {e}" - - # Reloading v0, getting back original weights - await policy.update_weights.fanout(version=v0) - all_errs = await policy._test_validate_model_params.fanout(validate_fn) - for errs in all_errs: - for _, e in errs.items(): - assert not e, f"Validation failed with exception: {e}" - - logger.info("✅ Weight sharding sanity check passed!") - await ts.shutdown() + await policy.update_weights.fanout(policy_version=v1) + all_errs = await policy._test_validate_model_params.fanout( + _test_validate_params_all_zeros + ) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + # Reloading v0, getting back original weights + await policy.update_weights.fanout(policy_version=v0) + all_errs = await policy._test_validate_model_params.fanout( + _test_validate_params_unchanged + ) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + + logger.info("✅ Weight sharding sanity check passed!") From bf0289ea31fa75fef1cd2c0672cc2b0a45bdda18 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Wed, 15 Oct 2025 09:47:49 -0400 Subject: [PATCH 2/2] new api --- src/forge/actors/generator.py | 4 ---- src/forge/actors/trainer.py | 11 ----------- tests/integration_tests/test_policy_update.py | 4 ++-- 3 files changed, 2 insertions(+), 17 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index f7e0cfe10..8f8cf8fc7 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -50,14 +50,10 @@ ) from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh -<<<<<<< HEAD:src/forge/actors/generator.py from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt from forge.env import TORCHSTORE_USE_RDMA from forge.interfaces import Policy as GeneratorInterface -from forge.data.sharding import VLLMSharding -from forge.data_models.completion import Completion -from forge.data_models.prompt import to_prompt from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 83229a993..dd85b3c82 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -18,17 +18,6 @@ import torch.distributed.checkpoint as dcp import torchstore as ts -from forge.actors._torchstore_utils import ( - DcpHandle, - get_dcp_whole_state_dict_key, - get_param_key, -) - -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device -from forge.observability.metrics import record_metric, Reduce -from forge.observability.perf_tracker import Tracer - from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 614d012f8..10b2852b7 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -264,7 +264,7 @@ async def test_sanity_check(self, _setup_and_teardown): for _, e in errs.items(): assert not e, f"Validation failed with exception: {e}" - await policy.update_weights.fanout(policy_version=v1) + await policy.update_weights.fanout(version=v1) all_errs = await policy._test_validate_model_params.fanout( _test_validate_params_all_zeros ) @@ -273,7 +273,7 @@ async def test_sanity_check(self, _setup_and_teardown): assert not e, f"Validation failed with exception: {e}" # Reloading v0, getting back original weights - await policy.update_weights.fanout(policy_version=v0) + await policy.update_weights.fanout(version=v0) all_errs = await policy._test_validate_model_params.fanout( _test_validate_params_unchanged )