From 15bf741c4a187fc0b58469e977b016308712a96b Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Wed, 24 Sep 2025 14:50:29 -0400 Subject: [PATCH 01/12] demonstrate the sharding bug when trainer export the weights --- src/forge/actors/trainer.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index f0f5d328a..de8a3e164 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -17,6 +17,9 @@ import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device + from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict @@ -36,9 +39,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -343,13 +343,29 @@ def unwrap(t): ): load_sd[k] = sd[k] + # Suppose tp = 4 for policy for illustration. + policy_tp = 4 + for i in range(num_layers): prefix = f"model.layers.{i}." # QKV fusion q = sd[prefix + "self_attn.q_proj.weight"] k = sd[prefix + "self_attn.k_proj.weight"] v = sd[prefix + "self_attn.v_proj.weight"] - load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) + + q_shards = torch.chunk(q, policy_tp, dim=0) + k_shards = torch.chunk(k, policy_tp, dim=0) + v_shards = torch.chunk(v, policy_tp, dim=0) + + # Concatenate each corresponding shard (q_shard_i, k_shard_i, v_shard_i) + combined_shards = [] + for i in range(policy_tp): + combined_shard = torch.cat([q_shards[i], k_shards[i], v_shards[i]], dim=0) + combined_shards.append(combined_shard) + + load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat( + combined_shards, dim=0 + ) # QKV fusion - handle bias if present q_bias_key = prefix + "self_attn.q_proj.bias" @@ -360,6 +376,7 @@ def unwrap(t): q_bias = sd[q_bias_key] k_bias = sd[k_bias_key] v_bias = sd[v_bias_key] + # Same sharding has to happen here load_sd[prefix + "self_attn.qkv_proj.bias"] = torch.cat( [q_bias, k_bias, v_bias], dim=0 ) @@ -367,7 +384,14 @@ def unwrap(t): # MLP gate_up_proj fusion gate = sd[prefix + "mlp.gate_proj.weight"] up = sd[prefix + "mlp.up_proj.weight"] - load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) + gate_shards = torch.chunk(gate, policy_tp, dim=0) + up_shards = torch.chunk(up, policy_tp, dim=0) + + combined_shards = [] + for i in range(policy_tp): + combined_shard = torch.cat([gate_shards[i], up_shards[i]], dim=0) + combined_shards.append(combined_shard) + load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat(combined_shards, dim=0) # MLP gate_up_proj fusion - handle bias if present gate_bias_key = prefix + "mlp.gate_proj.bias" @@ -376,6 +400,7 @@ def unwrap(t): if all(key in sd for key in [gate_bias_key, up_bias_key]): gate_bias = sd[gate_bias_key] up_bias = sd[up_bias_key] + # Same sharding has to happen here load_sd[prefix + "mlp.gate_up_proj.bias"] = torch.cat( [gate_bias, up_bias], dim=0 ) From 656aca373eabf22ce6bc708e49672b0782857d5c Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Thu, 25 Sep 2025 22:26:28 -0400 Subject: [PATCH 02/12] fix attempt --- src/forge/actors/trainer.py | 72 ++++++++++--------- tests/integration_tests/test_policy_update.py | 8 +-- 2 files changed, 43 insertions(+), 37 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index de8a3e164..bc0b81e48 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -100,6 +100,7 @@ class RLTrainer(ForgeActor): state_dict_key: str = "model_state_dict" use_dcp: bool = True + def __post_init__(self): """Initializes config types and env variables. @@ -252,7 +253,7 @@ def train_step( return loss.item() @endpoint - async def push_weights(self, policy_version: int) -> None: + async def push_weights(self, policy_version: int, tp_DEPRECATED: int = 1) -> None: # Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. start_time = time.perf_counter() # TODO: @@ -271,7 +272,8 @@ async def push_weights(self, policy_version: int) -> None: hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) # TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed vllm_ready_hf_sd = _qwen3_hf_to_vllm( - sd=hf_state_dict, num_layers=self.engine.model_args.n_layers + sd=hf_state_dict, num_layers=self.engine.model_args.n_layers, + tp=tp_DEPRECATED ) conversion_time = time.perf_counter() key = f"{self.state_dict_key}{DELIM}{policy_version}" @@ -307,7 +309,29 @@ async def cleanup(self) -> None: self.engine.checkpointer.close() -def _qwen3_hf_to_vllm(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: +def _shard_and_concat(sources: list[torch.Tensor], dim: int, tp: int) -> torch.Tensor: + """Shard and concatenate tensors along a given dimension. + + Args: + source (list[torch.Tensor]): List of tensors to shard and concatenate. + dim (int): Dimension along which to shard and concatenate. + tp (int): Number of tensor parallel groups. + + Returns: + torch.Tensor: Concatenated tensor. + """ + sharded_sources = [] + for source in sources: + sharded_sources.append(torch.chunk(source, tp, dim=dim)) + + combined_shards = [] + for shard_idx in range(tp): + combined= torch.cat([s[shard_idx] for s in sharded_sources], dim=dim) + combined_shards.append(combined) + return torch.cat(combined_shards, dim=dim) + + +def _qwen3_hf_to_vllm(sd: dict[str, torch.Tensor], num_layers: int, tp: int) -> dict[str, torch.Tensor]: """Convert transformers state dict to vLLM format. Specifically, this fuses QKV projection and MLP gate_up_proj layers. @@ -343,9 +367,6 @@ def unwrap(t): ): load_sd[k] = sd[k] - # Suppose tp = 4 for policy for illustration. - policy_tp = 4 - for i in range(num_layers): prefix = f"model.layers.{i}." # QKV fusion @@ -353,21 +374,11 @@ def unwrap(t): k = sd[prefix + "self_attn.k_proj.weight"] v = sd[prefix + "self_attn.v_proj.weight"] - q_shards = torch.chunk(q, policy_tp, dim=0) - k_shards = torch.chunk(k, policy_tp, dim=0) - v_shards = torch.chunk(v, policy_tp, dim=0) - - # Concatenate each corresponding shard (q_shard_i, k_shard_i, v_shard_i) - combined_shards = [] - for i in range(policy_tp): - combined_shard = torch.cat([q_shards[i], k_shards[i], v_shards[i]], dim=0) - combined_shards.append(combined_shard) - - load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat( - combined_shards, dim=0 + load_sd[prefix + "self_attn.qkv_proj.weight"] = _shard_and_concat( + [q, k, v], dim=0, tp=tp ) - # QKV fusion - handle bias if present + # Untested: QKV fusion - handle bias if present q_bias_key = prefix + "self_attn.q_proj.bias" k_bias_key = prefix + "self_attn.k_proj.bias" v_bias_key = prefix + "self_attn.v_proj.bias" @@ -376,24 +387,18 @@ def unwrap(t): q_bias = sd[q_bias_key] k_bias = sd[k_bias_key] v_bias = sd[v_bias_key] - # Same sharding has to happen here - load_sd[prefix + "self_attn.qkv_proj.bias"] = torch.cat( - [q_bias, k_bias, v_bias], dim=0 + load_sd[prefix + "self_attn.qkv_proj.bias"] = _shard_and_concat( + [q_bias, k_bias, v_bias], dim=0, tp=tp ) # MLP gate_up_proj fusion gate = sd[prefix + "mlp.gate_proj.weight"] up = sd[prefix + "mlp.up_proj.weight"] - gate_shards = torch.chunk(gate, policy_tp, dim=0) - up_shards = torch.chunk(up, policy_tp, dim=0) - - combined_shards = [] - for i in range(policy_tp): - combined_shard = torch.cat([gate_shards[i], up_shards[i]], dim=0) - combined_shards.append(combined_shard) - load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat(combined_shards, dim=0) + load_sd[prefix + "mlp.gate_up_proj.weight"] = _shard_and_concat( + [gate, up], dim=0, tp=tp + ) - # MLP gate_up_proj fusion - handle bias if present + # Untested: MLP gate_up_proj fusion - handle bias if present gate_bias_key = prefix + "mlp.gate_proj.bias" up_bias_key = prefix + "mlp.up_proj.bias" @@ -401,8 +406,9 @@ def unwrap(t): gate_bias = sd[gate_bias_key] up_bias = sd[up_bias_key] # Same sharding has to happen here - load_sd[prefix + "mlp.gate_up_proj.bias"] = torch.cat( - [gate_bias, up_bias], dim=0 + load_sd[prefix + "mlp.gate_up_proj.bias"] = _shard_and_concat( + [gate_bias, up_bias], dim=0, tp=tp ) + return load_sd diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 543956877..28192e5de 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -213,10 +213,10 @@ async def test_policy_update_single(self, trainer_cfg): v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.fanout(policy_version=v0) + await rl_trainer.push_weights.fanout(policy_version=v0, tp_DEPRECATED=tp_size) # Setting everything to zero await rl_trainer.zero_out_model_states.fanout() - await rl_trainer.push_weights.fanout(policy_version=v1) + await rl_trainer.push_weights.fanout(policy_version=v1, tp_DEPRECATED=tp_size) await policy._test_save_model_params.fanout() # Sanity check that before update all the tests pass @@ -281,10 +281,10 @@ async def test_policy_update_tp(self, trainer_cfg_tp): v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.fanout(policy_version=v0) + await rl_trainer.push_weights.fanout(policy_version=v0, tp_DEPRECATED=tp_size) # Setting everything to zero await rl_trainer.zero_out_model_states.fanout() - await rl_trainer.push_weights.fanout(policy_version=v1) + await rl_trainer.push_weights.fanout(policy_version=v1, tp_DEPRECATED=tp_size) await policy._test_save_model_params.fanout() # Sanity check that before update all the tests pass From a7bc8c3ca605874cec681e69ba579ce5ed085aa8 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Thu, 25 Sep 2025 23:01:04 -0400 Subject: [PATCH 03/12] update main grpo app --- apps/grpo/main.py | 3 ++- src/forge/actors/trainer.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 05922b644..fb8e92812 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -243,6 +243,7 @@ async def main(cfg: DictConfig): group_size = cfg.group_size max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens + policy_tp_size = cfg.policy.engine_config.tensor_parallel_size mlogger = get_metric_logger( "wandb", freq=1, @@ -356,7 +357,7 @@ async def continuous_training(): loss = await trainer.train_step.route(inputs, targets) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - await trainer.push_weights.fanout(training_step) + await trainer.push_weights.fanout(training_step, tp_DEPRECATED=policy_tp_size) await policy.update_weights.fanout(training_step) print("Starting GRPO training loops...") diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index bc0b81e48..1ed8299d3 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -100,7 +100,6 @@ class RLTrainer(ForgeActor): state_dict_key: str = "model_state_dict" use_dcp: bool = True - def __post_init__(self): """Initializes config types and env variables. From 60caca7b4c20561e22f1f770aa017740605279c0 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 26 Sep 2025 08:59:44 -0400 Subject: [PATCH 04/12] nit --- apps/grpo/qwen3_1_7b.yaml | 1 + src/forge/actors/trainer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 5c0481528..b121e4eaa 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -111,6 +111,7 @@ services: trainer: procs: 1 num_replicas: 1 + hosts: 1 with_gpus: true replay_buffer: procs: 1 diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 1ed8299d3..6d2d66622 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -299,7 +299,7 @@ async def push_weights(self, policy_version: int, tp_DEPRECATED: int = 1) -> Non end_time = time.perf_counter() logger.info( f"Completed weights push to {key} in {end_time - start_time:.2f} seconds " - f"(to_vllm: {conversion_time - start_time:.2f}s, tranport time: {end_time - conversion_time:.2f})" + f"(hg to vllm conversion: {conversion_time - start_time:.2f}s, tranport time: {end_time - conversion_time:.2f})" ) @endpoint From 9547532039800c36f9c13b1d0707b469d0b55a36 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 26 Sep 2025 09:13:59 -0400 Subject: [PATCH 05/12] format --- src/forge/actors/trainer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 6d2d66622..4d38fe8d4 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -17,9 +17,6 @@ import torch.distributed.checkpoint as dcp import torchstore as ts -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device - from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict @@ -39,6 +36,9 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -271,8 +271,9 @@ async def push_weights(self, policy_version: int, tp_DEPRECATED: int = 1) -> Non hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) # TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed vllm_ready_hf_sd = _qwen3_hf_to_vllm( - sd=hf_state_dict, num_layers=self.engine.model_args.n_layers, - tp=tp_DEPRECATED + sd=hf_state_dict, + num_layers=self.engine.model_args.n_layers, + tp=tp_DEPRECATED, ) conversion_time = time.perf_counter() key = f"{self.state_dict_key}{DELIM}{policy_version}" @@ -325,12 +326,14 @@ def _shard_and_concat(sources: list[torch.Tensor], dim: int, tp: int) -> torch.T combined_shards = [] for shard_idx in range(tp): - combined= torch.cat([s[shard_idx] for s in sharded_sources], dim=dim) + combined = torch.cat([s[shard_idx] for s in sharded_sources], dim=dim) combined_shards.append(combined) return torch.cat(combined_shards, dim=dim) -def _qwen3_hf_to_vllm(sd: dict[str, torch.Tensor], num_layers: int, tp: int) -> dict[str, torch.Tensor]: +def _qwen3_hf_to_vllm( + sd: dict[str, torch.Tensor], num_layers: int, tp: int +) -> dict[str, torch.Tensor]: """Convert transformers state dict to vLLM format. Specifically, this fuses QKV projection and MLP gate_up_proj layers. @@ -409,5 +412,4 @@ def unwrap(t): [gate_bias, up_bias], dim=0, tp=tp ) - return load_sd From d81f3ba17090923f7e2e5daa93de34e7fa7207f0 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 26 Sep 2025 09:25:06 -0400 Subject: [PATCH 06/12] nit --- apps/grpo/main.py | 3 ++- src/forge/actors/trainer.py | 6 +++--- tests/integration_tests/test_policy_update.py | 8 ++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index fb8e92812..003191645 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -243,6 +243,7 @@ async def main(cfg: DictConfig): group_size = cfg.group_size max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens + # TODO: this is a temporary; delete it after we are confident on the vllm weight sync long term fix PR #184 policy_tp_size = cfg.policy.engine_config.tensor_parallel_size mlogger = get_metric_logger( "wandb", @@ -357,7 +358,7 @@ async def continuous_training(): loss = await trainer.train_step.route(inputs, targets) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - await trainer.push_weights.fanout(training_step, tp_DEPRECATED=policy_tp_size) + await trainer.push_weights.fanout(training_step, vllm_tp_DEPRECATED=policy_tp_size) await policy.update_weights.fanout(training_step) print("Starting GRPO training loops...") diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4d38fe8d4..cd370657c 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -252,7 +252,7 @@ def train_step( return loss.item() @endpoint - async def push_weights(self, policy_version: int, tp_DEPRECATED: int = 1) -> None: + async def push_weights(self, policy_version: int, vllm_tp_DEPRECATED: int = 1) -> None: # Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. start_time = time.perf_counter() # TODO: @@ -273,7 +273,7 @@ async def push_weights(self, policy_version: int, tp_DEPRECATED: int = 1) -> Non vllm_ready_hf_sd = _qwen3_hf_to_vllm( sd=hf_state_dict, num_layers=self.engine.model_args.n_layers, - tp=tp_DEPRECATED, + vllm_tp=vllm_tp_DEPRECATED, ) conversion_time = time.perf_counter() key = f"{self.state_dict_key}{DELIM}{policy_version}" @@ -332,7 +332,7 @@ def _shard_and_concat(sources: list[torch.Tensor], dim: int, tp: int) -> torch.T def _qwen3_hf_to_vllm( - sd: dict[str, torch.Tensor], num_layers: int, tp: int + sd: dict[str, torch.Tensor], num_layers: int, vllm_tp: int ) -> dict[str, torch.Tensor]: """Convert transformers state dict to vLLM format. Specifically, this fuses QKV projection and MLP gate_up_proj layers. diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 28192e5de..51ee8bbf3 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -213,10 +213,10 @@ async def test_policy_update_single(self, trainer_cfg): v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.fanout(policy_version=v0, tp_DEPRECATED=tp_size) + await rl_trainer.push_weights.fanout(policy_version=v0, vllm_tp_DEPRECATED=tp_size) # Setting everything to zero await rl_trainer.zero_out_model_states.fanout() - await rl_trainer.push_weights.fanout(policy_version=v1, tp_DEPRECATED=tp_size) + await rl_trainer.push_weights.fanout(policy_version=v1, vllm_tp_DEPRECATED=tp_size) await policy._test_save_model_params.fanout() # Sanity check that before update all the tests pass @@ -281,10 +281,10 @@ async def test_policy_update_tp(self, trainer_cfg_tp): v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.fanout(policy_version=v0, tp_DEPRECATED=tp_size) + await rl_trainer.push_weights.fanout(policy_version=v0, vllm_tp_DEPRECATED=tp_size) # Setting everything to zero await rl_trainer.zero_out_model_states.fanout() - await rl_trainer.push_weights.fanout(policy_version=v1, tp_DEPRECATED=tp_size) + await rl_trainer.push_weights.fanout(policy_version=v1, vllm_tp_DEPRECATED=tp_size) await policy._test_save_model_params.fanout() # Sanity check that before update all the tests pass From 763b5402140da442e20ccb35963f26cdd41eed88 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 26 Sep 2025 09:27:53 -0400 Subject: [PATCH 07/12] fix --- src/forge/actors/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index cd370657c..d0b9ab994 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -377,7 +377,7 @@ def unwrap(t): v = sd[prefix + "self_attn.v_proj.weight"] load_sd[prefix + "self_attn.qkv_proj.weight"] = _shard_and_concat( - [q, k, v], dim=0, tp=tp + [q, k, v], dim=0, tp=vllm_tp ) # Untested: QKV fusion - handle bias if present @@ -390,14 +390,14 @@ def unwrap(t): k_bias = sd[k_bias_key] v_bias = sd[v_bias_key] load_sd[prefix + "self_attn.qkv_proj.bias"] = _shard_and_concat( - [q_bias, k_bias, v_bias], dim=0, tp=tp + [q_bias, k_bias, v_bias], dim=0, tp=vllm_tp ) # MLP gate_up_proj fusion gate = sd[prefix + "mlp.gate_proj.weight"] up = sd[prefix + "mlp.up_proj.weight"] load_sd[prefix + "mlp.gate_up_proj.weight"] = _shard_and_concat( - [gate, up], dim=0, tp=tp + [gate, up], dim=0, tp=vllm_tp ) # Untested: MLP gate_up_proj fusion - handle bias if present @@ -409,7 +409,7 @@ def unwrap(t): up_bias = sd[up_bias_key] # Same sharding has to happen here load_sd[prefix + "mlp.gate_up_proj.bias"] = _shard_and_concat( - [gate_bias, up_bias], dim=0, tp=tp + [gate_bias, up_bias], dim=0, tp=vllm_tp ) return load_sd From 5fbfd40a54aef711940fedaea13712fd4cd3cdc5 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 26 Sep 2025 09:33:12 -0400 Subject: [PATCH 08/12] nit --- apps/grpo/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 003191645..431ade155 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -243,7 +243,7 @@ async def main(cfg: DictConfig): group_size = cfg.group_size max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens - # TODO: this is a temporary; delete it after we are confident on the vllm weight sync long term fix PR #184 + # TODO: delete this logic after we are confident on the vllm weight sync long term fix PR #184 policy_tp_size = cfg.policy.engine_config.tensor_parallel_size mlogger = get_metric_logger( "wandb", From 31cf5531b995d021075ce4e9119e4cd65193eb74 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 26 Sep 2025 09:34:18 -0400 Subject: [PATCH 09/12] nit --- apps/grpo/qwen3_1_7b.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index b121e4eaa..5c0481528 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -111,7 +111,6 @@ services: trainer: procs: 1 num_replicas: 1 - hosts: 1 with_gpus: true replay_buffer: procs: 1 From 12f809f5c1404c397111f049403a5232f015f006 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 26 Sep 2025 09:42:16 -0400 Subject: [PATCH 10/12] another callsites --- apps/toy_rl/sumdigits.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index c77b1f60e..fa1f3e554 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -338,11 +338,11 @@ def train_step(self, episodes: list[Episode]) -> float: return loss.item() @endpoint - async def push_weights(self, version: int): + async def push_weights(self, version: int, vllm_tp_DEPRECATED: int) -> None: """Update policy model weights with trainer's current weights.""" key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id new_sd = _qwen3_hf_to_vllm( - self.model.state_dict(), num_layers=self.model.config.num_hidden_layers + self.model.state_dict(), num_layers=self.model.config.num_hidden_layers, vllm_tp=vllm_tp_DEPRECATED ) start_time = time.time() await ts.put_state_dict(new_sd, key) @@ -433,6 +433,8 @@ async def main(cfg: DictConfig): group_size = cfg.group_size max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens + # TODO: delete this logic after we are confident on the vllm weight sync long term fix PR #184 + policy_tp_size = cfg.policy.engine_config.tensor_parallel_size mlogger = get_metric_logger( "wandb", freq=1, @@ -520,7 +522,7 @@ async def continuous_training(): training_step += 1 mlogger.log("loss/training_step", loss, training_step) print(f"loss/training_step: {loss} at training step {training_step}") - await trainer.push_weights.fanout(training_step) + await trainer.push_weights.fanout(training_step, vllm_tp_DEPRECATED=policy_tp_size) await policy.update_weights.fanout(training_step) # NOTE: hard-coded to be on-policy for faster convergence await replay_buffer.clear.fanout() From 4f6950462ea549114c9beb4ecae244961a763c43 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 26 Sep 2025 09:47:15 -0400 Subject: [PATCH 11/12] format --- apps/toy_rl/sumdigits.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index fa1f3e554..cef2bd229 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -342,7 +342,9 @@ async def push_weights(self, version: int, vllm_tp_DEPRECATED: int) -> None: """Update policy model weights with trainer's current weights.""" key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id new_sd = _qwen3_hf_to_vllm( - self.model.state_dict(), num_layers=self.model.config.num_hidden_layers, vllm_tp=vllm_tp_DEPRECATED + self.model.state_dict(), + num_layers=self.model.config.num_hidden_layers, + vllm_tp=vllm_tp_DEPRECATED, ) start_time = time.time() await ts.put_state_dict(new_sd, key) @@ -522,7 +524,9 @@ async def continuous_training(): training_step += 1 mlogger.log("loss/training_step", loss, training_step) print(f"loss/training_step: {loss} at training step {training_step}") - await trainer.push_weights.fanout(training_step, vllm_tp_DEPRECATED=policy_tp_size) + await trainer.push_weights.fanout( + training_step, vllm_tp_DEPRECATED=policy_tp_size + ) await policy.update_weights.fanout(training_step) # NOTE: hard-coded to be on-policy for faster convergence await replay_buffer.clear.fanout() From 133f37230208ffdd0c4f7c53d8115498affd6ee6 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Fri, 26 Sep 2025 11:33:21 -0400 Subject: [PATCH 12/12] format --- apps/grpo/main.py | 4 +++- src/forge/actors/trainer.py | 4 +++- tests/integration_tests/test_policy_update.py | 16 ++++++++++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 431ade155..ac25fd52f 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -358,7 +358,9 @@ async def continuous_training(): loss = await trainer.train_step.route(inputs, targets) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - await trainer.push_weights.fanout(training_step, vllm_tp_DEPRECATED=policy_tp_size) + await trainer.push_weights.fanout( + training_step, vllm_tp_DEPRECATED=policy_tp_size + ) await policy.update_weights.fanout(training_step) print("Starting GRPO training loops...") diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index d0b9ab994..2e1bc48e9 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -252,7 +252,9 @@ def train_step( return loss.item() @endpoint - async def push_weights(self, policy_version: int, vllm_tp_DEPRECATED: int = 1) -> None: + async def push_weights( + self, policy_version: int, vllm_tp_DEPRECATED: int = 1 + ) -> None: # Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. start_time = time.perf_counter() # TODO: diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 51ee8bbf3..e77c33e7b 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -213,10 +213,14 @@ async def test_policy_update_single(self, trainer_cfg): v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.fanout(policy_version=v0, vllm_tp_DEPRECATED=tp_size) + await rl_trainer.push_weights.fanout( + policy_version=v0, vllm_tp_DEPRECATED=tp_size + ) # Setting everything to zero await rl_trainer.zero_out_model_states.fanout() - await rl_trainer.push_weights.fanout(policy_version=v1, vllm_tp_DEPRECATED=tp_size) + await rl_trainer.push_weights.fanout( + policy_version=v1, vllm_tp_DEPRECATED=tp_size + ) await policy._test_save_model_params.fanout() # Sanity check that before update all the tests pass @@ -281,10 +285,14 @@ async def test_policy_update_tp(self, trainer_cfg_tp): v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.fanout(policy_version=v0, vllm_tp_DEPRECATED=tp_size) + await rl_trainer.push_weights.fanout( + policy_version=v0, vllm_tp_DEPRECATED=tp_size + ) # Setting everything to zero await rl_trainer.zero_out_model_states.fanout() - await rl_trainer.push_weights.fanout(policy_version=v1, vllm_tp_DEPRECATED=tp_size) + await rl_trainer.push_weights.fanout( + policy_version=v1, vllm_tp_DEPRECATED=tp_size + ) await policy._test_save_model_params.fanout() # Sanity check that before update all the tests pass