Skip to content

Commit 45b907c

Browse files
committed
add test
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent d763771 commit 45b907c

File tree

1 file changed

+260
-0
lines changed

1 file changed

+260
-0
lines changed
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
from types import SimpleNamespace
18+
from unittest.mock import MagicMock
19+
20+
import pytest
21+
import torch
22+
23+
from nemo_automodel.components.distributed import megatron_fsdp as mfsdp
24+
25+
26+
class _FakeModel:
27+
"""Tiny stand-in with `.to()` chaining and optional checkpointing support."""
28+
29+
def __init__(self, *, supports_gradient_checkpointing: bool):
30+
self.to_calls = []
31+
self.gradient_checkpointing_enabled = False
32+
if supports_gradient_checkpointing:
33+
self.gradient_checkpointing_enable = MagicMock(side_effect=self._enable_gc)
34+
35+
def _enable_gc(self):
36+
self.gradient_checkpointing_enabled = True
37+
38+
def to(self, *args, **kwargs):
39+
self.to_calls.append((args, kwargs))
40+
return self
41+
42+
43+
def test_setup_distributed_raises_when_dist_not_available(monkeypatch):
44+
fake_dist = SimpleNamespace(is_available=lambda: False)
45+
monkeypatch.setattr(mfsdp, "dist", fake_dist, raising=True)
46+
47+
with pytest.raises(RuntimeError, match="torch.distributed not available"):
48+
mfsdp.MegatronFSDPManager(world_size=1, backend="gloo")
49+
50+
51+
def test_setup_distributed_raises_when_dist_not_initialized(monkeypatch):
52+
fake_dist = SimpleNamespace(is_available=lambda: True, is_initialized=lambda: False)
53+
monkeypatch.setattr(mfsdp, "dist", fake_dist, raising=True)
54+
55+
with pytest.raises(RuntimeError, match="expected torch.distributed to be initialized"):
56+
mfsdp.MegatronFSDPManager(world_size=1, backend="gloo")
57+
58+
59+
def test_setup_distributed_defaults_tp_cp_to_one_and_uses_cpu_mesh_when_backend_not_nccl(monkeypatch):
60+
fake_dist = SimpleNamespace(is_available=lambda: True, is_initialized=lambda: True)
61+
monkeypatch.setattr(mfsdp, "dist", fake_dist, raising=True)
62+
63+
mesh = MagicMock()
64+
init_device_mesh_mock = MagicMock(return_value=mesh)
65+
monkeypatch.setattr(mfsdp, "init_device_mesh", init_device_mesh_mock, raising=True)
66+
67+
mgr = mfsdp.MegatronFSDPManager(tp_size=0, cp_size=0, dp_size=None, world_size=4, backend="gloo")
68+
69+
assert mgr.tp_size == 1
70+
assert mgr.cp_size == 1
71+
assert mgr.dp_size == 4
72+
assert mgr.device_mesh is mesh
73+
74+
init_device_mesh_mock.assert_called_once()
75+
call_kwargs = init_device_mesh_mock.call_args.kwargs
76+
assert call_kwargs["device_type"] == "cpu"
77+
assert call_kwargs["mesh_shape"] == (4, 1, 1)
78+
assert call_kwargs["mesh_dim_names"] == ("dp", "cp", "tp")
79+
80+
81+
def test_setup_distributed_infers_dp_size_and_flattens_dp_cp_when_cp_gt_one(monkeypatch):
82+
fake_dist = SimpleNamespace(is_available=lambda: True, is_initialized=lambda: True)
83+
monkeypatch.setattr(mfsdp, "dist", fake_dist, raising=True)
84+
85+
mesh = MagicMock()
86+
tp_mesh = MagicMock()
87+
tp_mesh.size.return_value = 2
88+
dp_cp_mesh = MagicMock()
89+
90+
mesh.__getitem__.side_effect = lambda key: {
91+
"tp": tp_mesh,
92+
("dp", "cp"): dp_cp_mesh,
93+
}[key]
94+
95+
init_device_mesh_mock = MagicMock(return_value=mesh)
96+
monkeypatch.setattr(mfsdp, "init_device_mesh", init_device_mesh_mock, raising=True)
97+
98+
mgr = mfsdp.MegatronFSDPManager(dp_size=None, tp_size=2, cp_size=2, world_size=8, backend="nccl")
99+
100+
# inferred dp_size so that dp * cp * tp == world_size
101+
assert mgr.dp_size == 2
102+
assert mgr.device_mesh is mesh
103+
104+
# backend="nccl" selects cuda mesh
105+
init_device_mesh_mock.assert_called_once()
106+
call_kwargs = init_device_mesh_mock.call_args.kwargs
107+
assert call_kwargs["device_type"] == "cuda"
108+
assert call_kwargs["mesh_shape"] == (2, 2, 2)
109+
110+
# cp_size > 1 triggers dp+cp flattening
111+
dp_cp_mesh._flatten.assert_called_once_with(mesh_dim_name="dp_cp")
112+
113+
114+
def test_setup_distributed_raises_when_world_size_not_divisible_by_tp_times_cp(monkeypatch):
115+
fake_dist = SimpleNamespace(is_available=lambda: True, is_initialized=lambda: True)
116+
monkeypatch.setattr(mfsdp, "dist", fake_dist, raising=True)
117+
118+
with pytest.raises(ValueError, match="must be divisible by \\(tp_size \\* cp_size\\)"):
119+
mfsdp.MegatronFSDPManager(dp_size=None, tp_size=3, cp_size=2, world_size=8, backend="gloo")
120+
121+
122+
def test_parallelize_world_size_one_moves_to_cuda_bf16_and_enables_checkpointing_when_supported(monkeypatch):
123+
fake_dist = SimpleNamespace(
124+
is_available=lambda: True,
125+
is_initialized=lambda: True,
126+
get_world_size=lambda: 1,
127+
)
128+
monkeypatch.setattr(mfsdp, "dist", fake_dist, raising=True)
129+
monkeypatch.setattr(mfsdp, "init_device_mesh", MagicMock(return_value=MagicMock()), raising=True)
130+
131+
mgr = mfsdp.MegatronFSDPManager(world_size=1, backend="gloo", activation_checkpointing=True)
132+
model = _FakeModel(supports_gradient_checkpointing=True)
133+
optimizer = MagicMock()
134+
135+
out_model, out_opt = mgr.parallelize(model, optimizer=optimizer)
136+
assert out_model is model
137+
assert out_opt is optimizer
138+
139+
# `.to("cuda").to(torch.bfloat16)` chain should be attempted even in CPU-only tests
140+
assert [args for (args, _kwargs) in model.to_calls] == [("cuda",), (torch.bfloat16,)]
141+
model.gradient_checkpointing_enable.assert_called_once_with()
142+
assert model.gradient_checkpointing_enabled is True
143+
144+
145+
def test_parallelize_world_size_one_logs_error_when_checkpointing_not_supported(monkeypatch, caplog):
146+
fake_dist = SimpleNamespace(
147+
is_available=lambda: True,
148+
is_initialized=lambda: True,
149+
get_world_size=lambda: 1,
150+
)
151+
monkeypatch.setattr(mfsdp, "dist", fake_dist, raising=True)
152+
monkeypatch.setattr(mfsdp, "init_device_mesh", MagicMock(return_value=MagicMock()), raising=True)
153+
154+
mgr = mfsdp.MegatronFSDPManager(world_size=1, backend="gloo", activation_checkpointing=True)
155+
model = _FakeModel(supports_gradient_checkpointing=False)
156+
157+
caplog.set_level(logging.ERROR)
158+
mgr.parallelize(model, optimizer=None)
159+
assert "Model does not support gradient checkpointing. Skipping." in caplog.text
160+
161+
162+
def test_parallelize_world_size_gt_one_selects_tp_plan_passes_dims_and_warns_on_nonzero3(monkeypatch, capsys, caplog):
163+
fake_dist = SimpleNamespace(
164+
is_available=lambda: True,
165+
is_initialized=lambda: True,
166+
get_world_size=lambda: 8,
167+
)
168+
monkeypatch.setattr(mfsdp, "dist", fake_dist, raising=True)
169+
170+
# Device mesh used by manager.parallelize
171+
mesh = MagicMock()
172+
mesh.get_rank.return_value = 0
173+
tp_mesh = MagicMock()
174+
tp_mesh.size.return_value = 2
175+
dp_cp_mesh = MagicMock()
176+
mesh.__getitem__.side_effect = lambda key: {
177+
"tp": tp_mesh,
178+
("dp", "cp"): dp_cp_mesh,
179+
}[key]
180+
monkeypatch.setattr(mfsdp, "init_device_mesh", MagicMock(return_value=mesh), raising=True)
181+
182+
# Plan selection and strategy call should be delegated
183+
tp_plan = {"some.layer": object()}
184+
get_plan_mock = MagicMock(return_value=tp_plan)
185+
strat_mock = MagicMock(return_value=("parallel_model", "parallel_opt"))
186+
monkeypatch.setattr(mfsdp, "_get_parallel_plan", get_plan_mock, raising=True)
187+
monkeypatch.setattr(mfsdp, "megatron_fsdp_strategy_parallelize", strat_mock, raising=True)
188+
189+
mgr = mfsdp.MegatronFSDPManager(
190+
dp_size=None,
191+
tp_size=2,
192+
cp_size=2,
193+
world_size=8,
194+
backend="gloo",
195+
activation_checkpointing=True, # should log error but continue
196+
zero_dp_strategy=2, # triggers warning print on rank 0
197+
)
198+
199+
caplog.set_level(logging.ERROR)
200+
out_model, out_opt = mgr.parallelize(model=object(), optimizer="opt")
201+
assert (out_model, out_opt) == ("parallel_model", "parallel_opt")
202+
203+
# Activation checkpointing is not supported here; should emit an error log.
204+
assert "Activation checkpointing is not yet supported with MegatronFSDP. Skipping." in caplog.text
205+
206+
# zero_dp_strategy warning printed only on rank 0
207+
assert "Warning: MegatronFSDP zero_dp_strategy is not 3" in capsys.readouterr().out
208+
209+
# TP plan should be selected when tp mesh size > 1
210+
get_plan_mock.assert_called_once()
211+
plan_args, plan_kwargs = get_plan_mock.call_args
212+
assert plan_args[0] is not None # model object
213+
assert plan_kwargs["sequence_parallel"] is False
214+
assert plan_kwargs["tp_shard_plan"] is None
215+
assert plan_kwargs["use_hf_tp_plan"] is mgr.use_hf_tp_plan
216+
217+
# Strategy should receive computed mesh dim names
218+
strat_mock.assert_called_once()
219+
strat_kwargs = strat_mock.call_args.kwargs
220+
assert strat_kwargs["device_mesh"] is mesh
221+
assert strat_kwargs["tp_shard_plan"] == tp_plan
222+
assert strat_kwargs["dp_shard_dim"] == "dp_cp"
223+
assert strat_kwargs["tp_dim"] == "tp"
224+
225+
226+
def test_parallelize_world_size_gt_one_skips_tp_plan_when_tp_size_is_one(monkeypatch, capsys):
227+
fake_dist = SimpleNamespace(
228+
is_available=lambda: True,
229+
is_initialized=lambda: True,
230+
get_world_size=lambda: 2,
231+
)
232+
monkeypatch.setattr(mfsdp, "dist", fake_dist, raising=True)
233+
234+
mesh = MagicMock()
235+
mesh.get_rank.return_value = 0
236+
tp_mesh = MagicMock()
237+
tp_mesh.size.return_value = 1
238+
mesh.__getitem__.side_effect = lambda key: {"tp": tp_mesh}[key]
239+
monkeypatch.setattr(mfsdp, "init_device_mesh", MagicMock(return_value=mesh), raising=True)
240+
241+
get_plan_mock = MagicMock()
242+
strat_mock = MagicMock(return_value=("m", "o"))
243+
monkeypatch.setattr(mfsdp, "_get_parallel_plan", get_plan_mock, raising=True)
244+
monkeypatch.setattr(mfsdp, "megatron_fsdp_strategy_parallelize", strat_mock, raising=True)
245+
246+
mgr = mfsdp.MegatronFSDPManager(dp_size=2, tp_size=1, cp_size=1, world_size=2, backend="gloo")
247+
out_model, out_opt = mgr.parallelize(model=object(), optimizer=object())
248+
assert (out_model, out_opt) == ("m", "o")
249+
250+
# No TP -> do not ask for a TP plan
251+
get_plan_mock.assert_not_called()
252+
253+
# dp_shard_dim should be "dp" when cp_size == 1
254+
strat_kwargs = strat_mock.call_args.kwargs
255+
assert strat_kwargs["tp_shard_plan"] is None
256+
assert strat_kwargs["dp_shard_dim"] == "dp"
257+
258+
# zero_dp_strategy default is 3 -> no warning print
259+
assert capsys.readouterr().out == ""
260+

0 commit comments

Comments
 (0)