Skip to content

Commit d63506c

Browse files
committed
address feedback
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
1 parent 4b137cc commit d63506c

File tree

2 files changed

+16
-21
lines changed

2 files changed

+16
-21
lines changed

nemo_rl/models/megatron/pipeline_parallel.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -57,14 +57,12 @@ def broadcast_obj_from_pp_rank(obj: Any) -> Any:
5757
# ------------------------------------------------------------------
5858
# 2. Identify the owning rank (the only rank with True flag)
5959
# ------------------------------------------------------------------
60-
src_rank = None # Rank *inside* the PP group
61-
for rank, flag in enumerate(obj_flags):
62-
if flag:
63-
src_rank = rank
64-
break
65-
66-
if src_rank is None:
60+
true_ranks = [rank for rank, flag in enumerate(obj_flags) if flag]
61+
if not true_ranks:
6762
raise ValueError("Object must exist on at least one PP rank")
63+
if len(true_ranks) > 1:
64+
raise ValueError(f"Object present on multiple PP ranks: {true_ranks}")
65+
src_rank = true_ranks[0]
6866

6967
# ------------------------------------------------------------------
7068
# 3. Broadcast the object from the source rank to all ranks
@@ -135,12 +133,11 @@ def broadcast_tensors_from_last_stage(
135133
if is_pipeline_last_stage(ignore_virtual=True):
136134
# Broadcast tensors from last stage
137135
for name, tensor in tensors.items():
138-
if tensor is not None:
139-
broadcasted_tensors[name] = broadcast_tensor(
140-
tensor, current_rank, pp_group
136+
if tensor is None:
137+
raise ValueError(
138+
f"Last PP stage must provide tensor '{name}' for broadcast."
141139
)
142-
else:
143-
broadcasted_tensors[name] = None
140+
broadcasted_tensors[name] = broadcast_tensor(tensor, current_rank, pp_group)
144141
else:
145142
# Receive tensors on other stages
146143
for name in tensors.keys():

nemo_rl/models/megatron/train.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -96,8 +96,6 @@ def model_forward(
9696
**multimodal_data,
9797
)
9898

99-
apply_temperature_scaling(output_tensor, cfg)
100-
10199
return output_tensor
102100

103101

@@ -174,7 +172,11 @@ def forward_with_post_processing_fn(
174172
straggler_timer=straggler_timer,
175173
)
176174

177-
## calling post_processing_fn will return a function that takes the output tensor and returns a tuple of (loss, metrics)
175+
# Apply temperature scaling only for sampling-oriented post-processors.
176+
# Loss computation should use unscaled logits.
177+
if isinstance(post_processing_fn, (LogprobsPostProcessor, TopkLogitsPostProcessor)):
178+
apply_temperature_scaling(output_tensor, cfg)
179+
178180
# Use type checking to dispatch to the correct post-processing method
179181
if isinstance(post_processing_fn, LossPostProcessor):
180182
post_processing_fn_wrapped = post_processing_fn(
@@ -425,10 +427,6 @@ def __call__(
425427
seq_lengths = data_dict["input_lengths"]
426428

427429
def processor_fn_inner(output_tensor):
428-
# Only the last PP stage produces final logits/top-k; earlier stages return empty
429-
# if not is_pipeline_last_stage(ignore_virtual=True):
430-
# return output_tensor.new_zeros(()), {}
431-
432430
tp_grp = get_tensor_model_parallel_group()
433431
tp_rank = get_tensor_model_parallel_rank()
434432
vocab_shard_size = output_tensor.shape[-1]

0 commit comments

Comments
 (0)