3232from torch import nn
3333from torch .distributed .tensor import DTensor , Shard
3434
35- from nemo_rl .algorithms .logits_sampling_utils import TrainingSamplingParams
35+ from nemo_rl .algorithms .logits_sampling_utils import (
36+ TrainingSamplingParams ,
37+ apply_top_k_top_p ,
38+ need_top_k_or_top_p_filtering ,
39+ )
3640from nemo_rl .algorithms .loss import SequencePackingLossWrapper , prepare_loss_input
3741from nemo_rl .algorithms .loss .interfaces import LossFunction
3842from nemo_rl .distributed .batched_data_dict import BatchedDataDict
@@ -124,20 +128,42 @@ def extract_logits(
124128
125129
126130def apply_temperature_scaling (
127- logits : torch .Tensor ,
128- cfg : PolicyConfig ,
131+ logits : torch .Tensor , sampling_params : Optional [TrainingSamplingParams ]
129132) -> torch .Tensor :
130133 """Apply temperature scaling to logits.
131134
132135 Args:
133136 logits: Logits tensor to scale
134- cfg: Configuration dictionary containing generation settings
137+ sampling_params: Sampling parameters
135138
136139 Returns:
137140 torch.Tensor: Temperature-scaled logits
138141 """
139- if "generation" in cfg and cfg ["generation" ] is not None :
140- logits .div_ (cfg ["generation" ]["temperature" ])
142+ if sampling_params is not None and sampling_params .temperature != 1.0 :
143+ logits .div_ (sampling_params .temperature )
144+ return logits
145+
146+
147+ def apply_top_k_top_p_filtering_for_local_logits (
148+ logits : torch .Tensor , sampling_params : Optional [TrainingSamplingParams ]
149+ ) -> torch .Tensor :
150+ """Apply top-k and top-p filtering to the non-distributed logits.
151+
152+ Args:
153+ logits: Logits tensor to filter
154+ sampling_params: Sampling parameters
155+
156+ Returns:
157+ torch.Tensor: Filtered logits
158+ """
159+ if sampling_params is not None and need_top_k_or_top_p_filtering (
160+ sampling_params .top_k , sampling_params .top_p
161+ ):
162+ logits , _ = apply_top_k_top_p (
163+ logits ,
164+ top_k = sampling_params .top_k ,
165+ top_p = sampling_params .top_p ,
166+ )
141167 return logits
142168
143169
@@ -233,7 +259,7 @@ def prepare_data_for_cp(
233259
234260def forward_with_post_processing_fn (
235261 model : nn .Module ,
236- cfg : PolicyConfig ,
262+ sampling_params : TrainingSamplingParams ,
237263 post_processing_fn : PostProcessingFunction ,
238264 processed_mb : ProcessedMicrobatch ,
239265 is_reward_model : bool = False ,
@@ -253,7 +279,7 @@ def forward_with_post_processing_fn(
253279
254280 Args:
255281 model: The model to run forward pass on
256- cfg: Configuration dictionary
282+ sampling_params: Sampling parameters
257283 post_processing_fn: Post-processing function to apply to the logits
258284 processed_mb: Pre-fetched ProcessedMicrobatch containing data and processed inputs
259285 is_reward_model: Whether this is a reward model
@@ -290,7 +316,10 @@ def forward_with_post_processing_fn(
290316 post_processing_fn ,
291317 (LossPostProcessor , LogprobsPostProcessor , TopkLogitsPostProcessor ),
292318 ):
293- logits = apply_temperature_scaling (logits , cfg )
319+ # Temperature scaling is element-wise, directly applying it here.
320+ # Other sampling parameters like top-k and top-p need the logits from whole vocabulary,
321+ # so applying them when gathering logits from vocab parallel (called in LossPostProcessor and LogprobsPostProcessor).
322+ logits = apply_temperature_scaling (logits , sampling_params )
294323
295324 # Apply the post-processing function directly based on type
296325 if isinstance (post_processing_fn , LossPostProcessor ):
@@ -558,6 +587,7 @@ def __init__(
558587 tp_mesh : Any ,
559588 cp_size : int ,
560589 enable_seq_packing : bool = False ,
590+ sampling_params : Optional [TrainingSamplingParams ] = None ,
561591 ):
562592 """Initialize LogprobsPostProcessor.
563593
@@ -568,13 +598,15 @@ def __init__(
568598 tp_mesh: Tensor parallel mesh
569599 cp_size: Context parallel size
570600 enable_seq_packing: Whether sequence packing is enabled
601+ sampling_params: Sampling parameters
571602 """
572603 self .cfg = cfg
573604 self .device_mesh = device_mesh
574605 self .cp_mesh = cp_mesh
575606 self .tp_mesh = tp_mesh
576607 self .cp_size = cp_size
577608 self .enable_seq_packing = enable_seq_packing
609+ self .sampling_params = sampling_params
578610 self .logprob_chunk_size = cfg .get ("logprob_chunk_size" , None )
579611
580612 def __call__ (
@@ -627,17 +659,21 @@ def __call__(
627659 input_ids_dtensor ,
628660 seq_index_tensor ,
629661 chunk_size = self .logprob_chunk_size ,
662+ sampling_params = self .sampling_params , # top-k and top-p filtering
630663 )
631664
632665 assert token_logprobs .shape [1 ] == seq_len - 1
633666 else :
634667 if isinstance (logits , DTensor ):
668+ # DTensor path with TP sharding
635669 token_logprobs = get_logprobs_from_vocab_parallel_logits (
636670 logits ,
637671 processed_inputs .input_ids ,
638672 chunk_size = self .logprob_chunk_size ,
673+ sampling_params = self .sampling_params , # top-k and top-p filtering
639674 )
640675 else :
676+ # Non-DTensor path (no TP sharding)
641677 token_logprobs = self ._compute_local_logprobs (
642678 logits , processed_inputs .input_ids
643679 )
@@ -703,12 +739,18 @@ def _compute_local_logprobs(
703739 (chunk_idx + 1 ) * self .logprob_chunk_size ,
704740 )
705741 chunk_logits = logits [:, chunk_start :chunk_end , :].to (torch .float32 )
742+ chunk_logits = apply_top_k_top_p_filtering_for_local_logits (
743+ chunk_logits , self .sampling_params
744+ )
706745 log_probs = torch .nn .functional .log_softmax (chunk_logits , dim = - 1 )
707746 chunked_log_probs .append (log_probs )
708747 log_probs = torch .cat (chunked_log_probs , dim = 1 )
709748 del chunked_log_probs
710749 else :
711750 logits = logits .to (torch .float32 )
751+ logits = apply_top_k_top_p_filtering_for_local_logits (
752+ logits , self .sampling_params
753+ )
712754 log_probs = torch .nn .functional .log_softmax (logits , dim = - 1 )
713755
714756 # Extract logprobs for each token in the sequence by gathering the logprob
0 commit comments