Skip to content

Commit 84bede0

Browse files
authored
feat: Remove do_not_average_loss (#1988)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
1 parent c40dba3 commit 84bede0

File tree

5 files changed

+25
-12
lines changed

5 files changed

+25
-12
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[submodule "3rdparty/Megatron-LM"]
22
path = 3rdparty/Megatron-LM-workspace/Megatron-LM
33
url = https://github.com/yaoyu-33/Megatron-LM.git
4-
branch = main
4+
branch = yifu/remove_do_not_average_loss
55
shallow = true
66
[submodule "3rdparty/Megatron-Bridge"]
77
path = 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge

nemo_rl/models/megatron/train.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def megatron_forward_backward(
220220
defer_fp32_logits: Optional[bool] = False,
221221
global_valid_seqs: Optional[torch.Tensor] = None,
222222
global_valid_toks: Optional[torch.Tensor] = None,
223-
do_not_average_loss: bool = False,
224223
straggler_timer: Optional[StragglerDetector] = None,
225224
) -> Any:
226225
"""Execute forward and backward passes using Megatron's utilities.
@@ -241,7 +240,6 @@ def megatron_forward_backward(
241240
defer_fp32_logits: Whether to skip the conversion of logits to fp32
242241
global_valid_seqs: Global valid sequence count for loss normalization
243242
global_valid_toks: Global valid token count for loss normalization
244-
do_not_average_loss: If True, do not average loss across microbatches
245243
straggler_timer: Straggler detector for profiling the forward pass
246244
247245
Returns:
@@ -266,7 +264,6 @@ def megatron_forward_backward(
266264
micro_batch_size=mbs,
267265
decoder_seq_length=seq_length,
268266
forward_only=forward_only,
269-
do_not_average_loss=do_not_average_loss,
270267
)
271268

272269

@@ -275,10 +272,12 @@ def __init__(
275272
self,
276273
loss_fn: LossFunction,
277274
cfg: PolicyConfig,
275+
num_microbatches: int = 1,
278276
cp_normalize: bool = True,
279277
):
280278
self.loss_fn = loss_fn
281279
self.cfg = cfg
280+
self.num_microbatches = num_microbatches
282281
self.cp_normalize = cp_normalize
283282

284283
def __call__(
@@ -325,14 +324,26 @@ def __call__(
325324

326325
if self.cp_normalize:
327326
cp_size = get_context_parallel_world_size()
328-
orig_loss_fn_wrapped = loss_fn_wrapped
327+
prev_loss_fn = loss_fn_wrapped
329328

330329
def _div_by_cp_size(*args, **kwargs):
331-
loss, metrics = orig_loss_fn_wrapped(*args, **kwargs)
330+
loss, metrics = prev_loss_fn(*args, **kwargs)
332331
return loss / cp_size, metrics
333332

334333
loss_fn_wrapped = _div_by_cp_size
335334

335+
# Counteract Megatron's default loss averaging in schedules.py,
336+
# which applies (* cp_size / num_microbatches) to the loss.
337+
cp_size = get_context_parallel_world_size()
338+
num_microbatches = self.num_microbatches
339+
loss_fn_before_mcore_scaling = loss_fn_wrapped
340+
341+
def _counteract_mcore_loss_averaging(*args, **kwargs):
342+
loss, metrics = loss_fn_before_mcore_scaling(*args, **kwargs)
343+
return loss * num_microbatches / cp_size, metrics
344+
345+
loss_fn_wrapped = _counteract_mcore_loss_averaging
346+
336347
return loss_fn_wrapped
337348

338349

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def train(
318318
loss_post_processor = LossPostProcessor(
319319
loss_fn=loss_fn,
320320
cfg=self.cfg,
321+
num_microbatches=num_microbatches,
321322
)
322323

323324
rerun_state_machine = get_rerun_state_machine()
@@ -339,7 +340,6 @@ def train(
339340
defer_fp32_logits=self.defer_fp32_logits,
340341
global_valid_seqs=global_valid_seqs,
341342
global_valid_toks=global_valid_toks,
342-
do_not_average_loss=True,
343343
straggler_timer=self.mcore_state.straggler_timer,
344344
)
345345

tests/unit/models/megatron/test_train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -719,13 +719,15 @@ def test_loss_post_processor_no_packing(
719719
def test_loss_post_processor_with_cp_normalize(
720720
self, mock_cp_size, mock_cp_grp, mock_tp_grp, mock_tp_rank
721721
):
722-
"""Test LossPostProcessor with CP normalization."""
722+
"""Test LossPostProcessor with CP normalization and microbatch pre-scaling."""
723723
from nemo_rl.models.megatron.train import LossPostProcessor
724724

725725
mock_loss_fn = MagicMock(return_value=(torch.tensor(1.0), {}))
726726
cfg = {"sequence_packing": {"enabled": False}}
727727

728-
processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg, cp_normalize=True)
728+
processor = LossPostProcessor(
729+
loss_fn=mock_loss_fn, cfg=cfg, num_microbatches=4, cp_normalize=True
730+
)
729731

730732
# Set up mock return values for process groups
731733
mock_tp_grp.return_value = MagicMock()
@@ -736,8 +738,8 @@ def test_loss_post_processor_with_cp_normalize(
736738
output_tensor = torch.randn(2, 10, 100)
737739
loss, _ = wrapped_fn(output_tensor)
738740

739-
# Loss should be divided by CP size (2)
740-
assert torch.isclose(loss, torch.tensor(0.5))
741+
# Loss should be scaled by num_microbatches / (cp_size * cp_size) = 4 / (2 * 2) = 1.0
742+
assert torch.isclose(loss, torch.tensor(1.0))
741743

742744
@patch(
743745
"nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0

0 commit comments

Comments
 (0)