diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 05922b644..ac25fd52f 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -243,6 +243,8 @@ async def main(cfg: DictConfig): group_size = cfg.group_size max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens + # TODO: delete this logic after we are confident on the vllm weight sync long term fix PR #184 + policy_tp_size = cfg.policy.engine_config.tensor_parallel_size mlogger = get_metric_logger( "wandb", freq=1, @@ -356,7 +358,9 @@ async def continuous_training(): loss = await trainer.train_step.route(inputs, targets) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - await trainer.push_weights.fanout(training_step) + await trainer.push_weights.fanout( + training_step, vllm_tp_DEPRECATED=policy_tp_size + ) await policy.update_weights.fanout(training_step) print("Starting GRPO training loops...") diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index c77b1f60e..cef2bd229 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -338,11 +338,13 @@ def train_step(self, episodes: list[Episode]) -> float: return loss.item() @endpoint - async def push_weights(self, version: int): + async def push_weights(self, version: int, vllm_tp_DEPRECATED: int) -> None: """Update policy model weights with trainer's current weights.""" key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id new_sd = _qwen3_hf_to_vllm( - self.model.state_dict(), num_layers=self.model.config.num_hidden_layers + self.model.state_dict(), + num_layers=self.model.config.num_hidden_layers, + vllm_tp=vllm_tp_DEPRECATED, ) start_time = time.time() await ts.put_state_dict(new_sd, key) @@ -433,6 +435,8 @@ async def main(cfg: DictConfig): group_size = cfg.group_size max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens + # TODO: delete this logic after we are confident on the vllm weight sync long term fix PR #184 + policy_tp_size = cfg.policy.engine_config.tensor_parallel_size mlogger = get_metric_logger( "wandb", freq=1, @@ -520,7 +524,9 @@ async def continuous_training(): training_step += 1 mlogger.log("loss/training_step", loss, training_step) print(f"loss/training_step: {loss} at training step {training_step}") - await trainer.push_weights.fanout(training_step) + await trainer.push_weights.fanout( + training_step, vllm_tp_DEPRECATED=policy_tp_size + ) await policy.update_weights.fanout(training_step) # NOTE: hard-coded to be on-policy for faster convergence await replay_buffer.clear.fanout() diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index f0f5d328a..2e1bc48e9 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -252,7 +252,9 @@ def train_step( return loss.item() @endpoint - async def push_weights(self, policy_version: int) -> None: + async def push_weights( + self, policy_version: int, vllm_tp_DEPRECATED: int = 1 + ) -> None: # Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. start_time = time.perf_counter() # TODO: @@ -271,7 +273,9 @@ async def push_weights(self, policy_version: int) -> None: hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) # TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed vllm_ready_hf_sd = _qwen3_hf_to_vllm( - sd=hf_state_dict, num_layers=self.engine.model_args.n_layers + sd=hf_state_dict, + num_layers=self.engine.model_args.n_layers, + vllm_tp=vllm_tp_DEPRECATED, ) conversion_time = time.perf_counter() key = f"{self.state_dict_key}{DELIM}{policy_version}" @@ -298,7 +302,7 @@ async def push_weights(self, policy_version: int) -> None: end_time = time.perf_counter() logger.info( f"Completed weights push to {key} in {end_time - start_time:.2f} seconds " - f"(to_vllm: {conversion_time - start_time:.2f}s, tranport time: {end_time - conversion_time:.2f})" + f"(hg to vllm conversion: {conversion_time - start_time:.2f}s, tranport time: {end_time - conversion_time:.2f})" ) @endpoint @@ -307,7 +311,31 @@ async def cleanup(self) -> None: self.engine.checkpointer.close() -def _qwen3_hf_to_vllm(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: +def _shard_and_concat(sources: list[torch.Tensor], dim: int, tp: int) -> torch.Tensor: + """Shard and concatenate tensors along a given dimension. + + Args: + source (list[torch.Tensor]): List of tensors to shard and concatenate. + dim (int): Dimension along which to shard and concatenate. + tp (int): Number of tensor parallel groups. + + Returns: + torch.Tensor: Concatenated tensor. + """ + sharded_sources = [] + for source in sources: + sharded_sources.append(torch.chunk(source, tp, dim=dim)) + + combined_shards = [] + for shard_idx in range(tp): + combined = torch.cat([s[shard_idx] for s in sharded_sources], dim=dim) + combined_shards.append(combined) + return torch.cat(combined_shards, dim=dim) + + +def _qwen3_hf_to_vllm( + sd: dict[str, torch.Tensor], num_layers: int, vllm_tp: int +) -> dict[str, torch.Tensor]: """Convert transformers state dict to vLLM format. Specifically, this fuses QKV projection and MLP gate_up_proj layers. @@ -349,9 +377,12 @@ def unwrap(t): q = sd[prefix + "self_attn.q_proj.weight"] k = sd[prefix + "self_attn.k_proj.weight"] v = sd[prefix + "self_attn.v_proj.weight"] - load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) - # QKV fusion - handle bias if present + load_sd[prefix + "self_attn.qkv_proj.weight"] = _shard_and_concat( + [q, k, v], dim=0, tp=vllm_tp + ) + + # Untested: QKV fusion - handle bias if present q_bias_key = prefix + "self_attn.q_proj.bias" k_bias_key = prefix + "self_attn.k_proj.bias" v_bias_key = prefix + "self_attn.v_proj.bias" @@ -360,24 +391,27 @@ def unwrap(t): q_bias = sd[q_bias_key] k_bias = sd[k_bias_key] v_bias = sd[v_bias_key] - load_sd[prefix + "self_attn.qkv_proj.bias"] = torch.cat( - [q_bias, k_bias, v_bias], dim=0 + load_sd[prefix + "self_attn.qkv_proj.bias"] = _shard_and_concat( + [q_bias, k_bias, v_bias], dim=0, tp=vllm_tp ) # MLP gate_up_proj fusion gate = sd[prefix + "mlp.gate_proj.weight"] up = sd[prefix + "mlp.up_proj.weight"] - load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) + load_sd[prefix + "mlp.gate_up_proj.weight"] = _shard_and_concat( + [gate, up], dim=0, tp=vllm_tp + ) - # MLP gate_up_proj fusion - handle bias if present + # Untested: MLP gate_up_proj fusion - handle bias if present gate_bias_key = prefix + "mlp.gate_proj.bias" up_bias_key = prefix + "mlp.up_proj.bias" if all(key in sd for key in [gate_bias_key, up_bias_key]): gate_bias = sd[gate_bias_key] up_bias = sd[up_bias_key] - load_sd[prefix + "mlp.gate_up_proj.bias"] = torch.cat( - [gate_bias, up_bias], dim=0 + # Same sharding has to happen here + load_sd[prefix + "mlp.gate_up_proj.bias"] = _shard_and_concat( + [gate_bias, up_bias], dim=0, tp=vllm_tp ) return load_sd diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 543956877..e77c33e7b 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -213,10 +213,14 @@ async def test_policy_update_single(self, trainer_cfg): v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.fanout(policy_version=v0) + await rl_trainer.push_weights.fanout( + policy_version=v0, vllm_tp_DEPRECATED=tp_size + ) # Setting everything to zero await rl_trainer.zero_out_model_states.fanout() - await rl_trainer.push_weights.fanout(policy_version=v1) + await rl_trainer.push_weights.fanout( + policy_version=v1, vllm_tp_DEPRECATED=tp_size + ) await policy._test_save_model_params.fanout() # Sanity check that before update all the tests pass @@ -281,10 +285,14 @@ async def test_policy_update_tp(self, trainer_cfg_tp): v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.fanout(policy_version=v0) + await rl_trainer.push_weights.fanout( + policy_version=v0, vllm_tp_DEPRECATED=tp_size + ) # Setting everything to zero await rl_trainer.zero_out_model_states.fanout() - await rl_trainer.push_weights.fanout(policy_version=v1) + await rl_trainer.push_weights.fanout( + policy_version=v1, vllm_tp_DEPRECATED=tp_size + ) await policy._test_save_model_params.fanout() # Sanity check that before update all the tests pass