diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 0368fac4d..b148a7f94 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -6,6 +6,7 @@ import math import os +import re import time from collections.abc import Mapping from dataclasses import dataclass, field, fields @@ -215,7 +216,7 @@ async def push_weights(self, policy_version: int) -> None: ) hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) # TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed - vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict, num_layers=28) + vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict) key = f"{self.state_dict_key}{DELIM}{policy_version}" start_time = time.time() await ts.put_state_dict(state_dict=vllm_ready_hf_sd, key=key) @@ -230,7 +231,7 @@ async def cleanup(self) -> None: self.engine.checkpointer.close() -def _qwen3_hf_to_vllm(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: +def _qwen3_hf_to_vllm(sd: dict[str, Tensor]) -> dict[str, Tensor]: """Convert transformers state dict to vLLM format. Specifically, this fuses QKV projection and MLP gate_up_proj layers. @@ -243,6 +244,17 @@ def _qwen3_hf_to_vllm(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tenso """ load_sd = {} + # Infer num_layers from the state dict by finding the highest layer index + layer_indices = [] + pattern = re.compile(r"model\.layers\.(\d+)\.") + for k in sd.keys(): + m = pattern.search(k) + if m: + layer_indices.append(int(m.group(1))) + if not layer_indices: + raise ValueError("Could not infer num_layers from state dict keys.") + num_layers = max(layer_indices) + 1 + # Copy over directly mapped keys for k in sd: if any(