66
77import logging
88import os
9- import shutil
109
1110import time
1211from collections .abc import Mapping
5352logger .setLevel (logging .DEBUG )
5453
5554
56- def cleanup_old_weight_versions (
57- state_dict_key : str ,
58- delim : str ,
59- current_policy_version : int ,
60- ) -> None :
61- """Delete old weight versions, keeping only current and N-1 versions.
62-
63- TODO - issues/194: provide a more robust way to handle eviction.
64-
65- Args:
66- state_dict_key: The base key for state dict storage
67- delim: The delimiter used between key and version
68- current_policy_version: The current policy version to keep
69- """
70- if current_policy_version <= 1 :
71- return # No cleanup needed for versions 0 or 1
72-
73- prefix = f"{ state_dict_key } { delim } "
74- current_weights = f"{ prefix } { current_policy_version } "
75- previous_weights = f"{ prefix } { current_policy_version - 1 } "
76-
77- # Find all weight directories that match our pattern
78- parent_dir = os .path .dirname (prefix ) or "."
79- if os .path .exists (parent_dir ):
80- for item in os .listdir (parent_dir ):
81- item_path = os .path .join (parent_dir , item )
82- if (
83- item .startswith (os .path .basename (prefix ))
84- and item != os .path .basename (current_weights )
85- and item != os .path .basename (previous_weights )
86- and os .path .isdir (item_path )
87- ):
88- try :
89- shutil .rmtree (item_path , ignore_errors = True )
90- logger .debug (f"Removed old weights at { item_path } " )
91- except OSError as e :
92- logger .debug (f"Error deleting { item_path } : { e } " )
93-
94-
9555@dataclass
9656class RLTrainer (ForgeActor ):
9757 """A reinforcement learning trainer actor for policy optimization training.
@@ -135,19 +95,10 @@ class RLTrainer(ForgeActor):
13595 dcp_path : str = "forge_dcp_tmp"
13696
13797 def __post_init__ (self ):
138- """Initializes config types and env variables.
139-
140- torchrun normally hands env variables, but we need to do it ourselves
141- in monarch for now.
142-
143- """
14498 super ().__init__ ()
145-
14699 if self .use_dcp :
147- # DCP specific optimization
148100 torch .serialization .set_crc32_options (False )
149101
150- # Instantiate dict fields
151102 for f in fields (self ):
152103 attr = getattr (self , f .name )
153104 if isinstance (attr , Mapping ):
@@ -184,73 +135,23 @@ def forward_backward(
184135 ) -> Tensor :
185136 model_parts = self .engine .model_parts
186137 parallel_dims = self .engine .parallel_dims
187-
188- # apply context parallelism if cp is enabled
189- # ensure CP handles the separate freqs_cis buffer for each pp stage
190- # if getattr(self.engine.model_args, "use_flex_attn", False):
191- # cp_mesh = (
192- # parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
193- # )
194- # init_attention_mask(
195- # inputs, self.engine.tokenizer.base_tokenizer.eos_id, cp_mesh
196- # )
197-
198- # optional_context_parallel_ctx = (
199- # dist_utils.create_context_parallel_ctx(
200- # cp_mesh=parallel_dims.world_mesh["cp"],
201- # cp_buffers=[inputs, targets] + [m.freqs_cis for m in model_parts],
202- # cp_seq_dims=[1, 1] + [0 for _ in model_parts],
203- # cp_no_restore_buffers={inputs, targets},
204- # cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
205- # )
206- # if parallel_dims.cp_enabled
207- # else None
208- # )
209138 optional_context_parallel_ctx = None
210-
211139 if parallel_dims .pp_enabled :
212140 raise NotImplementedError ("PP not implemented yet" )
213- # TODO implement PP
214- # # Pipeline Parallel forward / backward inside step() call
215- # with self.train_context(optional_context_parallel_ctx):
216- # targets, losses = (
217- # (labels, []) if self.pp_has_last_stage else (None, None)
218- # )
219- # if self.pp_has_first_stage:
220- # self.pp_schedule.step(
221- # inputs, target=targets, losses=losses, input_batch=inputs
222- # )
223- # else:
224- # self.pp_schedule.step(
225- # target=targets, losses=losses, input_batch=inputs
226- # )
227- #
228- # # accumulate losses across pipeline microbatches
229- # # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
230- # loss = (
231- # torch.mean(torch.stack(losses)).to(self.device)
232- # if self.pp_has_last_stage
233- # else torch.tensor([-1.0], device=self.device)
234- # )
235141 else :
236- # Non-PP forward / backward
237142 with self .engine .train_context (optional_context_parallel_ctx ):
238143 assert len (model_parts ) == 1
239144 with self .engine .maybe_enable_amp :
240145 logits = model_parts [0 ](** inputs )
241146 loss = self .loss (logits , ** targets )
242- # need to free to before bwd to avoid peaking memory
243- del logits
147+ del logits # Free to before bwd to avoid peaking memory
244148 loss .backward ()
245-
246149 return loss
247150
248151 @endpoint
249152 async def train_step (
250153 self , inputs : list [dict [str , Tensor ]], targets : list [dict [str , Tensor ]]
251154 ) -> float :
252-
253- # Log timesteps
254155 t = Tracer ("rl_trainer_perf/step" , timer = "gpu" , track_memory = True )
255156 t .start ()
256157
@@ -259,18 +160,12 @@ async def train_step(
259160 local_targets = targets [self .engine .dp_rank ]
260161 batch_to_device (local_inputs , self .engine .device )
261162 batch_to_device (local_targets , self .engine .device )
262- # compute policy logprobs
263- # TODO implement gradient accumulation
264- # with GradientAccumulation(
265- # self.gradient_accumulation_steps,
266- # self.model,
267- # self.data_parallel_size,
268- # ) as grad_acc:
163+
269164 loss = self .forward_backward (local_inputs , local_targets )
270165 torch .distributed .all_reduce (loss )
166+
271167 t .step ("forward_backward" )
272168
273- # Get learning rate from scheduler
274169 current_lr = (
275170 self .engine .lr_schedulers .get_last_lr ()[0 ]
276171 if hasattr (self .engine .lr_schedulers , "get_last_lr" )
@@ -283,13 +178,11 @@ async def train_step(
283178 self .engine .lr_schedulers .step ()
284179 t .step ("optimizer_step" )
285180
286- # Record training metrics
287181 # TODO: delete item() to avoid cpu-gpu sync
288- loss = loss .detach ().cpu (). item ()
182+ loss = loss .detach ().item ()
289183 record_metric ("rl_trainer/count_training_steps" , 1 , Reduce .SUM )
290184 record_metric ("rl_trainer/avg_grpo_loss" , loss , Reduce .MEAN )
291185
292- # TODO: Extract actual KL divergence and policy entropy from the loss computation
293186 # These are placeholder values until the loss function exposes these metrics
294187 # record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN)
295188 # record_metric("rl_trainer/step/std_kl_divergence", 0.0, Reduce.STD)
@@ -351,109 +244,3 @@ async def push_weights(self, policy_version: int) -> None:
351244 async def cleanup (self ) -> None :
352245 if self .engine .checkpointer :
353246 self .engine .checkpointer .close ()
354-
355-
356- def _shard_and_concat (sources : list [torch .Tensor ], dim : int , tp : int ) -> torch .Tensor :
357- """Shard and concatenate tensors along a given dimension.
358-
359- Args:
360- source (list[torch.Tensor]): List of tensors to shard and concatenate.
361- dim (int): Dimension along which to shard and concatenate.
362- tp (int): Number of tensor parallel groups.
363-
364- Returns:
365- torch.Tensor: Concatenated tensor.
366- """
367- sharded_sources = []
368- for source in sources :
369- sharded_sources .append (torch .chunk (source , tp , dim = dim ))
370-
371- combined_shards = []
372- for shard_idx in range (tp ):
373- combined = torch .cat ([s [shard_idx ] for s in sharded_sources ], dim = dim )
374- combined_shards .append (combined )
375- return torch .cat (combined_shards , dim = dim )
376-
377-
378- def _qwen3_hf_to_vllm (
379- sd : dict [str , torch .Tensor ], num_layers : int , vllm_tp : int
380- ) -> dict [str , torch .Tensor ]:
381- """Convert transformers state dict to vLLM format. Specifically, this fuses
382- QKV projection and MLP gate_up_proj layers.
383-
384- Args:
385- sd (dict): State dict from HF model.
386- num_layers (int): Number of layers in the model.
387-
388- Returns:
389- dict: State dict in vLLM format.
390- """
391- load_sd = {}
392-
393- def unwrap (t ):
394- """Unwrap a DTensor to a Tensor."""
395- return t .full_tensor () if isinstance (t , torch .distributed .tensor .DTensor ) else t
396-
397- for key in sd .keys ():
398- sd [key ] = unwrap (sd [key ]).cpu ()
399-
400- # Copy over directly mapped keys
401- for k in sd :
402- if any (
403- x in k
404- for x in [
405- "down_proj" ,
406- "input_layernorm" ,
407- "post_attention_layernorm" ,
408- "o_proj" ,
409- "norm.weight" ,
410- "embed_tokens.weight" ,
411- "lm_head.weight" ,
412- ]
413- ):
414- load_sd [k ] = sd [k ]
415-
416- for i in range (num_layers ):
417- prefix = f"model.layers.{ i } ."
418- # QKV fusion
419- q = sd [prefix + "self_attn.q_proj.weight" ]
420- k = sd [prefix + "self_attn.k_proj.weight" ]
421- v = sd [prefix + "self_attn.v_proj.weight" ]
422-
423- load_sd [prefix + "self_attn.qkv_proj.weight" ] = _shard_and_concat (
424- [q , k , v ], dim = 0 , tp = vllm_tp
425- )
426-
427- # Untested: QKV fusion - handle bias if present
428- q_bias_key = prefix + "self_attn.q_proj.bias"
429- k_bias_key = prefix + "self_attn.k_proj.bias"
430- v_bias_key = prefix + "self_attn.v_proj.bias"
431-
432- if all (key in sd for key in [q_bias_key , k_bias_key , v_bias_key ]):
433- q_bias = sd [q_bias_key ]
434- k_bias = sd [k_bias_key ]
435- v_bias = sd [v_bias_key ]
436- load_sd [prefix + "self_attn.qkv_proj.bias" ] = _shard_and_concat (
437- [q_bias , k_bias , v_bias ], dim = 0 , tp = vllm_tp
438- )
439-
440- # MLP gate_up_proj fusion
441- gate = sd [prefix + "mlp.gate_proj.weight" ]
442- up = sd [prefix + "mlp.up_proj.weight" ]
443- load_sd [prefix + "mlp.gate_up_proj.weight" ] = _shard_and_concat (
444- [gate , up ], dim = 0 , tp = vllm_tp
445- )
446-
447- # Untested: MLP gate_up_proj fusion - handle bias if present
448- gate_bias_key = prefix + "mlp.gate_proj.bias"
449- up_bias_key = prefix + "mlp.up_proj.bias"
450-
451- if all (key in sd for key in [gate_bias_key , up_bias_key ]):
452- gate_bias = sd [gate_bias_key ]
453- up_bias = sd [up_bias_key ]
454- # Same sharding has to happen here
455- load_sd [prefix + "mlp.gate_up_proj.bias" ] = _shard_and_concat (
456- [gate_bias , up_bias ], dim = 0 , tp = vllm_tp
457- )
458-
459- return load_sd
0 commit comments