Skip to content

Commit 7f66e95

Browse files
committed
update
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
1 parent 1cadb1e commit 7f66e95

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

nemo_rl/models/megatron/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def forward_with_post_processing_fn(
174174

175175
# Apply temperature scaling only for sampling-oriented post-processors.
176176
# Loss computation should use unscaled logits.
177-
if isinstance(post_processing_fn, (LogprobsPostProcessor, TopkLogitsPostProcessor)):
177+
if isinstance(post_processing_fn, (LossPostProcessor, LogprobsPostProcessor, TopkLogitsPostProcessor)):
178178
apply_temperature_scaling(output_tensor, cfg)
179179

180180
# Use type checking to dispatch to the correct post-processing method

0 commit comments

Comments
 (0)