@@ -1123,3 +1123,120 @@ def fake_allgather(local_tensor, group, seq_dim):
11231123 # Output should be unpacked: (batch_size=2, unpacked_seqlen=6, k=3)
11241124 assert result ["topk_logits" ].shape == (2 , unpacked_seqlen , k )
11251125 assert result ["topk_indices" ].shape == (2 , unpacked_seqlen , k )
1126+
1127+
1128+ class TestAggregateTrainingStatistics :
1129+ """Tests for aggregate_training_statistics function."""
1130+
1131+ @patch ("torch.distributed.all_reduce" )
1132+ def test_aggregates_metrics_across_microbatches (self , mock_all_reduce ):
1133+ """Test that per-microbatch metrics are collected into lists by key."""
1134+ from nemo_rl .models .megatron .train import aggregate_training_statistics
1135+
1136+ all_mb_metrics = [
1137+ {"loss" : 0.5 , "lr" : 1e-4 },
1138+ {"loss" : 0.3 , "lr" : 1e-4 },
1139+ {"loss" : 0.2 , "lr" : 1e-4 },
1140+ ]
1141+
1142+ mock_dp_group = MagicMock ()
1143+
1144+ mb_metrics , _ = aggregate_training_statistics (
1145+ all_mb_metrics = all_mb_metrics ,
1146+ losses = [1.0 ],
1147+ data_parallel_group = mock_dp_group ,
1148+ )
1149+
1150+ assert mb_metrics ["loss" ] == [0.5 , 0.3 , 0.2 ]
1151+ assert mb_metrics ["lr" ] == [1e-4 , 1e-4 , 1e-4 ]
1152+ assert len (mb_metrics ) == 2
1153+
1154+ @patch ("torch.distributed.all_reduce" )
1155+ def test_returns_plain_dict (self , mock_all_reduce ):
1156+ """Test that the returned mb_metrics is a plain dict, not defaultdict."""
1157+ from nemo_rl .models .megatron .train import aggregate_training_statistics
1158+
1159+ mb_metrics , _ = aggregate_training_statistics (
1160+ all_mb_metrics = [{"loss" : 0.5 }],
1161+ losses = [1.0 ],
1162+ data_parallel_group = MagicMock (),
1163+ )
1164+
1165+ assert type (mb_metrics ) is dict
1166+
1167+ @patch ("torch.distributed.all_reduce" )
1168+ def test_global_loss_tensor_from_losses (self , mock_all_reduce ):
1169+ """Test that losses list is converted to a CUDA tensor for all-reduce."""
1170+ from nemo_rl .models .megatron .train import aggregate_training_statistics
1171+
1172+ mock_dp_group = MagicMock ()
1173+
1174+ _ , global_loss = aggregate_training_statistics (
1175+ all_mb_metrics = [],
1176+ losses = [0.5 , 0.3 , 0.2 ],
1177+ data_parallel_group = mock_dp_group ,
1178+ )
1179+
1180+ # Verify all_reduce was called with correct args
1181+ mock_all_reduce .assert_called_once ()
1182+ call_args = mock_all_reduce .call_args
1183+ assert call_args [1 ]["op" ] == torch .distributed .ReduceOp .SUM
1184+ assert call_args [1 ]["group" ] is mock_dp_group
1185+
1186+ # Verify tensor shape matches losses list
1187+ reduced_tensor = call_args [0 ][0 ]
1188+ assert reduced_tensor .shape == (3 ,)
1189+
1190+ @patch ("torch.distributed.all_reduce" )
1191+ def test_empty_metrics (self , mock_all_reduce ):
1192+ """Test with empty microbatch metrics list."""
1193+ from nemo_rl .models .megatron .train import aggregate_training_statistics
1194+
1195+ mb_metrics , global_loss = aggregate_training_statistics (
1196+ all_mb_metrics = [],
1197+ losses = [1.0 ],
1198+ data_parallel_group = MagicMock (),
1199+ )
1200+
1201+ assert mb_metrics == {}
1202+ mock_all_reduce .assert_called_once ()
1203+
1204+ @patch ("torch.distributed.all_reduce" )
1205+ def test_handles_heterogeneous_metric_keys (self , mock_all_reduce ):
1206+ """Test that microbatches with different metric keys are handled correctly."""
1207+ from nemo_rl .models .megatron .train import aggregate_training_statistics
1208+
1209+ all_mb_metrics = [
1210+ {"loss" : 0.5 , "lr" : 1e-4 },
1211+ {"loss" : 0.3 , "global_valid_seqs" : 8 },
1212+ ]
1213+
1214+ mb_metrics , _ = aggregate_training_statistics (
1215+ all_mb_metrics = all_mb_metrics ,
1216+ losses = [0.8 ],
1217+ data_parallel_group = MagicMock (),
1218+ )
1219+
1220+ assert mb_metrics ["loss" ] == [0.5 , 0.3 ]
1221+ assert mb_metrics ["lr" ] == [1e-4 ]
1222+ assert mb_metrics ["global_valid_seqs" ] == [8 ]
1223+
1224+ @patch ("torch.distributed.all_reduce" )
1225+ def test_no_grad_context (self , mock_all_reduce ):
1226+ """Test that all-reduce runs under torch.no_grad context."""
1227+ from nemo_rl .models .megatron .train import aggregate_training_statistics
1228+
1229+ grad_enabled_during_all_reduce = []
1230+
1231+ def capture_grad_state (* args , ** kwargs ):
1232+ grad_enabled_during_all_reduce .append (torch .is_grad_enabled ())
1233+
1234+ mock_all_reduce .side_effect = capture_grad_state
1235+
1236+ aggregate_training_statistics (
1237+ all_mb_metrics = [],
1238+ losses = [1.0 ],
1239+ data_parallel_group = MagicMock (),
1240+ )
1241+
1242+ assert grad_enabled_during_all_reduce == [False ]
0 commit comments