Skip to content

Commit 710703e

Browse files
author
Felipe Mello
committed
Merge branch 'main' of https://github.com/meta-pytorch/forge into sft_metrics
2 parents 6ec9733 + 8bd8d5d commit 710703e

File tree

2 files changed

+4
-319
lines changed

2 files changed

+4
-319
lines changed

src/forge/actors/trainer.py

Lines changed: 4 additions & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import logging
88
import os
9-
import shutil
109

1110
import time
1211
from collections.abc import Mapping
@@ -53,45 +52,6 @@
5352
logger.setLevel(logging.DEBUG)
5453

5554

56-
def cleanup_old_weight_versions(
57-
state_dict_key: str,
58-
delim: str,
59-
current_policy_version: int,
60-
) -> None:
61-
"""Delete old weight versions, keeping only current and N-1 versions.
62-
63-
TODO - issues/194: provide a more robust way to handle eviction.
64-
65-
Args:
66-
state_dict_key: The base key for state dict storage
67-
delim: The delimiter used between key and version
68-
current_policy_version: The current policy version to keep
69-
"""
70-
if current_policy_version <= 1:
71-
return # No cleanup needed for versions 0 or 1
72-
73-
prefix = f"{state_dict_key}{delim}"
74-
current_weights = f"{prefix}{current_policy_version}"
75-
previous_weights = f"{prefix}{current_policy_version - 1}"
76-
77-
# Find all weight directories that match our pattern
78-
parent_dir = os.path.dirname(prefix) or "."
79-
if os.path.exists(parent_dir):
80-
for item in os.listdir(parent_dir):
81-
item_path = os.path.join(parent_dir, item)
82-
if (
83-
item.startswith(os.path.basename(prefix))
84-
and item != os.path.basename(current_weights)
85-
and item != os.path.basename(previous_weights)
86-
and os.path.isdir(item_path)
87-
):
88-
try:
89-
shutil.rmtree(item_path, ignore_errors=True)
90-
logger.debug(f"Removed old weights at {item_path}")
91-
except OSError as e:
92-
logger.debug(f"Error deleting {item_path}: {e}")
93-
94-
9555
@dataclass
9656
class RLTrainer(ForgeActor):
9757
"""A reinforcement learning trainer actor for policy optimization training.
@@ -135,19 +95,10 @@ class RLTrainer(ForgeActor):
13595
dcp_path: str = "forge_dcp_tmp"
13696

13797
def __post_init__(self):
138-
"""Initializes config types and env variables.
139-
140-
torchrun normally hands env variables, but we need to do it ourselves
141-
in monarch for now.
142-
143-
"""
14498
super().__init__()
145-
14699
if self.use_dcp:
147-
# DCP specific optimization
148100
torch.serialization.set_crc32_options(False)
149101

150-
# Instantiate dict fields
151102
for f in fields(self):
152103
attr = getattr(self, f.name)
153104
if isinstance(attr, Mapping):
@@ -184,73 +135,23 @@ def forward_backward(
184135
) -> Tensor:
185136
model_parts = self.engine.model_parts
186137
parallel_dims = self.engine.parallel_dims
187-
188-
# apply context parallelism if cp is enabled
189-
# ensure CP handles the separate freqs_cis buffer for each pp stage
190-
# if getattr(self.engine.model_args, "use_flex_attn", False):
191-
# cp_mesh = (
192-
# parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
193-
# )
194-
# init_attention_mask(
195-
# inputs, self.engine.tokenizer.base_tokenizer.eos_id, cp_mesh
196-
# )
197-
198-
# optional_context_parallel_ctx = (
199-
# dist_utils.create_context_parallel_ctx(
200-
# cp_mesh=parallel_dims.world_mesh["cp"],
201-
# cp_buffers=[inputs, targets] + [m.freqs_cis for m in model_parts],
202-
# cp_seq_dims=[1, 1] + [0 for _ in model_parts],
203-
# cp_no_restore_buffers={inputs, targets},
204-
# cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
205-
# )
206-
# if parallel_dims.cp_enabled
207-
# else None
208-
# )
209138
optional_context_parallel_ctx = None
210-
211139
if parallel_dims.pp_enabled:
212140
raise NotImplementedError("PP not implemented yet")
213-
# TODO implement PP
214-
# # Pipeline Parallel forward / backward inside step() call
215-
# with self.train_context(optional_context_parallel_ctx):
216-
# targets, losses = (
217-
# (labels, []) if self.pp_has_last_stage else (None, None)
218-
# )
219-
# if self.pp_has_first_stage:
220-
# self.pp_schedule.step(
221-
# inputs, target=targets, losses=losses, input_batch=inputs
222-
# )
223-
# else:
224-
# self.pp_schedule.step(
225-
# target=targets, losses=losses, input_batch=inputs
226-
# )
227-
#
228-
# # accumulate losses across pipeline microbatches
229-
# # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
230-
# loss = (
231-
# torch.mean(torch.stack(losses)).to(self.device)
232-
# if self.pp_has_last_stage
233-
# else torch.tensor([-1.0], device=self.device)
234-
# )
235141
else:
236-
# Non-PP forward / backward
237142
with self.engine.train_context(optional_context_parallel_ctx):
238143
assert len(model_parts) == 1
239144
with self.engine.maybe_enable_amp:
240145
logits = model_parts[0](**inputs)
241146
loss = self.loss(logits, **targets)
242-
# need to free to before bwd to avoid peaking memory
243-
del logits
147+
del logits # Free to before bwd to avoid peaking memory
244148
loss.backward()
245-
246149
return loss
247150

248151
@endpoint
249152
async def train_step(
250153
self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]]
251154
) -> float:
252-
253-
# Log timesteps
254155
t = Tracer("rl_trainer_perf/step", timer="gpu", track_memory=True)
255156
t.start()
256157

@@ -259,18 +160,12 @@ async def train_step(
259160
local_targets = targets[self.engine.dp_rank]
260161
batch_to_device(local_inputs, self.engine.device)
261162
batch_to_device(local_targets, self.engine.device)
262-
# compute policy logprobs
263-
# TODO implement gradient accumulation
264-
# with GradientAccumulation(
265-
# self.gradient_accumulation_steps,
266-
# self.model,
267-
# self.data_parallel_size,
268-
# ) as grad_acc:
163+
269164
loss = self.forward_backward(local_inputs, local_targets)
270165
torch.distributed.all_reduce(loss)
166+
271167
t.step("forward_backward")
272168

273-
# Get learning rate from scheduler
274169
current_lr = (
275170
self.engine.lr_schedulers.get_last_lr()[0]
276171
if hasattr(self.engine.lr_schedulers, "get_last_lr")
@@ -283,13 +178,11 @@ async def train_step(
283178
self.engine.lr_schedulers.step()
284179
t.step("optimizer_step")
285180

286-
# Record training metrics
287181
# TODO: delete item() to avoid cpu-gpu sync
288-
loss = loss.detach().cpu().item()
182+
loss = loss.detach().item()
289183
record_metric("rl_trainer/count_training_steps", 1, Reduce.SUM)
290184
record_metric("rl_trainer/avg_grpo_loss", loss, Reduce.MEAN)
291185

292-
# TODO: Extract actual KL divergence and policy entropy from the loss computation
293186
# These are placeholder values until the loss function exposes these metrics
294187
# record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN)
295188
# record_metric("rl_trainer/step/std_kl_divergence", 0.0, Reduce.STD)
@@ -351,109 +244,3 @@ async def push_weights(self, policy_version: int) -> None:
351244
async def cleanup(self) -> None:
352245
if self.engine.checkpointer:
353246
self.engine.checkpointer.close()
354-
355-
356-
def _shard_and_concat(sources: list[torch.Tensor], dim: int, tp: int) -> torch.Tensor:
357-
"""Shard and concatenate tensors along a given dimension.
358-
359-
Args:
360-
source (list[torch.Tensor]): List of tensors to shard and concatenate.
361-
dim (int): Dimension along which to shard and concatenate.
362-
tp (int): Number of tensor parallel groups.
363-
364-
Returns:
365-
torch.Tensor: Concatenated tensor.
366-
"""
367-
sharded_sources = []
368-
for source in sources:
369-
sharded_sources.append(torch.chunk(source, tp, dim=dim))
370-
371-
combined_shards = []
372-
for shard_idx in range(tp):
373-
combined = torch.cat([s[shard_idx] for s in sharded_sources], dim=dim)
374-
combined_shards.append(combined)
375-
return torch.cat(combined_shards, dim=dim)
376-
377-
378-
def _qwen3_hf_to_vllm(
379-
sd: dict[str, torch.Tensor], num_layers: int, vllm_tp: int
380-
) -> dict[str, torch.Tensor]:
381-
"""Convert transformers state dict to vLLM format. Specifically, this fuses
382-
QKV projection and MLP gate_up_proj layers.
383-
384-
Args:
385-
sd (dict): State dict from HF model.
386-
num_layers (int): Number of layers in the model.
387-
388-
Returns:
389-
dict: State dict in vLLM format.
390-
"""
391-
load_sd = {}
392-
393-
def unwrap(t):
394-
"""Unwrap a DTensor to a Tensor."""
395-
return t.full_tensor() if isinstance(t, torch.distributed.tensor.DTensor) else t
396-
397-
for key in sd.keys():
398-
sd[key] = unwrap(sd[key]).cpu()
399-
400-
# Copy over directly mapped keys
401-
for k in sd:
402-
if any(
403-
x in k
404-
for x in [
405-
"down_proj",
406-
"input_layernorm",
407-
"post_attention_layernorm",
408-
"o_proj",
409-
"norm.weight",
410-
"embed_tokens.weight",
411-
"lm_head.weight",
412-
]
413-
):
414-
load_sd[k] = sd[k]
415-
416-
for i in range(num_layers):
417-
prefix = f"model.layers.{i}."
418-
# QKV fusion
419-
q = sd[prefix + "self_attn.q_proj.weight"]
420-
k = sd[prefix + "self_attn.k_proj.weight"]
421-
v = sd[prefix + "self_attn.v_proj.weight"]
422-
423-
load_sd[prefix + "self_attn.qkv_proj.weight"] = _shard_and_concat(
424-
[q, k, v], dim=0, tp=vllm_tp
425-
)
426-
427-
# Untested: QKV fusion - handle bias if present
428-
q_bias_key = prefix + "self_attn.q_proj.bias"
429-
k_bias_key = prefix + "self_attn.k_proj.bias"
430-
v_bias_key = prefix + "self_attn.v_proj.bias"
431-
432-
if all(key in sd for key in [q_bias_key, k_bias_key, v_bias_key]):
433-
q_bias = sd[q_bias_key]
434-
k_bias = sd[k_bias_key]
435-
v_bias = sd[v_bias_key]
436-
load_sd[prefix + "self_attn.qkv_proj.bias"] = _shard_and_concat(
437-
[q_bias, k_bias, v_bias], dim=0, tp=vllm_tp
438-
)
439-
440-
# MLP gate_up_proj fusion
441-
gate = sd[prefix + "mlp.gate_proj.weight"]
442-
up = sd[prefix + "mlp.up_proj.weight"]
443-
load_sd[prefix + "mlp.gate_up_proj.weight"] = _shard_and_concat(
444-
[gate, up], dim=0, tp=vllm_tp
445-
)
446-
447-
# Untested: MLP gate_up_proj fusion - handle bias if present
448-
gate_bias_key = prefix + "mlp.gate_proj.bias"
449-
up_bias_key = prefix + "mlp.up_proj.bias"
450-
451-
if all(key in sd for key in [gate_bias_key, up_bias_key]):
452-
gate_bias = sd[gate_bias_key]
453-
up_bias = sd[up_bias_key]
454-
# Same sharding has to happen here
455-
load_sd[prefix + "mlp.gate_up_proj.bias"] = _shard_and_concat(
456-
[gate_bias, up_bias], dim=0, tp=vllm_tp
457-
)
458-
459-
return load_sd

tests/unit_tests/test_trainer.py

Lines changed: 0 additions & 102 deletions
This file was deleted.

0 commit comments

Comments
 (0)