Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import math
import os
import re
import time
from collections.abc import Mapping
from dataclasses import dataclass, field, fields
Expand Down Expand Up @@ -215,7 +216,7 @@ async def push_weights(self, policy_version: int) -> None:
)
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
# TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict, num_layers=28)
vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this just be simplified by using num_layers=self.model.config.num_hidden_layers? see this

Ideally you should not be needing this method in the trainer at all. The trainer should be agnostic to the type/arch of generator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should work with torchtitan:

num_layers=self.engine.model_args.n_layers)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess simply reading the max num layers from regex has the advantage that it's agnostic of the trainer implementation as long as the state dict is in huggingface format.
Let me know what you think. @Ritesh1905 @allenwang28

key = f"{self.state_dict_key}{DELIM}{policy_version}"
start_time = time.time()
await ts.put_state_dict(state_dict=vllm_ready_hf_sd, key=key)
Expand All @@ -230,7 +231,7 @@ async def cleanup(self) -> None:
self.engine.checkpointer.close()


def _qwen3_hf_to_vllm(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]:
def _qwen3_hf_to_vllm(sd: dict[str, Tensor]) -> dict[str, Tensor]:
"""Convert transformers state dict to vLLM format. Specifically, this fuses
QKV projection and MLP gate_up_proj layers.

Expand All @@ -243,6 +244,17 @@ def _qwen3_hf_to_vllm(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tenso
"""
load_sd = {}

# Infer num_layers from the state dict by finding the highest layer index
layer_indices = []
pattern = re.compile(r"model\.layers\.(\d+)\.")
for k in sd.keys():
m = pattern.search(k)
if m:
layer_indices.append(int(m.group(1)))
if not layer_indices:
raise ValueError("Could not infer num_layers from state dict keys.")
num_layers = max(layer_indices) + 1

# Copy over directly mapped keys
for k in sd:
if any(
Expand Down
Loading