2929from megatron .core .pipeline_parallel import get_forward_backward_func
3030from megatron .core .utils import StragglerDetector
3131
32+ from nemo_rl .algorithms .logits_sampling_utils import TrainingSamplingParams
3233from nemo_rl .algorithms .loss import (
3334 SequencePackingLossWrapper ,
3435 prepare_loss_input ,
5657def model_forward (
5758 model : GPTModel ,
5859 data_dict : BatchedDataDict [Any ],
59- cfg : PolicyConfig ,
6060 input_ids_cp_sharded : torch .Tensor ,
6161 position_ids : torch .Tensor ,
6262 attention_mask : torch .Tensor ,
@@ -106,27 +106,26 @@ def model_forward(
106106
107107
108108def apply_temperature_scaling (
109- logits : torch .Tensor ,
110- cfg : PolicyConfig ,
109+ logits : torch .Tensor , sampling_params : Optional [TrainingSamplingParams ]
111110) -> torch .Tensor :
112111 """Apply temperature scaling to logits.
113112
114113 Args:
115114 logits: Logits tensor to scale
116- cfg: Policy configuration containing generation settings
115+ sampling_params: Sampling parameters
117116
118117 Returns:
119118 torch.Tensor: Temperature-scaled logits
120119 """
121- if "generation" in cfg and cfg [ "generation" ] is not None :
122- logits .div_ (cfg [ "generation" ][ " temperature" ] )
120+ if sampling_params is not None and sampling_params . temperature != 1.0 :
121+ logits .div_ (sampling_params . temperature )
123122 return logits
124123
125124
126125def forward_with_post_processing_fn (
127126 data_iterator : Iterator [ProcessedMicrobatch ],
128127 model : GPTModel ,
129- cfg : PolicyConfig ,
128+ sampling_params : TrainingSamplingParams ,
130129 post_processing_fn : PostProcessingFunction ,
131130 defer_fp32_logits : Optional [bool ] = False ,
132131 global_valid_seqs : Optional [torch .Tensor ] = None ,
@@ -142,7 +141,7 @@ def forward_with_post_processing_fn(
142141 Args:
143142 data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed)
144143 model: The model to run forward pass on
145- cfg: Policy configuration dictionary
144+ sampling_params: Sampling parameters
146145 post_processing_fn: Post-processing function to post-process the logits
147146 defer_fp32_logits: Whether to defer FP32 conversion of logits
148147 global_valid_seqs: Global valid sequence count for loss normalization
@@ -169,7 +168,6 @@ def forward_with_post_processing_fn(
169168 output_tensor = model_forward (
170169 model = model ,
171170 data_dict = data_dict ,
172- cfg = cfg ,
173171 input_ids_cp_sharded = input_ids_cp_sharded ,
174172 position_ids = position_ids ,
175173 attention_mask = attention_mask ,
@@ -187,7 +185,7 @@ def forward_with_post_processing_fn(
187185 # Temperature scaling is element-wise, directly applying it here.
188186 # Other sampling parameters like top-k and top-p need the logits from whole vocabulary,
189187 # so applying them when gathering logits from vocab parallel (called in LossPostProcessor and LogprobsPostProcessor).
190- apply_temperature_scaling (output_tensor , cfg )
188+ apply_temperature_scaling (output_tensor , sampling_params )
191189
192190 # Use type checking to dispatch to the correct post-processing method
193191 if isinstance (post_processing_fn , LossPostProcessor ):
@@ -218,7 +216,7 @@ def forward_with_post_processing_fn(
218216
219217def megatron_forward_backward (
220218 model : GPTModel ,
221- cfg : PolicyConfig ,
219+ sampling_params : TrainingSamplingParams ,
222220 data_iterator : Iterator [ProcessedMicrobatch ],
223221 num_microbatches : int ,
224222 seq_length : int ,
@@ -238,7 +236,7 @@ def megatron_forward_backward(
238236
239237 Args:
240238 model: The model to train
241- cfg: Policy configuration dictionary
239+ sampling_params: Sampling parameters
242240 data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed)
243241 num_microbatches: Number of microbatches to process
244242 seq_length: Sequence length
@@ -255,7 +253,7 @@ def megatron_forward_backward(
255253 """
256254 forward_step = partial (
257255 forward_with_post_processing_fn ,
258- cfg = cfg ,
256+ sampling_params = sampling_params ,
259257 post_processing_fn = post_processing_fn ,
260258 defer_fp32_logits = defer_fp32_logits ,
261259 global_valid_seqs = global_valid_seqs ,
@@ -282,11 +280,13 @@ def __init__(
282280 cfg : PolicyConfig ,
283281 num_microbatches : int = 1 ,
284282 cp_normalize : bool = True ,
283+ sampling_params : Optional [TrainingSamplingParams ] = None ,
285284 ):
286285 self .loss_fn = loss_fn
287286 self .cfg = cfg
288287 self .num_microbatches = num_microbatches
289288 self .cp_normalize = cp_normalize
289+ self .sampling_params = sampling_params
290290
291291 def __call__ (
292292 self ,
@@ -310,12 +310,17 @@ def __call__(
310310 Returns:
311311 Callable: Function that takes output tensor and returns (loss, metrics) tuple
312312 """
313+ # wrap prepare_loss_input with sampling_params
314+ prepare_loss_input_wrapped = partial (
315+ prepare_loss_input , sampling_params = self .sampling_params
316+ )
317+
313318 # wrap loss function with loss input preparation
314319 pack_sequences = self .cfg ["sequence_packing" ]["enabled" ]
315320 if pack_sequences and packed_seq_params is not None :
316321 loss_fn_wrapped = SequencePackingLossWrapper (
317322 loss_fn = self .loss_fn ,
318- prepare_fn = prepare_loss_input ,
323+ prepare_fn = prepare_loss_input_wrapped ,
319324 cu_seqlens_q = packed_seq_params .cu_seqlens_q ,
320325 cu_seqlens_q_padded = packed_seq_params .cu_seqlens_q_padded ,
321326 vocab_parallel_rank = get_tensor_model_parallel_rank (),
@@ -326,7 +331,7 @@ def __call__(
326331 loss_fn_wrapped = partial (
327332 wrap_loss_fn_with_input_preparation ,
328333 loss_fn = self .loss_fn ,
329- prepare_fn = prepare_loss_input ,
334+ prepare_fn = prepare_loss_input_wrapped ,
330335 vocab_parallel_rank = get_tensor_model_parallel_rank (),
331336 vocab_parallel_group = get_tensor_model_parallel_group (),
332337 context_parallel_group = get_context_parallel_group (),
@@ -365,8 +370,9 @@ def _counteract_mcore_loss_averaging(*args, **kwargs):
365370
366371
367372class LogprobsPostProcessor :
368- def __init__ (self , cfg : PolicyConfig ):
373+ def __init__ (self , cfg : PolicyConfig , sampling_params : TrainingSamplingParams ):
369374 self .cfg = cfg
375+ self .sampling_params = sampling_params
370376
371377 def __call__ (
372378 self ,
@@ -406,6 +412,7 @@ def processor_fn_inner(output_tensor):
406412 inference_only = True ,
407413 cp_group = get_context_parallel_group (),
408414 chunk_size = logprob_chunk_size ,
415+ sampling_params = self .sampling_params ,
409416 )
410417 else :
411418 token_logprobs = from_parallel_logits_to_logprobs (
@@ -416,6 +423,7 @@ def processor_fn_inner(output_tensor):
416423 tp_group = tp_grp ,
417424 inference_only = True ,
418425 chunk_size = logprob_chunk_size ,
426+ sampling_params = self .sampling_params ,
419427 )
420428
421429 # Prepend 0 logprob for first token to maintain same sequence length as input
0 commit comments