Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[submodule "3rdparty/Megatron-LM"]
path = 3rdparty/Megatron-LM-workspace/Megatron-LM
url = https://github.com/yaoyu-33/Megatron-LM.git
branch = main
branch = yifu/remove_do_not_average_loss
shallow = true
[submodule "3rdparty/Megatron-Bridge"]
path = 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge
Expand Down
21 changes: 16 additions & 5 deletions nemo_rl/models/megatron/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def megatron_forward_backward(
defer_fp32_logits: Optional[bool] = False,
global_valid_seqs: Optional[torch.Tensor] = None,
global_valid_toks: Optional[torch.Tensor] = None,
do_not_average_loss: bool = False,
straggler_timer: Optional[StragglerDetector] = None,
) -> Any:
"""Execute forward and backward passes using Megatron's utilities.
Expand All @@ -241,7 +240,6 @@ def megatron_forward_backward(
defer_fp32_logits: Whether to skip the conversion of logits to fp32
global_valid_seqs: Global valid sequence count for loss normalization
global_valid_toks: Global valid token count for loss normalization
do_not_average_loss: If True, do not average loss across microbatches
straggler_timer: Straggler detector for profiling the forward pass

Returns:
Expand All @@ -266,7 +264,6 @@ def megatron_forward_backward(
micro_batch_size=mbs,
decoder_seq_length=seq_length,
forward_only=forward_only,
do_not_average_loss=do_not_average_loss,
)


Expand All @@ -275,10 +272,12 @@ def __init__(
self,
loss_fn: LossFunction,
cfg: PolicyConfig,
num_microbatches: int = 1,
cp_normalize: bool = True,
):
self.loss_fn = loss_fn
self.cfg = cfg
self.num_microbatches = num_microbatches
self.cp_normalize = cp_normalize

def __call__(
Expand Down Expand Up @@ -325,14 +324,26 @@ def __call__(

if self.cp_normalize:
cp_size = get_context_parallel_world_size()
orig_loss_fn_wrapped = loss_fn_wrapped
prev_loss_fn = loss_fn_wrapped

def _div_by_cp_size(*args, **kwargs):
loss, metrics = orig_loss_fn_wrapped(*args, **kwargs)
loss, metrics = prev_loss_fn(*args, **kwargs)
return loss / cp_size, metrics

loss_fn_wrapped = _div_by_cp_size

# Counteract Megatron's default loss averaging in schedules.py,
# which applies (* cp_size / num_microbatches) to the loss.
cp_size = get_context_parallel_world_size()
num_microbatches = self.num_microbatches
loss_fn_before_mcore_scaling = loss_fn_wrapped

def _counteract_mcore_loss_averaging(*args, **kwargs):
loss, metrics = loss_fn_before_mcore_scaling(*args, **kwargs)
return loss * num_microbatches / cp_size, metrics

loss_fn_wrapped = _counteract_mcore_loss_averaging

return loss_fn_wrapped


Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def train(
loss_post_processor = LossPostProcessor(
loss_fn=loss_fn,
cfg=self.cfg,
num_microbatches=num_microbatches,
)

rerun_state_machine = get_rerun_state_machine()
Expand All @@ -339,7 +340,6 @@ def train(
defer_fp32_logits=self.defer_fp32_logits,
global_valid_seqs=global_valid_seqs,
global_valid_toks=global_valid_toks,
do_not_average_loss=True,
straggler_timer=self.mcore_state.straggler_timer,
)

Expand Down
10 changes: 6 additions & 4 deletions tests/unit/models/megatron/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,13 +719,15 @@ def test_loss_post_processor_no_packing(
def test_loss_post_processor_with_cp_normalize(
self, mock_cp_size, mock_cp_grp, mock_tp_grp, mock_tp_rank
):
"""Test LossPostProcessor with CP normalization."""
"""Test LossPostProcessor with CP normalization and microbatch pre-scaling."""
from nemo_rl.models.megatron.train import LossPostProcessor

mock_loss_fn = MagicMock(return_value=(torch.tensor(1.0), {}))
cfg = {"sequence_packing": {"enabled": False}}

processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg, cp_normalize=True)
processor = LossPostProcessor(
loss_fn=mock_loss_fn, cfg=cfg, num_microbatches=4, cp_normalize=True
)

# Set up mock return values for process groups
mock_tp_grp.return_value = MagicMock()
Expand All @@ -736,8 +738,8 @@ def test_loss_post_processor_with_cp_normalize(
output_tensor = torch.randn(2, 10, 100)
loss, _ = wrapped_fn(output_tensor)

# Loss should be divided by CP size (2)
assert torch.isclose(loss, torch.tensor(0.5))
# Loss should be scaled by num_microbatches / (cp_size * cp_size) = 4 / (2 * 2) = 1.0
assert torch.isclose(loss, torch.tensor(1.0))

@patch(
"nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0
Expand Down
Loading