Skip to content

Commit 7931cb0

Browse files
committed
add topp topk in LogprobsPostProcessor
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 32fbfbe commit 7931cb0

File tree

4 files changed

+59
-13
lines changed

4 files changed

+59
-13
lines changed

nemo_rl/algorithms/loss/loss_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def __call__(
278278
if self.reference_policy_kl_penalty != 0:
279279
# When top-k/top-p filtering is enabled, we need special handling for KL:
280280
# - reference_policy_logprobs is computed **without** filtering (see use_reference_model)
281-
# - curr_logprobs is computed **with** filtering (for actor loss compatibility)
281+
# - curr_logprobs/prev_logprobs are computed **with** filtering (for actor loss compatibility)
282282
# - For KL, we need curr_logprobs **without** filtering to be consistent with ref logprobs
283283
# - For importance weights, we also use unfiltered curr_logprobs_for_kl since we're
284284
# reweighting samples from π_gen_filtered to π_curr_unfiltered

nemo_rl/models/automodel/train.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
from torch import nn
3333
from torch.distributed.tensor import DTensor, Shard
3434

35-
from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams
35+
from nemo_rl.algorithms.logits_sampling_utils import (
36+
TrainingSamplingParams,
37+
apply_top_k_top_p,
38+
need_top_k_or_top_p_filtering,
39+
)
3640
from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input
3741
from nemo_rl.algorithms.loss.interfaces import LossFunction
3842
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
@@ -124,20 +128,42 @@ def extract_logits(
124128

125129

126130
def apply_temperature_scaling(
127-
logits: torch.Tensor,
128-
cfg: PolicyConfig,
131+
logits: torch.Tensor, sampling_params: Optional[TrainingSamplingParams]
129132
) -> torch.Tensor:
130133
"""Apply temperature scaling to logits.
131134
132135
Args:
133136
logits: Logits tensor to scale
134-
cfg: Configuration dictionary containing generation settings
137+
sampling_params: Sampling parameters
135138
136139
Returns:
137140
torch.Tensor: Temperature-scaled logits
138141
"""
139-
if "generation" in cfg and cfg["generation"] is not None:
140-
logits.div_(cfg["generation"]["temperature"])
142+
if sampling_params is not None and sampling_params.temperature != 1.0:
143+
logits.div_(sampling_params.temperature)
144+
return logits
145+
146+
147+
def apply_top_k_top_p_filtering_for_local_logits(
148+
logits: torch.Tensor, sampling_params: Optional[TrainingSamplingParams]
149+
) -> torch.Tensor:
150+
"""Apply top-k and top-p filtering to the non-distributed logits.
151+
152+
Args:
153+
logits: Logits tensor to filter
154+
sampling_params: Sampling parameters
155+
156+
Returns:
157+
torch.Tensor: Filtered logits
158+
"""
159+
if sampling_params is not None and need_top_k_or_top_p_filtering(
160+
sampling_params.top_k, sampling_params.top_p
161+
):
162+
logits, _ = apply_top_k_top_p(
163+
logits,
164+
top_k=sampling_params.top_k,
165+
top_p=sampling_params.top_p,
166+
)
141167
return logits
142168

143169

@@ -233,7 +259,7 @@ def prepare_data_for_cp(
233259

234260
def forward_with_post_processing_fn(
235261
model: nn.Module,
236-
cfg: PolicyConfig,
262+
sampling_params: TrainingSamplingParams,
237263
post_processing_fn: PostProcessingFunction,
238264
processed_mb: ProcessedMicrobatch,
239265
is_reward_model: bool = False,
@@ -253,7 +279,7 @@ def forward_with_post_processing_fn(
253279
254280
Args:
255281
model: The model to run forward pass on
256-
cfg: Configuration dictionary
282+
sampling_params: Sampling parameters
257283
post_processing_fn: Post-processing function to apply to the logits
258284
processed_mb: Pre-fetched ProcessedMicrobatch containing data and processed inputs
259285
is_reward_model: Whether this is a reward model
@@ -290,7 +316,10 @@ def forward_with_post_processing_fn(
290316
post_processing_fn,
291317
(LossPostProcessor, LogprobsPostProcessor, TopkLogitsPostProcessor),
292318
):
293-
logits = apply_temperature_scaling(logits, cfg)
319+
# Temperature scaling is element-wise, directly applying it here.
320+
# Other sampling parameters like top-k and top-p need the logits from whole vocabulary,
321+
# so applying them when gathering logits from vocab parallel (called in LossPostProcessor and LogprobsPostProcessor).
322+
logits = apply_temperature_scaling(logits, sampling_params)
294323

295324
# Apply the post-processing function directly based on type
296325
if isinstance(post_processing_fn, LossPostProcessor):
@@ -558,6 +587,7 @@ def __init__(
558587
tp_mesh: Any,
559588
cp_size: int,
560589
enable_seq_packing: bool = False,
590+
sampling_params: Optional[TrainingSamplingParams] = None,
561591
):
562592
"""Initialize LogprobsPostProcessor.
563593
@@ -568,13 +598,15 @@ def __init__(
568598
tp_mesh: Tensor parallel mesh
569599
cp_size: Context parallel size
570600
enable_seq_packing: Whether sequence packing is enabled
601+
sampling_params: Sampling parameters
571602
"""
572603
self.cfg = cfg
573604
self.device_mesh = device_mesh
574605
self.cp_mesh = cp_mesh
575606
self.tp_mesh = tp_mesh
576607
self.cp_size = cp_size
577608
self.enable_seq_packing = enable_seq_packing
609+
self.sampling_params = sampling_params
578610
self.logprob_chunk_size = cfg.get("logprob_chunk_size", None)
579611

580612
def __call__(
@@ -627,17 +659,21 @@ def __call__(
627659
input_ids_dtensor,
628660
seq_index_tensor,
629661
chunk_size=self.logprob_chunk_size,
662+
sampling_params=self.sampling_params, # top-k and top-p filtering
630663
)
631664

632665
assert token_logprobs.shape[1] == seq_len - 1
633666
else:
634667
if isinstance(logits, DTensor):
668+
# DTensor path with TP sharding
635669
token_logprobs = get_logprobs_from_vocab_parallel_logits(
636670
logits,
637671
processed_inputs.input_ids,
638672
chunk_size=self.logprob_chunk_size,
673+
sampling_params=self.sampling_params, # top-k and top-p filtering
639674
)
640675
else:
676+
# Non-DTensor path (no TP sharding)
641677
token_logprobs = self._compute_local_logprobs(
642678
logits, processed_inputs.input_ids
643679
)
@@ -703,12 +739,18 @@ def _compute_local_logprobs(
703739
(chunk_idx + 1) * self.logprob_chunk_size,
704740
)
705741
chunk_logits = logits[:, chunk_start:chunk_end, :].to(torch.float32)
742+
chunk_logits = apply_top_k_top_p_filtering_for_local_logits(
743+
chunk_logits, self.sampling_params
744+
)
706745
log_probs = torch.nn.functional.log_softmax(chunk_logits, dim=-1)
707746
chunked_log_probs.append(log_probs)
708747
log_probs = torch.cat(chunked_log_probs, dim=1)
709748
del chunked_log_probs
710749
else:
711750
logits = logits.to(torch.float32)
751+
logits = apply_top_k_top_p_filtering_for_local_logits(
752+
logits, self.sampling_params
753+
)
712754
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
713755

714756
# Extract logprobs for each token in the sequence by gathering the logprob

nemo_rl/models/megatron/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def forward_with_post_processing_fn(
184184
post_processing_fn,
185185
(LossPostProcessor, LogprobsPostProcessor, TopkLogitsPostProcessor),
186186
):
187+
# Temperature scaling is element-wise, directly applying it here.
188+
# Other sampling parameters like top-k and top-p need the logits from whole vocabulary,
189+
# so applying them when gathering logits from vocab parallel (called in LossPostProcessor and LogprobsPostProcessor).
187190
apply_temperature_scaling(output_tensor, cfg)
188191

189192
# Use type checking to dispatch to the correct post-processing method

nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,7 @@ def get_logprobs(
574574
tp_mesh=self.tp_mesh,
575575
cp_size=self.cp_size,
576576
enable_seq_packing=self.enable_seq_packing,
577+
sampling_params=self.sampling_params,
577578
)
578579

579580
with torch.no_grad():
@@ -602,7 +603,7 @@ def get_logprobs(
602603
# Use forward_with_post_processing_fn for forward pass and post-processing
603604
token_logprobs, _metrics, _ = forward_with_post_processing_fn(
604605
model=self.model,
605-
cfg=self.cfg,
606+
sampling_params=self.sampling_params,
606607
post_processing_fn=logprobs_post_processor,
607608
processed_mb=processed_mb,
608609
is_reward_model=False,
@@ -671,7 +672,7 @@ def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]:
671672
# Use forward_with_post_processing_fn for forward pass and post-processing
672673
rm_scores, _metrics, _ = forward_with_post_processing_fn(
673674
model=self.model,
674-
cfg=self.cfg,
675+
sampling_params=self.sampling_params,
675676
post_processing_fn=score_post_processor,
676677
processed_mb=processed_mb,
677678
is_reward_model=True,
@@ -761,7 +762,7 @@ def get_topk_logits(
761762
# Use forward_with_post_processing_fn for forward pass and post-processing
762763
(vals, idx), _metrics, _ = forward_with_post_processing_fn(
763764
model=self.model,
764-
cfg=self.cfg,
765+
sampling_params=self.sampling_params,
765766
post_processing_fn=topk_post_processor,
766767
processed_mb=processed_mb,
767768
is_reward_model=False,

0 commit comments

Comments
 (0)