Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 4 additions & 217 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import logging
import os
import shutil

import time
from collections.abc import Mapping
Expand Down Expand Up @@ -53,45 +52,6 @@
logger.setLevel(logging.DEBUG)


def cleanup_old_weight_versions(
state_dict_key: str,
delim: str,
current_policy_version: int,
) -> None:
"""Delete old weight versions, keeping only current and N-1 versions.

TODO - issues/194: provide a more robust way to handle eviction.

Args:
state_dict_key: The base key for state dict storage
delim: The delimiter used between key and version
current_policy_version: The current policy version to keep
"""
if current_policy_version <= 1:
return # No cleanup needed for versions 0 or 1

prefix = f"{state_dict_key}{delim}"
current_weights = f"{prefix}{current_policy_version}"
previous_weights = f"{prefix}{current_policy_version - 1}"

# Find all weight directories that match our pattern
parent_dir = os.path.dirname(prefix) or "."
if os.path.exists(parent_dir):
for item in os.listdir(parent_dir):
item_path = os.path.join(parent_dir, item)
if (
item.startswith(os.path.basename(prefix))
and item != os.path.basename(current_weights)
and item != os.path.basename(previous_weights)
and os.path.isdir(item_path)
):
try:
shutil.rmtree(item_path, ignore_errors=True)
logger.debug(f"Removed old weights at {item_path}")
except OSError as e:
logger.debug(f"Error deleting {item_path}: {e}")


@dataclass
class RLTrainer(ForgeActor):
"""A reinforcement learning trainer actor for policy optimization training.
Expand Down Expand Up @@ -135,19 +95,10 @@ class RLTrainer(ForgeActor):
dcp_path: str = "forge_dcp_tmp"

def __post_init__(self):
"""Initializes config types and env variables.

torchrun normally hands env variables, but we need to do it ourselves
in monarch for now.

"""
super().__init__()

if self.use_dcp:
# DCP specific optimization
torch.serialization.set_crc32_options(False)

# Instantiate dict fields
for f in fields(self):
attr = getattr(self, f.name)
if isinstance(attr, Mapping):
Expand Down Expand Up @@ -184,73 +135,23 @@ def forward_backward(
) -> Tensor:
model_parts = self.engine.model_parts
parallel_dims = self.engine.parallel_dims

# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
# if getattr(self.engine.model_args, "use_flex_attn", False):
# cp_mesh = (
# parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
# )
# init_attention_mask(
# inputs, self.engine.tokenizer.base_tokenizer.eos_id, cp_mesh
# )

# optional_context_parallel_ctx = (
# dist_utils.create_context_parallel_ctx(
# cp_mesh=parallel_dims.world_mesh["cp"],
# cp_buffers=[inputs, targets] + [m.freqs_cis for m in model_parts],
# cp_seq_dims=[1, 1] + [0 for _ in model_parts],
# cp_no_restore_buffers={inputs, targets},
# cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
# )
# if parallel_dims.cp_enabled
# else None
# )
optional_context_parallel_ctx = None

if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet")
# TODO implement PP
# # Pipeline Parallel forward / backward inside step() call
# with self.train_context(optional_context_parallel_ctx):
# targets, losses = (
# (labels, []) if self.pp_has_last_stage else (None, None)
# )
# if self.pp_has_first_stage:
# self.pp_schedule.step(
# inputs, target=targets, losses=losses, input_batch=inputs
# )
# else:
# self.pp_schedule.step(
# target=targets, losses=losses, input_batch=inputs
# )
#
# # accumulate losses across pipeline microbatches
# # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
# loss = (
# torch.mean(torch.stack(losses)).to(self.device)
# if self.pp_has_last_stage
# else torch.tensor([-1.0], device=self.device)
# )
else:
# Non-PP forward / backward
with self.engine.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.engine.maybe_enable_amp:
logits = model_parts[0](**inputs)
loss = self.loss(logits, **targets)
# need to free to before bwd to avoid peaking memory
del logits
del logits # Free to before bwd to avoid peaking memory
loss.backward()

return loss

@endpoint
async def train_step(
self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]]
) -> float:

# Log timesteps
t = Tracer("rl_trainer_perf/step", timer="gpu", track_memory=True)
t.start()

Expand All @@ -259,18 +160,12 @@ async def train_step(
local_targets = targets[self.engine.dp_rank]
batch_to_device(local_inputs, self.engine.device)
batch_to_device(local_targets, self.engine.device)
# compute policy logprobs
# TODO implement gradient accumulation
# with GradientAccumulation(
# self.gradient_accumulation_steps,
# self.model,
# self.data_parallel_size,
# ) as grad_acc:

loss = self.forward_backward(local_inputs, local_targets)
torch.distributed.all_reduce(loss)

t.step("forward_backward")

# Get learning rate from scheduler
current_lr = (
self.engine.lr_schedulers.get_last_lr()[0]
if hasattr(self.engine.lr_schedulers, "get_last_lr")
Expand All @@ -283,13 +178,11 @@ async def train_step(
self.engine.lr_schedulers.step()
t.step("optimizer_step")

# Record training metrics
# TODO: delete item() to avoid cpu-gpu sync
loss = loss.detach().cpu().item()
loss = loss.detach().item()
record_metric("rl_trainer/count_training_steps", 1, Reduce.SUM)
record_metric("rl_trainer/avg_grpo_loss", loss, Reduce.MEAN)

# TODO: Extract actual KL divergence and policy entropy from the loss computation
# These are placeholder values until the loss function exposes these metrics
# record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN)
# record_metric("rl_trainer/step/std_kl_divergence", 0.0, Reduce.STD)
Expand Down Expand Up @@ -351,109 +244,3 @@ async def push_weights(self, policy_version: int) -> None:
async def cleanup(self) -> None:
if self.engine.checkpointer:
self.engine.checkpointer.close()


def _shard_and_concat(sources: list[torch.Tensor], dim: int, tp: int) -> torch.Tensor:
"""Shard and concatenate tensors along a given dimension.

