From 17e0c051bb2f7d9f51b4bb591d694c8c786467ab Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Thu, 18 Sep 2025 19:39:11 -0700 Subject: [PATCH 01/11] use vllm load_weights() in GRPO --- apps/grpo/main.py | 40 +++++++++++------- src/forge/actors/policy.py | 63 ++++++++++++---------------- src/forge/actors/torchstore_utils.py | 19 +++++++++ 3 files changed, 69 insertions(+), 53 deletions(-) create mode 100644 src/forge/actors/torchstore_utils.py diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 416ecbe05..aedae35d6 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -19,6 +19,7 @@ from forge.actors.policy import Policy from forge.actors.reference_model import ReferenceModel # noqa: F401 from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.torchstore_utils import get_param_key from forge.actors.trainer import _qwen3_hf_to_vllm from forge.cli.config import parse from forge.controller.actor import ForgeActor @@ -185,14 +186,13 @@ async def train_step(self, batch: list[list[Episode]]): @endpoint async def push_weights(self, version: int): """Update policy model weights with trainer's current weights.""" - key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id - new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28) - start_time = time.time() - await ts.put_state_dict(new_sd, key) - end_time = time.time() - self.logger.debug( - f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" - ) + start_time = time.perf_counter() + hf_state_dict = self.model.state_dict() + for name, param in hf_state_dict.items(): + key = get_param_key(version, name) + await ts.put(key, param) + end_time = time.perf_counter() + self.logger.debug(f"Pushed weights in {end_time - start_time:.2f} seconds") @dataclass @@ -318,7 +318,7 @@ async def main(cfg: DictConfig): mlogger = get_metric_logger( "wandb", freq=1, - project="grpo-training", + project="yuxuanh-grpo-training-debug", ) # ---- Setup services ---- # @@ -397,20 +397,28 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 - policy_version = 0 while True: - batch = await replay_buffer.sample.choose( - curr_policy_version=policy_version - ) + batch = await replay_buffer.sample.choose(curr_policy_version=training_step) if batch is None: await asyncio.sleep(0.1) else: loss = await trainer.train_step.choose(batch) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - await trainer.push_weights.call(policy_version) - policy_version += 1 - await policy.update_weights.call() + start_time = time.perf_counter() + await trainer.push_weights.call(training_step) + mlogger.log( + "push_weights_time/training_step", + time.perf_counter() - start_time, + training_step, + ) + start_time = time.perf_counter() + await policy.update_weights.call(training_step) + mlogger.log( + "update_weights_time/training_step", + time.perf_counter() - start_time, + training_step, + ) print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 50c277bec..d1f1c9d48 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -17,7 +17,6 @@ import torch import torchstore as ts from monarch.actor import current_rank, endpoint, ProcMesh -from torchstore.state_dict_utils import DELIM from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs @@ -40,11 +39,17 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh +from forge.actors.torchstore_utils import ( + extract_param_name, + get_param_key, + get_param_prefix, +) +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh from forge.data.sharding import VLLMSharding from forge.interfaces import Policy as PolicyInterface from forge.types import ProcessConfig +from forge.util.async_utils import make_sync_generator @dataclass @@ -364,7 +369,7 @@ async def run(self): fut.set_result(request_output) @endpoint - async def update_weights(self): + async def update_weights(self, policy_version: int): # TODO: If generating long sequences, this might be long and will block policy weight updates curr_requests = [fut for _, fut in self.requests.values()] if curr_requests: @@ -372,8 +377,8 @@ async def update_weights(self): await asyncio.gather(*curr_requests) self.logger.debug(f"Starting weight update on {self.__class__.__name__}") - await self.policy_worker.update.call(version=self.weights_version) - self.weights_version += 1 + await self.policy_worker.update.call(version=policy_version) + self.weights_version = policy_version self.logger.info(f"Weight update completed (now v{self.weights_version})") @endpoint @@ -395,7 +400,6 @@ async def stop(self): @dataclass class PolicyWorker(ForgeActor): vllm_config: VllmConfig - state_dict_key: str = "model_state_dict" @endpoint async def setup(self): @@ -407,41 +411,26 @@ async def setup(self): async def execute_model(self, schedule: SchedulerOutput): return self.worker.execute_model(schedule) - async def _load_tensor_parallel_state_dict( - self, current_state_dict: dict, version: int - ): - """ - Load full state dict from torchstore into tensor parallel model with deterministic sharding. - """ - sharding = VLLMSharding( - self.vllm_config.parallel_config.tensor_parallel_size, self.rank - ) - - for param_name in current_state_dict.keys(): - current_tensor = current_state_dict[param_name] - - # Load the full tensor from torchstore - # TODO: only get the part of the tensor that is needed - stored_tensor = await ts.get( - f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}" - ) - sharding.load_from_source_to_target( - param_name, - stored_tensor, - current_tensor, - ) - @endpoint async def update(self, version: int): """Update model weights by reading state dict from torchstore""" - key = f"{self.state_dict_key}{DELIM}{version}" model = self.worker.model_runner.model - current_state_dict = model.state_dict() - start = time.time() - await self._load_tensor_parallel_state_dict(current_state_dict, version) - self.logger.debug( - f"Loaded state dict from {key} in {time.time() - start} seconds" - ) + prefix = get_param_prefix(version) + self.logger.debug(f"{prefix=}") + matching_keys = await ts.keys(prefix) + self.logger.debug(f"{matching_keys=}") + # TODO: find a way to save the original huggingface parameter names. + hf_names = [extract_param_name(key) for key in matching_keys] + self.logger.debug(f"{hf_names=}") + loaded_weights = set() + # We can't pass a generator since vllm load_weights is not async. + # Instead, we just call load_weights with one parameter at a time. + for name in hf_names: + param = await ts.get(get_param_key(version, name)) + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) + self.logger.info(f"Updated {len(loaded_weights)} parameters") @endpoint async def setup_kv_cache(self): diff --git a/src/forge/actors/torchstore_utils.py b/src/forge/actors/torchstore_utils.py new file mode 100644 index 000000000..b8835039b --- /dev/null +++ b/src/forge/actors/torchstore_utils.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +KEY_DELIM = "." + + +def get_param_prefix(policy_version: int) -> str: + return f"policy_ver_{policy_version}" + + +def get_param_key(policy_version: int, name: str) -> str: + return f"policy_ver_{policy_version}{KEY_DELIM}{name}" + + +def extract_param_name(key: str) -> str: + return KEY_DELIM.join(key.split(KEY_DELIM)[1:]) From e89a33905e724f3bbbeace1fd7a8289ae9b07ede Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Mon, 22 Sep 2025 15:21:08 -0700 Subject: [PATCH 02/11] integration test for weight sync that actually tests behavior --- .gitignore | 3 + src/forge/actors/policy.py | 62 ++-- tests/integration_tests/__init__.py | 5 + tests/integration_tests/test_policy_update.py | 279 ++++++------------ 4 files changed, 141 insertions(+), 208 deletions(-) create mode 100644 tests/integration_tests/__init__.py diff --git a/.gitignore b/.gitignore index 8acaa6101..4a769a61a 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,9 @@ cython_debug/ slogs/ slurm-* +# DCP checkpoints +model_state_dict/ + # Celery stuff celerybeat-schedule celerybeat.pid diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index bc03a3e49..8e9f63d16 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -7,6 +7,7 @@ from __future__ import annotations import asyncio +import logging import os import sys import time @@ -17,6 +18,14 @@ import torch import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh + +from forge.data.sharding import VLLMSharding +from forge.data_models.completion import Completion +from forge.data_models.prompt import to_prompt + +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig from monarch.actor import current_rank, endpoint, ProcMesh from torchstore.state_dict_utils import DELIM from vllm.config import VllmConfig @@ -41,14 +50,7 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh - -from forge.data.sharding import VLLMSharding -from forge.data_models.completion import Completion -from forge.data_models.prompt import to_prompt - -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig +logger: logging.Logger = logging.getLogger(__name__) @dataclass @@ -382,15 +384,6 @@ async def update_weights(self, policy_version: int): self.policy_version = policy_version self.logger.info(f"Weight update completed (now v{self.policy_version})") - @endpoint - async def _get_model_params(self) -> dict[str, torch.Tensor]: - """Get the current model parameters. Only for testing purposes.""" - val_mesh = await self.policy_worker._get_model_params.call() - sharded_state_dicts = {} - for idx, val in val_mesh.items(): - sharded_state_dicts[idx["gpus"]] = val - return sharded_state_dicts - @endpoint async def get_version(self) -> int: """Get the current policy version.""" @@ -400,6 +393,18 @@ async def get_version(self) -> int: async def stop(self): self.running = False + @endpoint + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info("[Policy] start saving model parameters before update for testing") + await self.policy_worker._test_save_model_params.call() + + @endpoint + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[Policy] start validating model parameters post update") + return await self.policy_worker._test_validate_model_params.call(validate_fn) + def _to_completions(self, request_output: RequestOutput) -> list[Completion]: """Convert a RequestOutput to a list of Completion objects.""" completions = [] @@ -443,6 +448,9 @@ class PolicyWorker(ForgeActor): state_dict_key: str = "model_state_dict" use_dcp: bool = True + # used for tesing purposes only + _test_prev_params = {} + @endpoint async def setup(self): # TODO: remove ["gpus"] when monarch implements a flat rank @@ -532,15 +540,19 @@ async def setup_kv_cache(self): return kv_cache_config @endpoint - async def _get_model_params(self) -> dict[str, torch.Tensor]: - model = self.worker.model_runner.model - state_dict = {} + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info( + "[PolicyWorker] start saving model parameters before update for testing" + ) + for name, param in self.worker.model_runner.model.named_parameters(): + self._test_prev_params[name] = param.detach().cpu() - for name, param in model.named_parameters(): - if "layers.0" not in name: - continue - state_dict[name] = param.cpu().detach() - return state_dict + @endpoint + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[PolicyWorker] start validating model parameters post update") + return validate_fn(self._test_prev_params, self.worker.model_runner.model, logger) def setup_worker(self): """Build and Instantiate vLLM worker""" diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/integration_tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 5fdce0b6a..7f8a42c2f 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -12,159 +12,29 @@ import torch import torchstore as ts + from forge.actors.policy import EngineConfig, Policy, SamplingConfig from forge.actors.trainer import RLTrainer from forge.controller.service import ServiceConfig from forge.data.sharding import VLLMSharding -from transformers import AutoModelForCausalLM +from monarch.actor import endpoint +from torch.distributed.checkpoint._nested_dict import flatten_state_dict + requires_cuda = pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA not available", ) -from forge.actors.trainer import _qwen3_hf_to_vllm + from huggingface_hub import snapshot_download logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -# Run tests: pytest tests/integration_tests/test_policy_update.py::TestWeightSync:: - - -def convert_state_dict(saved_sd): - """ - Convert transformers state dict to vLLM format. - - Key conversions: - 1. Copy over directly mapped keys (down_proj, input_layernorm, etc.) - 2. Fuse QKV projections: combine q_proj, k_proj, v_proj into qkv_proj - 3. Fuse MLP projections: combine gate_proj and up_proj into gate_up_proj - """ - load_sd = {} - num_layers = 32 # For Llama-8B-3.1 - - # Copy over directly mapped keys - for k in saved_sd: - if any( - x in k - for x in [ - "down_proj", - "input_layernorm", - "post_attention_layernorm", - "o_proj", - "norm.weight", - "embed_tokens.weight", - "lm_head.weight", - ] - ): - load_sd[k] = saved_sd[k] - - # Fuse QKV and gate_up_proj - for i in range(num_layers): - prefix = f"model.layers.{i}." - - # QKV fusion - q = saved_sd[prefix + "self_attn.q_proj.weight"] - k = saved_sd[prefix + "self_attn.k_proj.weight"] - v = saved_sd[prefix + "self_attn.v_proj.weight"] - load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) - - # MLP gate_up_proj fusion - gate = saved_sd[prefix + "mlp.gate_proj.weight"] - up = saved_sd[prefix + "mlp.up_proj.weight"] - load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) - - return load_sd - - -def calculate_expected_shard( - full_tensor: torch.Tensor, - param_name: str, - tensor_parallel_size: int, - rank: int, -) -> torch.Tensor: - """ - Calculate the expected shard of a full tensor for comparison with loaded tensor. - This is mainly used for validation in tests. - - Args: - full_tensor: The full tensor to shard - param_name: Name of the parameter (used to determine sharding strategy) - expected_shape: Expected shape of the sharded tensor - tensor_parallel_size: Number of tensor parallel ranks - rank: Current rank - - Returns: - torch.Tensor: The expected sharded tensor for this rank - """ - - sharding = VLLMSharding(tensor_parallel_size, rank) - shard_dim, is_sharded = sharding._get_tensor_parallel_sharding_strategy(param_name) - - if not is_sharded: - return full_tensor - - sharded_tensor = sharding._calculate_tensor_shard( - full_tensor, shard_dim, tensor_parallel_size, rank - ) - return sharded_tensor - - -def validate_loaded_tensors_equals_original( - loaded_state_dict: dict[str, torch.Tensor], - original_state_dict: dict[str, torch.Tensor], - tensor_parallel_size: int, - rank: int, -): - """ - Shared validation function to verify that every tensor loaded by the policy - equals the original tensor. - - For tensor parallel cases, instead of gathering sharded tensors, we shard - the original tensor and compare it with the loaded shard. - """ - for param_name, loaded_tensor in loaded_state_dict.items(): - if param_name in original_state_dict: - original_tensor = original_state_dict[param_name] - - if tensor_parallel_size > 1: - original_shard = calculate_expected_shard( - original_tensor, - param_name, - tensor_parallel_size, - rank, - ) - tensor_to_compare = original_shard.cpu().float() - else: - tensor_to_compare = original_tensor.cpu().float() - - # Training trainer emitted and loaded tensors are of type bfloat16, - # where as a HF model loaded(expected) tensor has type float16. - if not torch.allclose( - loaded_tensor.float(), - tensor_to_compare, - rtol=1e-2, - atol=1e-3, - ): - logger.warning( - f"Loaded tensor {param_name} does not equal original.\n" - f"dtype = {loaded_tensor.dtype} vs {original_tensor.dtype}\n" - f"shape= {loaded_tensor.shape} vs {original_tensor.shape}\n," - f"values = {loaded_tensor} vs {original_tensor}" - ) - raise ValueError( - f"Loaded tensor {param_name} does not equal original " - f"(shapes: loaded={loaded_tensor.shape}, expected={tensor_to_compare.shape})" - ) - else: - print(f"Loaded tensor {param_name} correctly validated") - - print( - f"Successfully validated that all {len(loaded_state_dict)} loaded tensors equal original" - ) +# Run tests: pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync:: def get_configs( @@ -188,15 +58,25 @@ def get_configs( return policy_config, service_config +class MockRLTrainer(RLTrainer): + @endpoint + async def mock_train_step(self): + """Mock train step. This simply multiplies the model weights by 0.1.""" + sd = self.engine.checkpointer.states["model"].state_dict() + sd, _ = flatten_state_dict(sd) + for name, param in sd.items(): + if not torch.is_floating_point(param): + continue + param.copy_(sd[name] * 0.1) + + class TestWeightSync: """Tests for weight sync between trainer and policy. Currently hardcoded to Qwen3-1.7B.""" model = "Qwen/Qwen3-1.7B" - to_vllm_fn: Callable = _qwen3_hf_to_vllm - num_layers = 28 @pytest_asyncio.fixture - def trainer_cfg(self): + async def trainer_cfg(self): cached_dir = snapshot_download(repo_id=self.model) return { "model": { @@ -212,7 +92,7 @@ def trainer_cfg(self): } @pytest_asyncio.fixture - def trainer_cfg_tp(self): + async def trainer_cfg_tp(self): # NB: TP size is set to 2. cached_dir = snapshot_download(repo_id=self.model) return { @@ -229,34 +109,30 @@ def trainer_cfg_tp(self): }, } - @pytest_asyncio.fixture - def expected_sd(self): - model = AutoModelForCausalLM.from_pretrained( - self.model, - dtype=torch.bfloat16, - trust_remote_code=True, - ) - original_state_dict = model.state_dict() - # Hack to access through class without passing in self param - return self.__class__.to_vllm_fn(original_state_dict, self.num_layers) - @pytest.mark.asyncio @requires_cuda - async def test_policy_update_single(self, expected_sd, trainer_cfg): + async def test_policy_update_single(self, trainer_cfg): """ - 1. Loads weights from HF model into in-memory state-dict (source of truth) - 2. Initializes RLTrainer, make the weights available in torchstore. - 3. Initializes Policy, and calls update_weights() to load weights from torchstore. - 4. Validate the policy weights against source of truth. + Test the weight synchronization process between RLTrainer and Policy. + + This test performs the following steps: + 1. Loads weights from a HuggingFace (HF) model into an in-memory state dictionary, serving as the source of truth. + 2. Initializes RLTrainer and applies a mock training step that multiplies all model weights by 0.1. + 3. Pushes the updated weights to torchstore. + 4. Initializes a Policy instance and calls update_weights() to load weights from torchstore. + 5. Validates that the policy's weights match the expected values (original weights multiplied by 0.1). """ worker_size = 1 # 1. Initialize TS await ts.initialize() # 2. Trainer push - rl_trainer = await RLTrainer.options( + rl_trainer = await MockRLTrainer.options( procs=worker_size, with_gpus=True, num_replicas=1 ).as_service(**trainer_cfg) + # Mock train step multiplies everything by 0.1 + await rl_trainer.mock_train_step.call() + await rl_trainer.push_weights.choose(policy_version=0) # 3. Policy pull weights policy_config, service_config = get_configs( @@ -265,21 +141,39 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg): policy = await Policy.options(service_config=service_config).as_service( **policy_config ) - await policy.update_weights.call() - # 4. Validate weights - loaded_state_dict = await policy._get_model_params.choose() - validate_loaded_tensors_equals_original( - loaded_state_dict, expected_sd, tensor_parallel_size=1, rank=0 - ) + await policy._test_save_model_params.call() + await policy.update_weights.call(policy_version=0) + + # exceptions sometimes are not propogated in monarch, do it manually + def validate_fn(prev_params, curr_model) -> Exception | None: + try: + for name, param in curr_model.named_parameters(): + if not torch.is_floating_point(param): + continue + assert name in prev_params + assert torch.allclose(prev_params[name] * 0.1, param.cpu()) + except Exception as e: + return e + finally: + return None + + all_errs = await policy._test_validate_model_params.call(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + if e: + raise e @pytest.mark.asyncio @requires_cuda - async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp): + async def test_policy_update_tp(self, trainer_cfg_tp): """ - 1. Init RLTrainer over multiple workers with TP parallelism strategy. - 2. Push weights in to torchstore. - 3. Initializes Policy over multiple workers, and calls update_weights() to load weights from torchstore. - 4. Validate the policy weights against manually loaded origina HF weights. + Test the weight synchronization process between RLTrainer and Policy. + + This test performs the following steps: + - Initializes RLTrainer and applies a mock training step that multiplies all model weights by 0.1. + - Pushes the updated weights to torchstore. + - Initializes a Policy instance and calls update_weights() to load weights from torchstore. + - Validates that the policy's weights match the expected values (original weights multiplied by 0.1). """ # test configs/paralleism trainer_worker_size = 2 @@ -290,31 +184,50 @@ async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp): pytest.skip( f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel" ) - # 1. Initialize TS + await ts.initialize() - # 2. Trainer push - rl_trainer = await RLTrainer.options( + + rl_trainer = await MockRLTrainer.options( procs=trainer_worker_size, with_gpus=True, num_replicas=1 ).as_service(**trainer_cfg_tp) await rl_trainer.push_weights.call(policy_version=0) - # 3. Policy pull weights + policy_config, service_config = get_configs( worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model ) policy = await Policy.options(service_config=service_config).as_service( **policy_config ) - await policy.update_weights.call() - # validate loaded shard of each worker againt manually calculated shard (expected shard). + # Mock train step multiplies everything by 0.1 + await rl_trainer.mock_train_step.call() - # 4. Validate weight shards. We compare vLLM loades shard content with - # Directly loaded HF shard content. - sharded_state_dicts = await policy._get_model_params.call() - validate_loaded_tensors_equals_original( - sharded_state_dicts[0][0], expected_sd, tensor_parallel_size=tp_size, rank=0 - ) - validate_loaded_tensors_equals_original( - sharded_state_dicts[0][1], expected_sd, tensor_parallel_size=tp_size, rank=1 - ) + await rl_trainer.push_weights.choose(policy_version=0) + + await policy._test_save_model_params.call() + await policy.update_weights.call(policy_version=0) + + # exceptions sometimes are not propogated in monarch, do it manually + def validate_fn(prev_params, curr_model, logger) -> Exception | None: + verified = set() + try: + for name, param in curr_model.named_parameters(): + if not torch.is_floating_point(param): + continue + assert name in prev_params + assert torch.allclose(prev_params[name] * 0.1, param.cpu()) + verified.add(name) + except Exception as e: + return e + finally: + logger.info( + f"Successfully verified {len(verified)} parameters: {verified}" + ) + return None + + all_errs = await policy._test_validate_model_params.call(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + if e: + raise e From d5ba56fd793b1f678506b2608d70d31d6d278a1f Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Mon, 22 Sep 2025 17:06:26 -0700 Subject: [PATCH 03/11] logging --- tests/integration_tests/test_policy_update.py | 67 ++++++++++--------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 7f8a42c2f..fb809f434 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -17,6 +17,8 @@ from forge.actors.trainer import RLTrainer from forge.controller.service import ServiceConfig + +from forge.controller.service.service import uuid from forge.data.sharding import VLLMSharding from monarch.actor import endpoint @@ -64,10 +66,36 @@ async def mock_train_step(self): """Mock train step. This simply multiplies the model weights by 0.1.""" sd = self.engine.checkpointer.states["model"].state_dict() sd, _ = flatten_state_dict(sd) - for name, param in sd.items(): + for _, param in sd.items(): if not torch.is_floating_point(param): continue - param.copy_(sd[name] * 0.1) + param.mul_(0.1) + + +# exceptions sometimes are not propogated in monarch, do it manually +def validate_fn(prev_params, curr_model, logger) -> Exception | None: + verified = set() + skipped = set() + try: + params = curr_model.named_parameters() + logger.info( + f"Validating model params, # of params {len(params)}, all params {params}" + ) + for name, param in curr_model.named_parameters(): + if not torch.is_floating_point(param): + logger.info(f"Skipping non-float param {name}") + skipped.add(name) + continue + assert name in prev_params + assert torch.allclose(prev_params[name] * 0.1, param.cpu()) + verified.add(name) + except Exception as e: + return e + finally: + logger.info( + f"Skipped non-float parameters: {skipped}. Successfully verified {len(verified)} parameters: {verified}" + ) + return None class TestWeightSync: @@ -144,25 +172,14 @@ async def test_policy_update_single(self, trainer_cfg): await policy._test_save_model_params.call() await policy.update_weights.call(policy_version=0) - # exceptions sometimes are not propogated in monarch, do it manually - def validate_fn(prev_params, curr_model) -> Exception | None: - try: - for name, param in curr_model.named_parameters(): - if not torch.is_floating_point(param): - continue - assert name in prev_params - assert torch.allclose(prev_params[name] * 0.1, param.cpu()) - except Exception as e: - return e - finally: - return None - all_errs = await policy._test_validate_model_params.call(validate_fn) for errs in all_errs: for _, e in errs.items(): if e: raise e + await ts.shutdown() + @pytest.mark.asyncio @requires_cuda async def test_policy_update_tp(self, trainer_cfg_tp): @@ -208,26 +225,10 @@ async def test_policy_update_tp(self, trainer_cfg_tp): await policy._test_save_model_params.call() await policy.update_weights.call(policy_version=0) - # exceptions sometimes are not propogated in monarch, do it manually - def validate_fn(prev_params, curr_model, logger) -> Exception | None: - verified = set() - try: - for name, param in curr_model.named_parameters(): - if not torch.is_floating_point(param): - continue - assert name in prev_params - assert torch.allclose(prev_params[name] * 0.1, param.cpu()) - verified.add(name) - except Exception as e: - return e - finally: - logger.info( - f"Successfully verified {len(verified)} parameters: {verified}" - ) - return None - all_errs = await policy._test_validate_model_params.call(validate_fn) for errs in all_errs: for _, e in errs.items(): if e: raise e + + await ts.shutdown() From ee7eea192c8743308a6ae559f176f17ff245a416 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:30:16 -0700 Subject: [PATCH 04/11] a --- src/forge/actors/policy.py | 25 +++-- tests/integration_tests/test_policy_update.py | 91 ++++++++++--------- 2 files changed, 66 insertions(+), 50 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 8e9f63d16..3f4bc468b 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -18,14 +18,6 @@ import torch import torch.distributed.checkpoint as dcp import torchstore as ts -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh - -from forge.data.sharding import VLLMSharding -from forge.data_models.completion import Completion -from forge.data_models.prompt import to_prompt - -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig from monarch.actor import current_rank, endpoint, ProcMesh from torchstore.state_dict_utils import DELIM from vllm.config import VllmConfig @@ -50,6 +42,15 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh + +from forge.data.sharding import VLLMSharding +from forge.data_models.completion import Completion +from forge.data_models.prompt import to_prompt + +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig + logger: logging.Logger = logging.getLogger(__name__) @@ -547,12 +548,18 @@ async def _test_save_model_params(self): ) for name, param in self.worker.model_runner.model.named_parameters(): self._test_prev_params[name] = param.detach().cpu() + logger.info( + "[PolicyWorker] finished saving model parameters, len = %d", + len(self._test_prev_params), + ) @endpoint async def _test_validate_model_params(self, validate_fn): """Validate updated model params using validate_fn.""" logger.info("[PolicyWorker] start validating model parameters post update") - return validate_fn(self._test_prev_params, self.worker.model_runner.model, logger) + return validate_fn( + self._test_prev_params, self.worker.model_runner.model, logger + ) def setup_worker(self): """Build and Instantiate vLLM worker""" diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index fb809f434..b1ba6a403 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import logging from typing import Callable @@ -66,8 +67,12 @@ async def mock_train_step(self): """Mock train step. This simply multiplies the model weights by 0.1.""" sd = self.engine.checkpointer.states["model"].state_dict() sd, _ = flatten_state_dict(sd) + logger.info(f"[MockRLTrainer] mock_train_step(): sd = {sd}") for _, param in sd.items(): if not torch.is_floating_point(param): + logger.info( + f"[MockRLTrainer] mock_train_step(): skipping non-float param {param}" + ) continue param.mul_(0.1) @@ -76,26 +81,23 @@ async def mock_train_step(self): def validate_fn(prev_params, curr_model, logger) -> Exception | None: verified = set() skipped = set() + logger.info( + f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" + ) try: - params = curr_model.named_parameters() - logger.info( - f"Validating model params, # of params {len(params)}, all params {params}" - ) for name, param in curr_model.named_parameters(): if not torch.is_floating_point(param): logger.info(f"Skipping non-float param {name}") skipped.add(name) continue - assert name in prev_params - assert torch.allclose(prev_params[name] * 0.1, param.cpu()) + assert name in prev_params, f"Param {name} not found in prev_params" + assert torch.allclose( + prev_params[name] * 0.1, param.cpu() + ), f"current param {name} does not match expected value" verified.add(name) except Exception as e: + logger.error(f"Validation failed with exception: {e}") return e - finally: - logger.info( - f"Skipped non-float parameters: {skipped}. Successfully verified {len(verified)} parameters: {verified}" - ) - return None class TestWeightSync: @@ -150,33 +152,39 @@ async def test_policy_update_single(self, trainer_cfg): 4. Initializes a Policy instance and calls update_weights() to load weights from torchstore. 5. Validates that the policy's weights match the expected values (original weights multiplied by 0.1). """ - worker_size = 1 - # 1. Initialize TS - await ts.initialize() - # 2. Trainer push - rl_trainer = await MockRLTrainer.options( - procs=worker_size, with_gpus=True, num_replicas=1 - ).as_service(**trainer_cfg) + trainer_worker_size = 2 + policy_worker_size = 1 + tp_size = 1 - # Mock train step multiplies everything by 0.1 - await rl_trainer.mock_train_step.call() + await ts.initialize() - await rl_trainer.push_weights.choose(policy_version=0) - # 3. Policy pull weights policy_config, service_config = get_configs( - worker_size=worker_size, tp_size=worker_size, model_name=self.model + worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model ) - policy = await Policy.options(service_config=service_config).as_service( - **policy_config + policy, rl_trainer = await asyncio.gather( + *[ + Policy.options(service_config=service_config).as_service( + **policy_config + ), + MockRLTrainer.options( + procs=trainer_worker_size, with_gpus=True, num_replicas=1 + ).as_service(**trainer_cfg), + ] ) + + policy_version = uuid.uuid4().int + + # Mock train step multiplies everything by 0.1 + await rl_trainer.mock_train_step.call() + + await rl_trainer.push_weights.call(policy_version=policy_version) await policy._test_save_model_params.call() - await policy.update_weights.call(policy_version=0) + await policy.update_weights.call(policy_version=policy_version) all_errs = await policy._test_validate_model_params.call(validate_fn) for errs in all_errs: for _, e in errs.items(): - if e: - raise e + assert not e, f"Validation failed with exception: {e}" await ts.shutdown() @@ -204,31 +212,32 @@ async def test_policy_update_tp(self, trainer_cfg_tp): await ts.initialize() - rl_trainer = await MockRLTrainer.options( - procs=trainer_worker_size, with_gpus=True, num_replicas=1 - ).as_service(**trainer_cfg_tp) - - await rl_trainer.push_weights.call(policy_version=0) - policy_config, service_config = get_configs( worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model ) - policy = await Policy.options(service_config=service_config).as_service( - **policy_config + policy, rl_trainer = await asyncio.gather( + *[ + Policy.options(service_config=service_config).as_service( + **policy_config + ), + MockRLTrainer.options( + procs=trainer_worker_size, with_gpus=True, num_replicas=1 + ).as_service(**trainer_cfg_tp), + ] ) + policy_version = uuid.uuid4().int + # Mock train step multiplies everything by 0.1 await rl_trainer.mock_train_step.call() - await rl_trainer.push_weights.choose(policy_version=0) - + await rl_trainer.push_weights.call(policy_version=policy_version) await policy._test_save_model_params.call() - await policy.update_weights.call(policy_version=0) + await policy.update_weights.call(policy_version=policy_version) all_errs = await policy._test_validate_model_params.call(validate_fn) for errs in all_errs: for _, e in errs.items(): - if e: - raise e + assert not e, f"Validation failed with exception: {e}" await ts.shutdown() From 368286bc4f8480722d45a4fb9a70a60b12e57f30 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Mon, 22 Sep 2025 19:46:17 -0700 Subject: [PATCH 05/11] better diagnostics --- tests/integration_tests/test_policy_update.py | 62 +++++++++++++------ 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index b1ba6a403..558349a0e 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -39,6 +39,8 @@ # Run tests: pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync:: +TEST_MULTIPLIER = 1.5 + def get_configs( worker_size: int, tp_size: int, model_name: str @@ -64,7 +66,17 @@ def get_configs( class MockRLTrainer(RLTrainer): @endpoint async def mock_train_step(self): - """Mock train step. This simply multiplies the model weights by 0.1.""" + """Mock train step. This simply multiplies the model weights by TEST_MULTIPLIER""" + self.engine.optimizers.step() + self.engine.optimizers.zero_grad() + self.engine.lr_schedulers.step() + + self.current_step += 1 + self.engine.checkpointer.save( + curr_step=self.current_step, + last_step=self.current_step == self.num_training_steps, + ) + sd = self.engine.checkpointer.states["model"].state_dict() sd, _ = flatten_state_dict(sd) logger.info(f"[MockRLTrainer] mock_train_step(): sd = {sd}") @@ -74,7 +86,7 @@ async def mock_train_step(self): f"[MockRLTrainer] mock_train_step(): skipping non-float param {param}" ) continue - param.mul_(0.1) + param *= 1.5 # exceptions sometimes are not propogated in monarch, do it manually @@ -84,20 +96,32 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None: logger.info( f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" ) - try: - for name, param in curr_model.named_parameters(): - if not torch.is_floating_point(param): - logger.info(f"Skipping non-float param {name}") - skipped.add(name) - continue + errs = [] + for name, param in curr_model.named_parameters(): + if not torch.is_floating_point(param): + logger.info(f"Skipping non-float param {name}") + skipped.add(name) + continue + try: assert name in prev_params, f"Param {name} not found in prev_params" assert torch.allclose( - prev_params[name] * 0.1, param.cpu() - ), f"current param {name} does not match expected value" + prev_params[name] * TEST_MULTIPLIER, param.cpu(), atol=1e-3, rtol=1e-2 + ), ( + f"current param {name} does not match expected value; " + f"previous param ({prev_params[name].size()})= {prev_params[name]}; " + f"expected = {prev_params[name] * TEST_MULTIPLIER} vs got = {param.cpu().size()} {param.cpu()}" + ) verified.add(name) - except Exception as e: - logger.error(f"Validation failed with exception: {e}") - return e + except Exception as e: + # logger.error(f"Validation failed with exception: {e}") + errs.append((name, e)) + logger.info(f"Verified params = {verified}") + logger.info(f"Skipped params = {skipped}") + if errs: + logger.error( + f"Validation failed for the following params: {[e[0] for e in errs]}" + ) + return AssertionError(f"Validation failed: {errs}") class TestWeightSync: @@ -147,10 +171,10 @@ async def test_policy_update_single(self, trainer_cfg): This test performs the following steps: 1. Loads weights from a HuggingFace (HF) model into an in-memory state dictionary, serving as the source of truth. - 2. Initializes RLTrainer and applies a mock training step that multiplies all model weights by 0.1. + 2. Initializes RLTrainer and applies a mock training step that multiplies all model weights by TEST_MULTIPLIER. 3. Pushes the updated weights to torchstore. 4. Initializes a Policy instance and calls update_weights() to load weights from torchstore. - 5. Validates that the policy's weights match the expected values (original weights multiplied by 0.1). + 5. Validates that the policy's weights match the expected values (original weights multiplied by TEST_MULTIPLIER). """ trainer_worker_size = 2 policy_worker_size = 1 @@ -174,7 +198,7 @@ async def test_policy_update_single(self, trainer_cfg): policy_version = uuid.uuid4().int - # Mock train step multiplies everything by 0.1 + # Mock train step multiplies everything by TEST_MULTIPLIER await rl_trainer.mock_train_step.call() await rl_trainer.push_weights.call(policy_version=policy_version) @@ -195,10 +219,10 @@ async def test_policy_update_tp(self, trainer_cfg_tp): Test the weight synchronization process between RLTrainer and Policy. This test performs the following steps: - - Initializes RLTrainer and applies a mock training step that multiplies all model weights by 0.1. + - Initializes RLTrainer and applies a mock training step that multiplies all model weights by TEST_MULTIPLIER. - Pushes the updated weights to torchstore. - Initializes a Policy instance and calls update_weights() to load weights from torchstore. - - Validates that the policy's weights match the expected values (original weights multiplied by 0.1). + - Validates that the policy's weights match the expected values (original weights multiplied by TEST_MULTIPLIER). """ # test configs/paralleism trainer_worker_size = 2 @@ -228,7 +252,7 @@ async def test_policy_update_tp(self, trainer_cfg_tp): policy_version = uuid.uuid4().int - # Mock train step multiplies everything by 0.1 + # Mock train step multiplies everything by TEST_MULTIPLIER await rl_trainer.mock_train_step.call() await rl_trainer.push_weights.call(policy_version=policy_version) From 0f2d8a1f5c83835608c31e85219efe8abe2f26d1 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Tue, 23 Sep 2025 03:21:10 -0700 Subject: [PATCH 06/11] avoid dtensor yoga, setting everything to zero --- src/forge/actors/trainer.py | 6 +- tests/integration_tests/test_policy_update.py | 131 +++++++++++------- 2 files changed, 84 insertions(+), 53 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 5dffbac1e..01da7d1cf 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -15,6 +15,9 @@ import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device + from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict @@ -34,9 +37,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device - @dataclass class RLTrainer(ForgeActor): diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 558349a0e..52c9f96e3 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -13,17 +13,16 @@ import torch import torchstore as ts - from forge.actors.policy import EngineConfig, Policy, SamplingConfig from forge.actors.trainer import RLTrainer from forge.controller.service import ServiceConfig from forge.controller.service.service import uuid -from forge.data.sharding import VLLMSharding -from monarch.actor import endpoint +from monarch.actor import current_rank, endpoint from torch.distributed.checkpoint._nested_dict import flatten_state_dict +from torch.distributed.tensor import DTensor, Replicate requires_cuda = pytest.mark.skipif( @@ -39,8 +38,6 @@ # Run tests: pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync:: -TEST_MULTIPLIER = 1.5 - def get_configs( worker_size: int, tp_size: int, model_name: str @@ -66,31 +63,21 @@ def get_configs( class MockRLTrainer(RLTrainer): @endpoint async def mock_train_step(self): - """Mock train step. This simply multiplies the model weights by TEST_MULTIPLIER""" - self.engine.optimizers.step() - self.engine.optimizers.zero_grad() - self.engine.lr_schedulers.step() - - self.current_step += 1 - self.engine.checkpointer.save( - curr_step=self.current_step, - last_step=self.current_step == self.num_training_steps, - ) - - sd = self.engine.checkpointer.states["model"].state_dict() - sd, _ = flatten_state_dict(sd) - logger.info(f"[MockRLTrainer] mock_train_step(): sd = {sd}") - for _, param in sd.items(): - if not torch.is_floating_point(param): - logger.info( - f"[MockRLTrainer] mock_train_step(): skipping non-float param {param}" - ) - continue - param *= 1.5 + """Mock train step. This simply sets all model weights to zero.""" + for model_part in self.engine.model_parts: + sd = model_part.state_dict() + for k in sd.keys(): + if not torch.is_floating_point(sd[k]): + logger.info( + f"[MockRLTrainer] mock_train_step(): skipping non-float param {k}" + ) + continue + sd[k] *= 0.0 # exceptions sometimes are not propogated in monarch, do it manually def validate_fn(prev_params, curr_model, logger) -> Exception | None: + """Validate that current parameters are the same as prev_params.""" verified = set() skipped = set() logger.info( @@ -105,11 +92,11 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None: try: assert name in prev_params, f"Param {name} not found in prev_params" assert torch.allclose( - prev_params[name] * TEST_MULTIPLIER, param.cpu(), atol=1e-3, rtol=1e-2 + prev_params[name], param.cpu(), atol=1e-3, rtol=1e-2 ), ( f"current param {name} does not match expected value; " f"previous param ({prev_params[name].size()})= {prev_params[name]}; " - f"expected = {prev_params[name] * TEST_MULTIPLIER} vs got = {param.cpu().size()} {param.cpu()}" + f"expected = {prev_params[name]} vs got = {param.cpu().size()} {param.cpu()}" ) verified.add(name) except Exception as e: @@ -124,6 +111,39 @@ def validate_fn(prev_params, curr_model, logger) -> Exception | None: return AssertionError(f"Validation failed: {errs}") +# exceptions sometimes are not propogated in monarch, do it manually +def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None: + """Validate all parameters are set to zero. prev_params is actually not used.""" + _ = prev_params + verified = set() + skipped = set() + logger.info( + f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" + ) + errs = [] + for name, param in curr_model.named_parameters(): + if not torch.is_floating_point(param): + logger.info(f"Skipping non-float param {name}") + skipped.add(name) + continue + try: + param = param.cpu() + assert torch.allclose( + torch.zeros_like(param), param, atol=1e-4, rtol=1e-3 + ), "param {name} is not zero." + verified.add(name) + except Exception as e: + # logger.error(f"Validation failed with exception: {e}") + errs.append((name, e)) + logger.info(f"Verified params = {verified}") + logger.info(f"Skipped params = {skipped}") + if errs: + logger.error( + f"Validation failed for the following params: {[e[0] for e in errs]}" + ) + return AssertionError(f"Validation failed: {errs}") + + class TestWeightSync: """Tests for weight sync between trainer and policy. Currently hardcoded to Qwen3-1.7B.""" @@ -170,13 +190,12 @@ async def test_policy_update_single(self, trainer_cfg): Test the weight synchronization process between RLTrainer and Policy. This test performs the following steps: - 1. Loads weights from a HuggingFace (HF) model into an in-memory state dictionary, serving as the source of truth. - 2. Initializes RLTrainer and applies a mock training step that multiplies all model weights by TEST_MULTIPLIER. - 3. Pushes the updated weights to torchstore. - 4. Initializes a Policy instance and calls update_weights() to load weights from torchstore. - 5. Validates that the policy's weights match the expected values (original weights multiplied by TEST_MULTIPLIER). + - Initialize trainer and push weights v0 (original huggingface ckpt) + - Step the trainer, setting all weights to zero and push weights v1 + - Load weights v0 and check the policy has all zero weights + - Load weights v1 and check the policy has all the weights back """ - trainer_worker_size = 2 + trainer_worker_size = 1 policy_worker_size = 1 tp_size = 1 @@ -196,15 +215,21 @@ async def test_policy_update_single(self, trainer_cfg): ] ) - policy_version = uuid.uuid4().int + v0 = uuid.uuid4().int + v1 = v0 + 1 - # Mock train step multiplies everything by TEST_MULTIPLIER + await rl_trainer.push_weights.call(policy_version=v0) + # Setting everything to zero await rl_trainer.mock_train_step.call() - - await rl_trainer.push_weights.call(policy_version=policy_version) + await rl_trainer.push_weights.call(policy_version=v1) await policy._test_save_model_params.call() - await policy.update_weights.call(policy_version=policy_version) - + await policy.update_weights.call(policy_version=v1) + all_errs = await policy._test_validate_model_params.call(validate_fn_all_zeros) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + # Reloading v0, getting back original weights + await policy.update_weights.call(policy_version=v0) all_errs = await policy._test_validate_model_params.call(validate_fn) for errs in all_errs: for _, e in errs.items(): @@ -219,10 +244,10 @@ async def test_policy_update_tp(self, trainer_cfg_tp): Test the weight synchronization process between RLTrainer and Policy. This test performs the following steps: - - Initializes RLTrainer and applies a mock training step that multiplies all model weights by TEST_MULTIPLIER. - - Pushes the updated weights to torchstore. - - Initializes a Policy instance and calls update_weights() to load weights from torchstore. - - Validates that the policy's weights match the expected values (original weights multiplied by TEST_MULTIPLIER). + - Initialize trainer and push weights v0 (original huggingface ckpt) + - Step the trainer, setting all weights to zero and push weights v1 + - Load weights v0 and check the policy has all zero weights + - Load weights v1 and check the policy has all the weights back """ # test configs/paralleism trainer_worker_size = 2 @@ -250,15 +275,21 @@ async def test_policy_update_tp(self, trainer_cfg_tp): ] ) - policy_version = uuid.uuid4().int + v0 = uuid.uuid4().int + v1 = v0 + 1 - # Mock train step multiplies everything by TEST_MULTIPLIER + await rl_trainer.push_weights.call(policy_version=v0) + # Setting everything to zero await rl_trainer.mock_train_step.call() - - await rl_trainer.push_weights.call(policy_version=policy_version) + await rl_trainer.push_weights.call(policy_version=v1) await policy._test_save_model_params.call() - await policy.update_weights.call(policy_version=policy_version) - + await policy.update_weights.call(policy_version=v1) + all_errs = await policy._test_validate_model_params.call(validate_fn_all_zeros) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + # Reloading v0, getting back original weights + await policy.update_weights.call(policy_version=v0) all_errs = await policy._test_validate_model_params.call(validate_fn) for errs in all_errs: for _, e in errs.items(): From 092bcba3dce5b9d0466374fa711f6ca5e6caaa89 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Tue, 23 Sep 2025 10:59:42 -0700 Subject: [PATCH 07/11] add sanity check --- tests/integration_tests/test_policy_update.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 52c9f96e3..99ba8a2b7 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -223,11 +223,19 @@ async def test_policy_update_single(self, trainer_cfg): await rl_trainer.mock_train_step.call() await rl_trainer.push_weights.call(policy_version=v1) await policy._test_save_model_params.call() + + # Sanity check that before update all the tests pass + all_errs = await policy._test_validate_model_params.call(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + await policy.update_weights.call(policy_version=v1) all_errs = await policy._test_validate_model_params.call(validate_fn_all_zeros) for errs in all_errs: for _, e in errs.items(): assert not e, f"Validation failed with exception: {e}" + # Reloading v0, getting back original weights await policy.update_weights.call(policy_version=v0) all_errs = await policy._test_validate_model_params.call(validate_fn) @@ -283,6 +291,13 @@ async def test_policy_update_tp(self, trainer_cfg_tp): await rl_trainer.mock_train_step.call() await rl_trainer.push_weights.call(policy_version=v1) await policy._test_save_model_params.call() + + # Sanity check that before update all the tests pass + all_errs = await policy._test_validate_model_params.call(validate_fn) + for errs in all_errs: + for _, e in errs.items(): + assert not e, f"Validation failed with exception: {e}" + await policy.update_weights.call(policy_version=v1) all_errs = await policy._test_validate_model_params.call(validate_fn_all_zeros) for errs in all_errs: From 6b67be91f2ff43a280329504289ed9f441294194 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:44:04 -0700 Subject: [PATCH 08/11] modify trainer.py --- src/forge/actors/trainer.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index f6ffe3af7..ecd466e70 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -16,6 +16,14 @@ import torch import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.actors.torchstore_utils import ( + extract_param_name, + get_param_key, + get_param_prefix, +) + +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device from monarch.actor import current_rank, current_size, endpoint from torch import Tensor @@ -36,9 +44,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -290,6 +295,15 @@ async def push_weights(self, policy_version: int) -> None: logger.debug(f"Pushed weights to {key} in {end_time - start_time:.2f} seconds") + @endpoint + async def push_weights_hf_nonsharded( + self, hf_state_dict, policy_version: int + ) -> None: + """Push weights to torchstore in HF format, non-sharded.""" + for name, param in hf_state_dict.items(): + key = get_param_key(policy_version, name) + await ts.put(key, param) + @endpoint async def cleanup(self) -> None: if self.engine.checkpointer: From a6c7aef9da82a7d32418f8398396a3382c86812b Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:44:52 -0700 Subject: [PATCH 09/11] modify trainer.py --- src/forge/actors/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index ecd466e70..523e33ef4 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -16,6 +16,7 @@ import torch import torch.distributed.checkpoint as dcp import torchstore as ts + from forge.actors.torchstore_utils import ( extract_param_name, get_param_key, From d46e18961dedc6fc0e21a07ee2f8077ac1f29126 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:12:28 -0700 Subject: [PATCH 10/11] fix merge conflicts --- src/forge/actors/policy.py | 35 ++++++++----------- tests/integration_tests/test_policy_update.py | 8 ++--- 2 files changed, 18 insertions(+), 25 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index e026e778e..c63bc5939 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -7,10 +7,6 @@ from __future__ import annotations import asyncio -<<<<<<< HEAD -======= - ->>>>>>> main import logging import os import sys @@ -22,6 +18,20 @@ import torch import torch.distributed.checkpoint as dcp import torchstore as ts + +from forge.actors.torchstore_utils import ( + extract_param_name, + get_param_key, + get_param_prefix, +) + +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh +from forge.data.sharding import VLLMSharding +from forge.data_models.completion import Completion +from forge.data_models.prompt import to_prompt + +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig from monarch.actor import current_rank, endpoint, ProcMesh from vllm.config import VllmConfig @@ -45,26 +55,9 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.actors.torchstore_utils import ( - extract_param_name, - get_param_key, - get_param_prefix, -) - -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh -from forge.data.sharding import VLLMSharding -from forge.data_models.completion import Completion -from forge.data_models.prompt import to_prompt - -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig - -<<<<<<< HEAD logger: logging.Logger = logging.getLogger(__name__) -======= logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) ->>>>>>> main @dataclass diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 99ba8a2b7..d38520714 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -218,10 +218,10 @@ async def test_policy_update_single(self, trainer_cfg): v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.call(policy_version=v0) + await rl_trainer.push_weights_hf_nonsharded.call(policy_version=v0) # Setting everything to zero await rl_trainer.mock_train_step.call() - await rl_trainer.push_weights.call(policy_version=v1) + await rl_trainer.push_weights_hf_nonsharded.call(policy_version=v1) await policy._test_save_model_params.call() # Sanity check that before update all the tests pass @@ -286,10 +286,10 @@ async def test_policy_update_tp(self, trainer_cfg_tp): v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.call(policy_version=v0) + await rl_trainer.push_weights_hf_nonsharded.call(policy_version=v0) # Setting everything to zero await rl_trainer.mock_train_step.call() - await rl_trainer.push_weights.call(policy_version=v1) + await rl_trainer.push_weights_hf_nonsharded.call(policy_version=v1) await policy._test_save_model_params.call() # Sanity check that before update all the tests pass From 7ec461b84ed58365971a615906839be49591e5c4 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu <57782783+casteryh@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:17:50 -0700 Subject: [PATCH 11/11] fix --- src/forge/actors/trainer.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 523e33ef4..014d94037 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -297,10 +297,18 @@ async def push_weights(self, policy_version: int) -> None: logger.debug(f"Pushed weights to {key} in {end_time - start_time:.2f} seconds") @endpoint - async def push_weights_hf_nonsharded( - self, hf_state_dict, policy_version: int - ) -> None: + async def push_weights_hf_nonsharded(self, policy_version: int) -> None: """Push weights to torchstore in HF format, non-sharded.""" + if "model" not in self.engine.checkpointer.states: + raise RuntimeError("Model state not found in checkpointer state") + + sd = self.engine.checkpointer.states["model"].state_dict() + flattened_state_dict, _ = flatten_state_dict(sd) + if self.engine.checkpointer.sd_adapter is None: + raise RuntimeError( + "Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." + ) + hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) for name, param in hf_state_dict.items(): key = get_param_key(policy_version, name) await ts.put(key, param)