1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from contextlib import nullcontext
1516from functools import partial
1617from typing import Any , Callable , Dict , Iterator , Optional , Tuple , Union
1718
2526 get_tensor_model_parallel_rank ,
2627)
2728from megatron .core .pipeline_parallel import get_forward_backward_func
29+ from megatron .core .utils import StragglerDetector
2830
2931from nemo_rl .algorithms .loss_functions import LossFunction , SequencePackingLossWrapper
3032from nemo_rl .distributed .batched_data_dict import BatchedDataDict
3537 from_parallel_logits_to_logprobs_packed_sequences ,
3638)
3739from nemo_rl .models .megatron .data import ProcessedMicrobatch
40+ from nemo_rl .models .policy import PolicyConfig
3841
3942# Union type for any post-processing function (defined after classes below)
4043PostProcessingFunction = Union [
4750def model_forward (
4851 model : GPTModel ,
4952 data_dict : BatchedDataDict [Any ],
50- cfg : Dict [ str , Any ] ,
53+ cfg : PolicyConfig ,
5154 input_ids_cp_sharded : torch .Tensor ,
5255 position_ids : torch .Tensor ,
5356 attention_mask : torch .Tensor ,
5457 packed_seq_params : Optional [PackedSeqParams ] = None ,
55- defer_fp32_logits : Optional [bool ] = None ,
58+ defer_fp32_logits : Optional [bool ] = False ,
59+ straggler_timer : Optional [StragglerDetector ] = None ,
5660) -> torch .Tensor :
5761 """Perform a single forward pass through the model.
5862
5963 Args:
6064 model: The model to run forward pass on
61- data_dict (BatchedDataDict) : Dictionary containing batch data
62- cfg (dict): Configuration dictionary
65+ data_dict: Dictionary containing batch data
66+ cfg: Policy configuration dictionary
6367 input_ids_cp_sharded: Context-parallel sharded input token IDs
6468 position_ids: Position IDs for tokens
6569 attention_mask: Attention mask for the sequence
6670 packed_seq_params: Parameters for packed sequences (optional)
67- defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
71+ defer_fp32_logits: Whether to skip the conversion of logits to fp32
72+ straggler_timer: Straggler detector for profiling the forward pass
6873
6974 Returns:
7075 torch.Tensor: Output tensor from the model (logits)
@@ -81,31 +86,48 @@ def model_forward(
8186 additional_kwargs ["packed_seq_params" ] = packed_seq_params
8287 if defer_fp32_logits :
8388 additional_kwargs ["fp32_output" ] = False
84- # with straggler_timer:
85- output_tensor = model (
86- input_ids = input_ids_cp_sharded ,
87- position_ids = position_ids ,
88- attention_mask = attention_mask ,
89- ** additional_kwargs ,
90- ** multimodal_data ,
91- )
9289
93- # Apply temperature scaling to logits for training
94- # This matches the dtensor worker's _apply_temperature_scaling in the train method
95- if "generation" in cfg and cfg ["generation" ] is not None :
96- output_tensor .div_ (cfg ["generation" ]["temperature" ])
90+ with straggler_timer () if straggler_timer is not None else nullcontext ():
91+ output_tensor = model (
92+ input_ids = input_ids_cp_sharded ,
93+ position_ids = position_ids ,
94+ attention_mask = attention_mask ,
95+ ** additional_kwargs ,
96+ ** multimodal_data ,
97+ )
98+
99+ apply_temperature_scaling (output_tensor , cfg )
97100
98101 return output_tensor
99102
100103
104+ def apply_temperature_scaling (
105+ logits : torch .Tensor ,
106+ cfg : PolicyConfig ,
107+ ) -> torch .Tensor :
108+ """Apply temperature scaling to logits.
109+
110+ Args:
111+ logits: Logits tensor to scale
112+ cfg: Policy configuration containing generation settings
113+
114+ Returns:
115+ torch.Tensor: Temperature-scaled logits
116+ """
117+ if "generation" in cfg and cfg ["generation" ] is not None :
118+ logits .div_ (cfg ["generation" ]["temperature" ])
119+ return logits
120+
121+
101122def forward_with_post_processing_fn (
102123 data_iterator : Iterator [ProcessedMicrobatch ],
103124 model : GPTModel ,
104- cfg : Dict [ str , Any ] ,
125+ cfg : PolicyConfig ,
105126 post_processing_fn : PostProcessingFunction ,
106- defer_fp32_logits : Optional [bool ] = True ,
127+ defer_fp32_logits : Optional [bool ] = False ,
107128 global_valid_seqs : Optional [torch .Tensor ] = None ,
108129 global_valid_toks : Optional [torch .Tensor ] = None ,
130+ straggler_timer : Optional [StragglerDetector ] = None ,
109131) -> Tuple [torch .Tensor , Callable ]:
110132 """Perform forward pass with pre-processed microbatch and return output tensor and post-processing function.
111133
@@ -116,11 +138,12 @@ def forward_with_post_processing_fn(
116138 Args:
117139 data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed)
118140 model: The model to run forward pass on
119- cfg (dict): Configuration dictionary
141+ cfg: Policy configuration dictionary
120142 post_processing_fn: Post-processing function to post-process the logits
121143 defer_fp32_logits: Whether to defer FP32 conversion of logits
122144 global_valid_seqs: Global valid sequence count for loss normalization
123145 global_valid_toks: Global valid token count for loss normalization
146+ straggler_timer: Straggler detector for profiling the forward pass
124147
125148 Returns:
126149 tuple: (output_tensor, post_processing_fn_wrapped)
@@ -140,14 +163,15 @@ def forward_with_post_processing_fn(
140163 cu_seqlens_padded = processed_mb .cu_seqlens_padded
141164
142165 output_tensor = model_forward (
143- model ,
144- data_dict ,
145- cfg ,
146- input_ids_cp_sharded ,
147- position_ids ,
148- attention_mask ,
149- packed_seq_params ,
150- defer_fp32_logits ,
166+ model = model ,
167+ data_dict = data_dict ,
168+ cfg = cfg ,
169+ input_ids_cp_sharded = input_ids_cp_sharded ,
170+ position_ids = position_ids ,
171+ attention_mask = attention_mask ,
172+ packed_seq_params = packed_seq_params ,
173+ defer_fp32_logits = defer_fp32_logits ,
174+ straggler_timer = straggler_timer ,
151175 )
152176
153177 ## calling post_processing_fn will return a function that takes the output tensor and returns a tuple of (loss, metrics)
@@ -180,17 +204,18 @@ def forward_with_post_processing_fn(
180204
181205def megatron_forward_backward (
182206 model : GPTModel ,
183- cfg : Dict [ str , Any ] ,
207+ cfg : PolicyConfig ,
184208 data_iterator : Iterator [ProcessedMicrobatch ],
185209 num_microbatches : int ,
186210 seq_length : int ,
187211 mbs : int ,
188212 post_processing_fn : PostProcessingFunction ,
189213 forward_only : bool = False ,
190- defer_fp32_logits : Optional [bool ] = None ,
214+ defer_fp32_logits : Optional [bool ] = False ,
191215 global_valid_seqs : Optional [torch .Tensor ] = None ,
192216 global_valid_toks : Optional [torch .Tensor ] = None ,
193217 do_not_average_loss : bool = False ,
218+ straggler_timer : Optional [StragglerDetector ] = None ,
194219) -> Any :
195220 """Execute forward and backward passes using Megatron's utilities.
196221
@@ -200,19 +225,21 @@ def megatron_forward_backward(
200225
201226 Args:
202227 model: The model to train
203- cfg (dict): Configuration dictionary
228+ cfg: Policy configuration dictionary
204229 data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed)
205- num_microbatches (int) : Number of microbatches to process
206- seq_length (int) : Sequence length
207- mbs (int) : Micro batch size
230+ num_microbatches: Number of microbatches to process
231+ seq_length: Sequence length
232+ mbs: Micro batch size
208233 post_processing_fn: Post-processing function to post-process the logits
209- forward_only (bool) : If True, skip backward pass
210- defer_fp32_logits (Optional[bool]) : Whether to skip the conversion of logits to fp32
234+ forward_only: If True, skip backward pass
235+ defer_fp32_logits: Whether to skip the conversion of logits to fp32
211236 global_valid_seqs: Global valid sequence count for loss normalization
212237 global_valid_toks: Global valid token count for loss normalization
238+ do_not_average_loss: If True, do not average loss across microbatches
239+ straggler_timer: Straggler detector for profiling the forward pass
213240
214241 Returns:
215- BatchedDataDict: Results from the forward/backward execution
242+ Results from the forward/backward execution
216243 """
217244 forward_step = partial (
218245 forward_with_post_processing_fn ,
@@ -221,6 +248,7 @@ def megatron_forward_backward(
221248 defer_fp32_logits = defer_fp32_logits ,
222249 global_valid_seqs = global_valid_seqs ,
223250 global_valid_toks = global_valid_toks ,
251+ straggler_timer = straggler_timer ,
224252 )
225253 forward_backward_func = get_forward_backward_func ()
226254 return forward_backward_func (
@@ -240,7 +268,7 @@ class LossPostProcessor:
240268 def __init__ (
241269 self ,
242270 loss_fn : LossFunction ,
243- cfg : Dict [ str , Any ] ,
271+ cfg : PolicyConfig ,
244272 cp_normalize : bool = True ,
245273 ):
246274 self .loss_fn = loss_fn
@@ -261,11 +289,10 @@ def __call__(
261289 and context parallelism normalization.
262290
263291 Args:
264- loss_fn: The base loss function to wrap
265- cfg (dict): Configuration dictionary
266- data_dict: Batched data dictionary
292+ data_dict: Batched data dictionary for the current microbatch
267293 packed_seq_params: Parameters for packed sequences (optional)
268- cp_normalize (bool): Whether to normalize by context parallel size
294+ global_valid_seqs: Global valid sequence count for loss normalization
295+ global_valid_toks: Global valid token count for loss normalization
269296
270297 Returns:
271298 Callable: Function that takes output tensor and returns (loss, metrics) tuple
@@ -304,7 +331,7 @@ def _div_by_cp_size(*args, **kwargs):
304331
305332
306333class LogprobsPostProcessor :
307- def __init__ (self , cfg : Dict [ str , Any ] ):
334+ def __init__ (self , cfg : PolicyConfig ):
308335 self .cfg = cfg
309336
310337 def __call__ (
@@ -369,7 +396,7 @@ def processor_fn_inner(output_tensor):
369396
370397
371398class TopkLogitsPostProcessor :
372- def __init__ (self , cfg : Dict [ str , Any ] , k : int ):
399+ def __init__ (self , cfg : PolicyConfig , k : int ):
373400 self .cfg = cfg
374401 self .k = k
375402
0 commit comments