Skip to content

Commit 41a4e68

Browse files
committed
move metric aggregation to a function matching automodel
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
1 parent a54b079 commit 41a4e68

File tree

3 files changed

+165
-14
lines changed

3 files changed

+165
-14
lines changed

nemo_rl/models/megatron/train.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import defaultdict
1516
from contextlib import nullcontext
1617
from functools import partial
17-
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union
18+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
1819

1920
import torch
2021
from megatron.core.models.gpt import GPTModel
@@ -543,3 +544,42 @@ def processor_fn_inner(output_tensor):
543544
}
544545

545546
return processor_fn_inner
547+
548+
549+
def aggregate_training_statistics(
550+
all_mb_metrics: List[Dict[str, Any]],
551+
losses: List[float],
552+
data_parallel_group: torch.distributed.ProcessGroup,
553+
) -> Tuple[Dict[str, List[Any]], torch.Tensor]:
554+
"""Aggregate training statistics across microbatches and data-parallel ranks.
555+
556+
Computes a global loss by all-reducing per-gradient-buffer losses across the
557+
data-parallel group, then collects per-microbatch metrics into lists keyed by
558+
metric name.
559+
560+
Args:
561+
all_mb_metrics: List of metric dicts from each microbatch.
562+
losses: List of per-gradient-buffer scalar losses on this rank.
563+
data_parallel_group: The data-parallel process group for all-reduce.
564+
565+
Returns:
566+
Tuple of:
567+
- mb_metrics: Dict mapping metric names to lists of values across microbatches.
568+
- global_loss: Tensor of losses summed across all data-parallel ranks.
569+
"""
570+
# Compute global loss across all data-parallel ranks
571+
with torch.no_grad():
572+
global_loss = torch.tensor(losses, device="cuda")
573+
torch.distributed.all_reduce(
574+
global_loss,
575+
op=torch.distributed.ReduceOp.SUM,
576+
group=data_parallel_group,
577+
)
578+
579+
# Aggregate metrics across all microbatches
580+
mb_metrics: Dict[str, List[Any]] = defaultdict(list)
581+
for m in all_mb_metrics:
582+
for k, v in m.items():
583+
mb_metrics[k].append(v)
584+
585+
return dict(mb_metrics), global_loss

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
LogprobsPostProcessor,
8383
LossPostProcessor,
8484
TopkLogitsPostProcessor,
85+
aggregate_training_statistics,
8586
megatron_forward_backward,
8687
)
8788
from nemo_rl.models.policy import PolicyConfig
@@ -416,25 +417,18 @@ def train(
416417
self.scheduler.step(increment=gbs)
417418

418419
# Aggregate metrics across all microbatches
419-
mb_metrics = defaultdict(list)
420-
for m in all_mb_metrics:
421-
for k, v in m.items():
422-
mb_metrics[k].append(v)
423-
424-
with torch.no_grad():
425-
global_loss = torch.tensor(losses, device="cuda")
426-
torch.distributed.all_reduce(
427-
global_loss,
428-
op=torch.distributed.ReduceOp.SUM,
429-
group=parallel_state.get_data_parallel_group(),
430-
)
420+
mb_metrics, global_loss = aggregate_training_statistics(
421+
all_mb_metrics=all_mb_metrics,
422+
losses=losses,
423+
data_parallel_group=parallel_state.get_data_parallel_group(),
424+
)
431425

432426
metrics = {
433427
"global_loss": global_loss.cpu(),
434428
"rank": torch.distributed.get_rank(),
435429
"gpu_name": torch.cuda.get_device_name(),
436430
"model_dtype": self.dtype,
437-
"all_mb_metrics": dict(mb_metrics),
431+
"all_mb_metrics": mb_metrics,
438432
"grad_norm": torch.tensor([grad_norm]),
439433
}
440434
# Collect MoE aux metrics averaged across microbatches

tests/unit/models/megatron/test_train.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,3 +1123,120 @@ def fake_allgather(local_tensor, group, seq_dim):
11231123
# Output should be unpacked: (batch_size=2, unpacked_seqlen=6, k=3)
11241124
assert result["topk_logits"].shape == (2, unpacked_seqlen, k)
11251125
assert result["topk_indices"].shape == (2, unpacked_seqlen, k)
1126+
1127+
1128+
class TestAggregateTrainingStatistics:
1129+
"""Tests for aggregate_training_statistics function."""
1130+
1131+
@patch("torch.distributed.all_reduce")
1132+
def test_aggregates_metrics_across_microbatches(self, mock_all_reduce):
1133+
"""Test that per-microbatch metrics are collected into lists by key."""
1134+
from nemo_rl.models.megatron.train import aggregate_training_statistics
1135+
1136+
all_mb_metrics = [
1137+
{"loss": 0.5, "lr": 1e-4},
1138+
{"loss": 0.3, "lr": 1e-4},
1139+
{"loss": 0.2, "lr": 1e-4},
1140+
]
1141+
1142+
mock_dp_group = MagicMock()
1143+
1144+
mb_metrics, _ = aggregate_training_statistics(
1145+
all_mb_metrics=all_mb_metrics,
1146+
losses=[1.0],
1147+
data_parallel_group=mock_dp_group,
1148+
)
1149+
1150+
assert mb_metrics["loss"] == [0.5, 0.3, 0.2]
1151+
assert mb_metrics["lr"] == [1e-4, 1e-4, 1e-4]
1152+
assert len(mb_metrics) == 2
1153+
1154+
@patch("torch.distributed.all_reduce")
1155+
def test_returns_plain_dict(self, mock_all_reduce):
1156+
"""Test that the returned mb_metrics is a plain dict, not defaultdict."""
1157+
from nemo_rl.models.megatron.train import aggregate_training_statistics
1158+
1159+
mb_metrics, _ = aggregate_training_statistics(
1160+
all_mb_metrics=[{"loss": 0.5}],
1161+
losses=[1.0],
1162+
data_parallel_group=MagicMock(),
1163+
)
1164+
1165+
assert type(mb_metrics) is dict
1166+
1167+
@patch("torch.distributed.all_reduce")
1168+
def test_global_loss_tensor_from_losses(self, mock_all_reduce):
1169+
"""Test that losses list is converted to a CUDA tensor for all-reduce."""
1170+
from nemo_rl.models.megatron.train import aggregate_training_statistics
1171+
1172+
mock_dp_group = MagicMock()
1173+
1174+
_, global_loss = aggregate_training_statistics(
1175+
all_mb_metrics=[],
1176+
losses=[0.5, 0.3, 0.2],
1177+
data_parallel_group=mock_dp_group,
1178+
)
1179+
1180+
# Verify all_reduce was called with correct args
1181+
mock_all_reduce.assert_called_once()
1182+
call_args = mock_all_reduce.call_args
1183+
assert call_args[1]["op"] == torch.distributed.ReduceOp.SUM
1184+
assert call_args[1]["group"] is mock_dp_group
1185+
1186+
# Verify tensor shape matches losses list
1187+
reduced_tensor = call_args[0][0]
1188+
assert reduced_tensor.shape == (3,)
1189+
1190+
@patch("torch.distributed.all_reduce")
1191+
def test_empty_metrics(self, mock_all_reduce):
1192+
"""Test with empty microbatch metrics list."""
1193+
from nemo_rl.models.megatron.train import aggregate_training_statistics
1194+
1195+
mb_metrics, global_loss = aggregate_training_statistics(
1196+
all_mb_metrics=[],
1197+
losses=[1.0],
1198+
data_parallel_group=MagicMock(),
1199+
)
1200+
1201+
assert mb_metrics == {}
1202+
mock_all_reduce.assert_called_once()
1203+
1204+
@patch("torch.distributed.all_reduce")
1205+
def test_handles_heterogeneous_metric_keys(self, mock_all_reduce):
1206+
"""Test that microbatches with different metric keys are handled correctly."""
1207+
from nemo_rl.models.megatron.train import aggregate_training_statistics
1208+
1209+
all_mb_metrics = [
1210+
{"loss": 0.5, "lr": 1e-4},
1211+
{"loss": 0.3, "global_valid_seqs": 8},
1212+
]
1213+
1214+
mb_metrics, _ = aggregate_training_statistics(
1215+
all_mb_metrics=all_mb_metrics,
1216+
losses=[0.8],
1217+
data_parallel_group=MagicMock(),
1218+
)
1219+
1220+
assert mb_metrics["loss"] == [0.5, 0.3]
1221+
assert mb_metrics["lr"] == [1e-4]
1222+
assert mb_metrics["global_valid_seqs"] == [8]
1223+
1224+
@patch("torch.distributed.all_reduce")
1225+
def test_no_grad_context(self, mock_all_reduce):
1226+
"""Test that all-reduce runs under torch.no_grad context."""
1227+
from nemo_rl.models.megatron.train import aggregate_training_statistics
1228+
1229+
grad_enabled_during_all_reduce = []
1230+
1231+
def capture_grad_state(*args, **kwargs):
1232+
grad_enabled_during_all_reduce.append(torch.is_grad_enabled())
1233+
1234+
mock_all_reduce.side_effect = capture_grad_state
1235+
1236+
aggregate_training_statistics(
1237+
all_mb_metrics=[],
1238+
losses=[1.0],
1239+
data_parallel_group=MagicMock(),
1240+
)
1241+
1242+
assert grad_enabled_during_all_reduce == [False]

0 commit comments

Comments
 (0)