Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ async def main(cfg: DictConfig):
group_size = cfg.group_size
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens
# TODO: delete this logic after we are confident on the vllm weight sync long term fix PR #184
policy_tp_size = cfg.policy.engine_config.tensor_parallel_size
mlogger = get_metric_logger(
"wandb",
freq=1,
Expand Down Expand Up @@ -356,7 +358,9 @@ async def continuous_training():
loss = await trainer.train_step.route(inputs, targets)
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.fanout(training_step)
await trainer.push_weights.fanout(
training_step, vllm_tp_DEPRECATED=policy_tp_size
)
await policy.update_weights.fanout(training_step)

print("Starting GRPO training loops...")
Expand Down
12 changes: 9 additions & 3 deletions apps/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,13 @@ def train_step(self, episodes: list[Episode]) -> float:
return loss.item()

@endpoint
async def push_weights(self, version: int):
async def push_weights(self, version: int, vllm_tp_DEPRECATED: int) -> None:
"""Update policy model weights with trainer's current weights."""
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
new_sd = _qwen3_hf_to_vllm(
self.model.state_dict(), num_layers=self.model.config.num_hidden_layers
self.model.state_dict(),
num_layers=self.model.config.num_hidden_layers,
vllm_tp=vllm_tp_DEPRECATED,
)
start_time = time.time()
await ts.put_state_dict(new_sd, key)
Expand Down Expand Up @@ -433,6 +435,8 @@ async def main(cfg: DictConfig):
group_size = cfg.group_size
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens
# TODO: delete this logic after we are confident on the vllm weight sync long term fix PR #184
policy_tp_size = cfg.policy.engine_config.tensor_parallel_size
mlogger = get_metric_logger(
"wandb",
freq=1,
Expand Down Expand Up @@ -520,7 +524,9 @@ async def continuous_training():
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
print(f"loss/training_step: {loss} at training step {training_step}")
await trainer.push_weights.fanout(training_step)
await trainer.push_weights.fanout(
training_step, vllm_tp_DEPRECATED=policy_tp_size
)
await policy.update_weights.fanout(training_step)
# NOTE: hard-coded to be on-policy for faster convergence
await replay_buffer.clear.fanout()
Expand Down
58 changes: 46 additions & 12 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def train_step(
return loss.item()

@endpoint
async def push_weights(self, policy_version: int) -> None:
async def push_weights(
self, policy_version: int, vllm_tp_DEPRECATED: int = 1
) -> None:
# Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
start_time = time.perf_counter()
# TODO:
Expand All @@ -271,7 +273,9 @@ async def push_weights(self, policy_version: int) -> None:
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
# TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
vllm_ready_hf_sd = _qwen3_hf_to_vllm(
sd=hf_state_dict, num_layers=self.engine.model_args.n_layers
sd=hf_state_dict,
num_layers=self.engine.model_args.n_layers,
vllm_tp=vllm_tp_DEPRECATED,
)
conversion_time = time.perf_counter()
key = f"{self.state_dict_key}{DELIM}{policy_version}"
Expand All @@ -298,7 +302,7 @@ async def push_weights(self, policy_version: int) -> None:
end_time = time.perf_counter()
logger.info(
f"Completed weights push to {key} in {end_time - start_time:.2f} seconds "
f"(to_vllm: {conversion_time - start_time:.2f}s, tranport time: {end_time - conversion_time:.2f})"
f"(hg to vllm conversion: {conversion_time - start_time:.2f}s, tranport time: {end_time - conversion_time:.2f})"
)

@endpoint
Expand All @@ -307,7 +311,31 @@ async def cleanup(self) -> None:
self.engine.checkpointer.close()


def _qwen3_hf_to_vllm(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]:
def _shard_and_concat(sources: list[torch.Tensor], dim: int, tp: int) -> torch.Tensor:
"""Shard and concatenate tensors along a given dimension.

Args:
source (list[torch.Tensor]): List of tensors to shard and concatenate.
dim (int): Dimension along which to shard and concatenate.
tp (int): Number of tensor parallel groups.

Returns:
torch.Tensor: Concatenated tensor.
"""
sharded_sources = []
for source in sources:
sharded_sources.append(torch.chunk(source, tp, dim=dim))

combined_shards = []
for shard_idx in range(tp):
combined = torch.cat([s[shard_idx] for s in sharded_sources], dim=dim)
combined_shards.append(combined)
return torch.cat(combined_shards, dim=dim)


def _qwen3_hf_to_vllm(
sd: dict[str, torch.Tensor], num_layers: int, vllm_tp: int
) -> dict[str, torch.Tensor]:
"""Convert transformers state dict to vLLM format. Specifically, this fuses
QKV projection and MLP gate_up_proj layers.

Expand Down Expand Up @@ -349,9 +377,12 @@ def unwrap(t):
q = sd[prefix + "self_attn.q_proj.weight"]
k = sd[prefix + "self_attn.k_proj.weight"]
v = sd[prefix + "self_attn.v_proj.weight"]
load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0)

# QKV fusion - handle bias if present
load_sd[prefix + "self_attn.qkv_proj.weight"] = _shard_and_concat(
[q, k, v], dim=0, tp=vllm_tp
)

# Untested: QKV fusion - handle bias if present
q_bias_key = prefix + "self_attn.q_proj.bias"
k_bias_key = prefix + "self_attn.k_proj.bias"
v_bias_key = prefix + "self_attn.v_proj.bias"
Expand All @@ -360,24 +391,27 @@ def unwrap(t):
q_bias = sd[q_bias_key]
k_bias = sd[k_bias_key]
v_bias = sd[v_bias_key]
load_sd[prefix + "self_attn.qkv_proj.bias"] = torch.cat(
[q_bias, k_bias, v_bias], dim=0
load_sd[prefix + "self_attn.qkv_proj.bias"] = _shard_and_concat(
[q_bias, k_bias, v_bias], dim=0, tp=vllm_tp
)

# MLP gate_up_proj fusion
gate = sd[prefix + "mlp.gate_proj.weight"]
up = sd[prefix + "mlp.up_proj.weight"]
load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0)
load_sd[prefix + "mlp.gate_up_proj.weight"] = _shard_and_concat(
[gate, up], dim=0, tp=vllm_tp
)

# MLP gate_up_proj fusion - handle bias if present
# Untested: MLP gate_up_proj fusion - handle bias if present
gate_bias_key = prefix + "mlp.gate_proj.bias"
up_bias_key = prefix + "mlp.up_proj.bias"

if all(key in sd for key in [gate_bias_key, up_bias_key]):
gate_bias = sd[gate_bias_key]
up_bias = sd[up_bias_key]
load_sd[prefix + "mlp.gate_up_proj.bias"] = torch.cat(
[gate_bias, up_bias], dim=0
# Same sharding has to happen here
load_sd[prefix + "mlp.gate_up_proj.bias"] = _shard_and_concat(
[gate_bias, up_bias], dim=0, tp=vllm_tp
)

return load_sd
16 changes: 12 additions & 4 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,14 @@ async def test_policy_update_single(self, trainer_cfg):
v0 = uuid.uuid4().int
v1 = v0 + 1

await rl_trainer.push_weights.fanout(policy_version=v0)
await rl_trainer.push_weights.fanout(
policy_version=v0, vllm_tp_DEPRECATED=tp_size
)
# Setting everything to zero
await rl_trainer.zero_out_model_states.fanout()
await rl_trainer.push_weights.fanout(policy_version=v1)
await rl_trainer.push_weights.fanout(
policy_version=v1, vllm_tp_DEPRECATED=tp_size
)
await policy._test_save_model_params.fanout()

# Sanity check that before update all the tests pass
Expand Down Expand Up @@ -281,10 +285,14 @@ async def test_policy_update_tp(self, trainer_cfg_tp):
v0 = uuid.uuid4().int
v1 = v0 + 1

await rl_trainer.push_weights.fanout(policy_version=v0)
await rl_trainer.push_weights.fanout(
policy_version=v0, vllm_tp_DEPRECATED=tp_size
)
# Setting everything to zero
await rl_trainer.zero_out_model_states.fanout()
await rl_trainer.push_weights.fanout(policy_version=v1)
await rl_trainer.push_weights.fanout(
policy_version=v1, vllm_tp_DEPRECATED=tp_size
)
await policy._test_save_model_params.fanout()

# Sanity check that before update all the tests pass
Expand Down
Loading