Args:
source (list[torch.Tensor]): List of tensors to shard and concatenate.
dim (int): Dimension along which to shard and concatenate.
tp (int): Number of tensor parallel groups.

Returns:
torch.Tensor: Concatenated tensor.
"""
sharded_sources = []
for source in sources:
sharded_sources.append(torch.chunk(source, tp, dim=dim))

combined_shards = []
for shard_idx in range(tp):
combined = torch.cat([s[shard_idx] for s in sharded_sources], dim=dim)
combined_shards.append(combined)
return torch.cat(combined_shards, dim=dim)


def _qwen3_hf_to_vllm(
sd: dict[str, torch.Tensor], num_layers: int, vllm_tp: int
) -> dict[str, torch.Tensor]:
"""Convert transformers state dict to vLLM format. Specifically, this fuses
QKV projection and MLP gate_up_proj layers.

Args:
sd (dict): State dict from HF model.
num_layers (int): Number of layers in the model.

Returns:
dict: State dict in vLLM format.
"""
load_sd = {}

def unwrap(t):
"""Unwrap a DTensor to a Tensor."""
return t.full_tensor() if isinstance(t, torch.distributed.tensor.DTensor) else t

for key in sd.keys():
sd[key] = unwrap(sd[key]).cpu()

# Copy over directly mapped keys
for k in sd:
if any(
x in k
for x in [
"down_proj",
"input_layernorm",
"post_attention_layernorm",
"o_proj",
"norm.weight",
"embed_tokens.weight",
"lm_head.weight",
]
):
load_sd[k] = sd[k]

for i in range(num_layers):
prefix = f"model.layers.{i}."
# QKV fusion
q = sd[prefix + "self_attn.q_proj.weight"]
k = sd[prefix + "self_attn.k_proj.weight"]
v = sd[prefix + "self_attn.v_proj.weight"]

load_sd[prefix + "self_attn.qkv_proj.weight"] = _shard_and_concat(
[q, k, v], dim=0, tp=vllm_tp
)

# Untested: QKV fusion - handle bias if present
q_bias_key = prefix + "self_attn.q_proj.bias"
k_bias_key = prefix + "self_attn.k_proj.bias"
v_bias_key = prefix + "self_attn.v_proj.bias"

if all(key in sd for key in [q_bias_key, k_bias_key, v_bias_key]):
q_bias = sd[q_bias_key]
k_bias = sd[k_bias_key]
v_bias = sd[v_bias_key]
load_sd[prefix + "self_attn.qkv_proj.bias"] = _shard_and_concat(
[q_bias, k_bias, v_bias], dim=0, tp=vllm_tp
)

# MLP gate_up_proj fusion
gate = sd[prefix + "mlp.gate_proj.weight"]
up = sd[prefix + "mlp.up_proj.weight"]
load_sd[prefix + "mlp.gate_up_proj.weight"] = _shard_and_concat(
[gate, up], dim=0, tp=vllm_tp
)

# Untested: MLP gate_up_proj fusion - handle bias if present
gate_bias_key = prefix + "mlp.gate_proj.bias"
up_bias_key = prefix + "mlp.up_proj.bias"

if all(key in sd for key in [gate_bias_key, up_bias_key]):
gate_bias = sd[gate_bias_key]
up_bias = sd[up_bias_key]
# Same sharding has to happen here
load_sd[prefix + "mlp.gate_up_proj.bias"] = _shard_and_concat(
[gate_bias, up_bias], dim=0, tp=vllm_tp
)

return load_sd
102 changes: 0 additions & 102 deletions tests/unit_tests/test_trainer.py

This file was deleted.

Loading