diff --git a/.gitmodules b/.gitmodules index 81d066b8b0..c1b0c5a56f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "3rdparty/Megatron-LM"] path = 3rdparty/Megatron-LM-workspace/Megatron-LM url = https://github.com/yaoyu-33/Megatron-LM.git - branch = main + branch = yifu/remove_do_not_average_loss shallow = true [submodule "3rdparty/Megatron-Bridge"] path = 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge diff --git a/3rdparty/Megatron-LM-workspace/Megatron-LM b/3rdparty/Megatron-LM-workspace/Megatron-LM index 193463c4f8..b12071b947 160000 --- a/3rdparty/Megatron-LM-workspace/Megatron-LM +++ b/3rdparty/Megatron-LM-workspace/Megatron-LM @@ -1 +1 @@ -Subproject commit 193463c4f8414e6906a40dd527a450bca50706b1 +Subproject commit b12071b947f9ee3c6616306662069fc4ca77be4c diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 95ccc3761d..8459eada93 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -220,7 +220,6 @@ def megatron_forward_backward( defer_fp32_logits: Optional[bool] = False, global_valid_seqs: Optional[torch.Tensor] = None, global_valid_toks: Optional[torch.Tensor] = None, - do_not_average_loss: bool = False, straggler_timer: Optional[StragglerDetector] = None, ) -> Any: """Execute forward and backward passes using Megatron's utilities. @@ -241,7 +240,6 @@ def megatron_forward_backward( defer_fp32_logits: Whether to skip the conversion of logits to fp32 global_valid_seqs: Global valid sequence count for loss normalization global_valid_toks: Global valid token count for loss normalization - do_not_average_loss: If True, do not average loss across microbatches straggler_timer: Straggler detector for profiling the forward pass Returns: @@ -266,7 +264,6 @@ def megatron_forward_backward( micro_batch_size=mbs, decoder_seq_length=seq_length, forward_only=forward_only, - do_not_average_loss=do_not_average_loss, ) @@ -275,10 +272,12 @@ def __init__( self, loss_fn: LossFunction, cfg: PolicyConfig, + num_microbatches: int = 1, cp_normalize: bool = True, ): self.loss_fn = loss_fn self.cfg = cfg + self.num_microbatches = num_microbatches self.cp_normalize = cp_normalize def __call__( @@ -325,14 +324,26 @@ def __call__( if self.cp_normalize: cp_size = get_context_parallel_world_size() - orig_loss_fn_wrapped = loss_fn_wrapped + prev_loss_fn = loss_fn_wrapped def _div_by_cp_size(*args, **kwargs): - loss, metrics = orig_loss_fn_wrapped(*args, **kwargs) + loss, metrics = prev_loss_fn(*args, **kwargs) return loss / cp_size, metrics loss_fn_wrapped = _div_by_cp_size + # Counteract Megatron's default loss averaging in schedules.py, + # which applies (* cp_size / num_microbatches) to the loss. + cp_size = get_context_parallel_world_size() + num_microbatches = self.num_microbatches + loss_fn_before_mcore_scaling = loss_fn_wrapped + + def _counteract_mcore_loss_averaging(*args, **kwargs): + loss, metrics = loss_fn_before_mcore_scaling(*args, **kwargs) + return loss * num_microbatches / cp_size, metrics + + loss_fn_wrapped = _counteract_mcore_loss_averaging + return loss_fn_wrapped diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index d9a1c3d8a3..5f1483ed9a 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -318,6 +318,7 @@ def train( loss_post_processor = LossPostProcessor( loss_fn=loss_fn, cfg=self.cfg, + num_microbatches=num_microbatches, ) rerun_state_machine = get_rerun_state_machine() @@ -339,7 +340,6 @@ def train( defer_fp32_logits=self.defer_fp32_logits, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, - do_not_average_loss=True, straggler_timer=self.mcore_state.straggler_timer, ) diff --git a/tests/unit/models/megatron/test_train.py b/tests/unit/models/megatron/test_train.py index cf261c3d75..24dda67eec 100644 --- a/tests/unit/models/megatron/test_train.py +++ b/tests/unit/models/megatron/test_train.py @@ -719,13 +719,15 @@ def test_loss_post_processor_no_packing( def test_loss_post_processor_with_cp_normalize( self, mock_cp_size, mock_cp_grp, mock_tp_grp, mock_tp_rank ): - """Test LossPostProcessor with CP normalization.""" + """Test LossPostProcessor with CP normalization and microbatch pre-scaling.""" from nemo_rl.models.megatron.train import LossPostProcessor mock_loss_fn = MagicMock(return_value=(torch.tensor(1.0), {})) cfg = {"sequence_packing": {"enabled": False}} - processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg, cp_normalize=True) + processor = LossPostProcessor( + loss_fn=mock_loss_fn, cfg=cfg, num_microbatches=4, cp_normalize=True + ) # Set up mock return values for process groups mock_tp_grp.return_value = MagicMock() @@ -736,8 +738,8 @@ def test_loss_post_processor_with_cp_normalize( output_tensor = torch.randn(2, 10, 100) loss, _ = wrapped_fn(output_tensor) - # Loss should be divided by CP size (2) - assert torch.isclose(loss, torch.tensor(0.5)) + # Loss should be scaled by num_microbatches / (cp_size * cp_size) = 4 / (2 * 2) = 1.0 + assert torch.isclose(loss, torch.tensor(1.0)) @patch( "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0