1616from typing import Any , Callable , Dict , Iterator , Optional , Tuple , Union
1717
1818import torch
19-
2019from megatron .core .models .gpt import GPTModel
2120from megatron .core .packed_seq_params import PackedSeqParams
2221from megatron .core .parallel_state import (
3736)
3837from nemo_rl .models .megatron .data import ProcessedMicrobatch
3938
40-
4139# Union type for any post-processing function (defined after classes below)
4240PostProcessingFunction = Union [
4341 "LossPostProcessor" ,
@@ -56,9 +54,8 @@ def model_forward(
5654 packed_seq_params : Optional [PackedSeqParams ] = None ,
5755 defer_fp32_logits : Optional [bool ] = None ,
5856) -> torch .Tensor :
59- """
60- Perform a single forward pass through the model.
61-
57+ """Perform a single forward pass through the model.
58+
6259 Args:
6360 model: The model to run forward pass on
6461 data_dict (BatchedDataDict): Dictionary containing batch data
@@ -68,7 +65,7 @@ def model_forward(
6865 attention_mask: Attention mask for the sequence
6966 packed_seq_params: Parameters for packed sequences (optional)
7067 defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
71-
68+
7269 Returns:
7370 torch.Tensor: Output tensor from the model (logits)
7471 """
@@ -84,7 +81,7 @@ def model_forward(
8481 additional_kwargs ["packed_seq_params" ] = packed_seq_params
8582 if defer_fp32_logits :
8683 additional_kwargs ["fp32_output" ] = False
87- #with straggler_timer:
84+ # with straggler_timer:
8885 output_tensor = model (
8986 input_ids = input_ids_cp_sharded ,
9087 position_ids = position_ids ,
@@ -95,14 +92,12 @@ def model_forward(
9592
9693 # Apply temperature scaling to logits for training
9794 # This matches the dtensor worker's _apply_temperature_scaling in the train method
98- if (
99- "generation" in cfg
100- and cfg ["generation" ] is not None
101- ):
95+ if "generation" in cfg and cfg ["generation" ] is not None :
10296 output_tensor .div_ (cfg ["generation" ]["temperature" ])
103-
97+
10498 return output_tensor
10599
100+
106101def forward_with_post_processing_fn (
107102 data_iterator : Iterator [ProcessedMicrobatch ],
108103 model : GPTModel ,
@@ -112,22 +107,21 @@ def forward_with_post_processing_fn(
112107 global_valid_seqs : Optional [torch .Tensor ] = None ,
113108 global_valid_toks : Optional [torch .Tensor ] = None ,
114109) -> Tuple [torch .Tensor , Callable ]:
115- """
116- Perform forward pass with pre-processed microbatch and return output tensor and post-processing function.
117-
110+ """Perform forward pass with pre-processed microbatch and return output tensor and post-processing function.
111+
118112 This function takes a pre-processed microbatch (with sequence packing already handled),
119113 runs the forward step through the model, and prepares a post-processing function for
120114 post-processing the outputs.
121-
115+
122116 Args:
123117 data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed)
124- model: The model to run forward pass on
118+ model: The model to run forward pass on
125119 cfg (dict): Configuration dictionary
126120 post_processing_fn: Post-processing function to post-process the logits
127121 defer_fp32_logits: Whether to defer FP32 conversion of logits
128122 global_valid_seqs: Global valid sequence count for loss normalization
129123 global_valid_toks: Global valid token count for loss normalization
130-
124+
131125 Returns:
132126 tuple: (output_tensor, post_processing_fn_wrapped)
133127 - output_tensor: Raw model outputs (logits)
@@ -177,10 +171,13 @@ def forward_with_post_processing_fn(
177171 cu_seqlens_padded = cu_seqlens_padded ,
178172 )
179173 else :
180- raise TypeError (f"Unknown post-processing function type: { type (post_processing_fn )} " )
174+ raise TypeError (
175+ f"Unknown post-processing function type: { type (post_processing_fn )} "
176+ )
181177
182178 return output_tensor , post_processing_fn_wrapped
183179
180+
184181def megatron_forward_backward (
185182 model : GPTModel ,
186183 cfg : Dict [str , Any ],
@@ -195,13 +192,12 @@ def megatron_forward_backward(
195192 global_valid_toks : Optional [torch .Tensor ] = None ,
196193 do_not_average_loss : bool = False ,
197194) -> Any :
198- """
199- Execute forward and backward passes using Megatron's utilities.
200-
195+ """Execute forward and backward passes using Megatron's utilities.
196+
201197 This is the main training loop function that coordinates forward and backward
202198 passes across multiple microbatches using Megatron's pipeline parallel
203199 execution framework.
204-
200+
205201 Args:
206202 model: The model to train
207203 cfg (dict): Configuration dictionary
@@ -214,7 +210,7 @@ def megatron_forward_backward(
214210 defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
215211 global_valid_seqs: Global valid sequence count for loss normalization
216212 global_valid_toks: Global valid token count for loss normalization
217-
213+
218214 Returns:
219215 BatchedDataDict: Results from the forward/backward execution
220216 """
@@ -239,8 +235,8 @@ def megatron_forward_backward(
239235 do_not_average_loss = do_not_average_loss ,
240236 )
241237
242- class LossPostProcessor :
243238
239+ class LossPostProcessor :
244240 def __init__ (
245241 self ,
246242 loss_fn : LossFunction ,
@@ -250,16 +246,16 @@ def __init__(
250246 self .loss_fn = loss_fn
251247 self .cfg = cfg
252248 self .cp_normalize = cp_normalize
253-
254- def __call__ (self ,
249+
250+ def __call__ (
251+ self ,
255252 data_dict : BatchedDataDict [Any ],
256253 packed_seq_params : Optional [PackedSeqParams ] = None ,
257254 global_valid_seqs : Optional [torch .Tensor ] = None ,
258255 global_valid_toks : Optional [torch .Tensor ] = None ,
259256 ) -> Callable [[torch .Tensor ], Tuple [torch .Tensor , Dict [str , Any ]]]:
260- """
261- Create a loss post-processing function for training.
262-
257+ """Create a loss post-processing function for training.
258+
263259 This function wraps a loss function with the necessary context and parameters
264260 to compute loss and metrics from model outputs. It handles sequence packing
265261 and context parallelism normalization.
@@ -274,7 +270,6 @@ def __call__(self,
274270 Returns:
275271 Callable: Function that takes output tensor and returns (loss, metrics) tuple
276272 """
277-
278273 loss_fn = self .loss_fn
279274 pack_sequences = self .cfg ["sequence_packing" ]["enabled" ]
280275 if pack_sequences and packed_seq_params is not None :
@@ -307,8 +302,8 @@ def _div_by_cp_size(*args, **kwargs):
307302
308303 return loss_fn_wrapped
309304
310- class LogprobsPostProcessor :
311305
306+ class LogprobsPostProcessor :
312307 def __init__ (self , cfg : Dict [str , Any ]):
313308 self .cfg = cfg
314309
@@ -318,8 +313,7 @@ def __call__(
318313 input_ids : torch .Tensor ,
319314 cu_seqlens_padded : torch .Tensor ,
320315 ) -> Callable [[torch .Tensor ], Tuple [torch .Tensor , Dict [str , torch .Tensor ]]]:
321- """
322- Create a post-processing function that computes token log probabilities.
316+ """Create a post-processing function that computes token log probabilities.
323317
324318 This function returns a processor that takes model logits and converts them
325319 to token-level log probabilities, handling both packed and unpacked sequences.
@@ -370,11 +364,11 @@ def processor_fn_inner(output_tensor):
370364 return torch .tensor (0.0 , device = token_logprobs .device ), {
371365 "logprobs" : token_logprobs
372366 }
367+
373368 return processor_fn_inner
374369
375370
376371class TopkLogitsPostProcessor :
377-
378372 def __init__ (self , cfg : Dict [str , Any ], k : int ):
379373 self .cfg = cfg
380374 self .k = k
@@ -384,8 +378,7 @@ def __call__(
384378 data_dict : BatchedDataDict [Any ],
385379 cu_seqlens_padded : torch .Tensor ,
386380 ) -> Callable [[torch .Tensor ], Tuple [torch .Tensor , Dict [str , torch .Tensor ]]]:
387- """
388- Create a post-processing function that computes top-k logits and indices.
381+ """Create a post-processing function that computes top-k logits and indices.
389382
390383 This function returns a processor that extracts the top-k highest logits
391384 and their corresponding vocabulary indices from model outputs. It handles
@@ -396,10 +389,9 @@ def __call__(
396389 cu_seqlens_padded: Cumulative sequence lengths for packed sequences
397390
398391 Returns:
399- Callable: Function that takes output tensor and returns
392+ Callable: Function that takes output tensor and returns
400393 (dummy_loss, {"topk_logits": values, "topk_indices": indices})
401394 """
402-
403395 pack = self .cfg ["sequence_packing" ]["enabled" ]
404396 cp_size = self .cfg ["megatron_cfg" ]["context_parallel_size" ]
405397 unpacked_seqlen = data_dict ["input_ids" ].shape [1 ]
@@ -521,4 +513,5 @@ def processor_fn_inner(output_tensor):
521513 "topk_logits" : topk_vals_full ,
522514 "topk_indices" : topk_idx_full ,
523515 }
524- return processor_fn_inner
516+
517+ return processor_fn_inner
0 commit comments