Skip to content

Commit 53cd230

Browse files
committed
add do_not_average_loss arg
Signed-off-by: ashors1 <ashors@nvidia.com>
1 parent b745aef commit 53cd230

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

nemo_rl/models/megatron/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def megatron_forward_backward(
180180
defer_fp32_logits: Optional[bool] = None,
181181
global_valid_seqs: Optional[torch.Tensor] = None,
182182
global_valid_toks: Optional[torch.Tensor] = None,
183+
do_not_average_loss: bool = False,
183184
) -> Any:
184185
"""
185186
Execute forward and backward passes using Megatron's utilities.

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,10 +963,10 @@ def train(
963963
mbs=micro_batch_size,
964964
post_processing_fn=loss_fn_wrapped,
965965
forward_only=eval_mode,
966-
#do_not_average_loss=True, ## TODO!
967966
defer_fp32_logits=self.defer_fp32_logits,
968967
global_valid_seqs=global_valid_seqs,
969968
global_valid_toks=global_valid_toks,
969+
do_not_average_loss=True,
970970
)
971971

972972
# Empty unused memory.

0 commit comments

Comments
 (0)