Skip to content

Commit c3429d1

Browse files
committed
add test
Signed-off-by: HuiyingLi <[email protected]>
1 parent fd3574c commit c3429d1

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

tests/unit_tests/models/qwen3_vl_moe/test_qwen3_vl_moe_state_dict_adapter.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,47 @@ def test_aggregates_dtensor_path_uses_split_helper(self, adapter, monkeypatch):
181181
assert global_down.shape[0] == adapter.moe_config.n_routed_experts
182182
torch.testing.assert_close(global_down[2], local_slice)
183183

184+
def test_all_gather_path_populates_global_tensor(self, adapter, monkeypatch):
185+
# Local shard has experts 0 and 1; simulate another rank providing experts 2 and 3
186+
local_experts = torch.tensor(
187+
[
188+
[[1.0]],
189+
[[2.0]],
190+
],
191+
dtype=adapter.dtype,
192+
) # shape: [2, 1, 1]
193+
194+
device_mesh = Mock()
195+
device_mesh.mesh_dim_names = ["ep"]
196+
device_mesh.get_group = lambda dim: "ep_group" if dim == 0 else None
197+
198+
monkeypatch.setattr(
199+
"nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh",
200+
lambda mesh, n_experts: (0, 2),
201+
)
202+
monkeypatch.setattr("torch.distributed.is_initialized", lambda: True)
203+
monkeypatch.setattr("torch.distributed.get_world_size", lambda group=None: 2)
204+
205+
def fake_all_gather_object(gathered, payload, group=None):
206+
# payload from this rank for experts [0,1]; simulate other rank with [2,3]
207+
gathered[0] = payload
208+
other_weights = [torch.tensor([[3.0]], dtype=adapter.dtype), torch.tensor([[4.0]], dtype=adapter.dtype)]
209+
gathered[1] = ([2, 3], other_weights)
210+
211+
monkeypatch.setattr("torch.distributed.all_gather_object", fake_all_gather_object)
212+
213+
state_dict = {"model.language_model.layers.0.mlp.experts.gate_and_up_projs": local_experts}
214+
out = adapter.to_hf(state_dict, device_mesh=device_mesh)
215+
216+
gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj"
217+
global_gate = out[gate_key]
218+
219+
assert global_gate.shape == (adapter.moe_config.n_routed_experts, 1, 1)
220+
torch.testing.assert_close(global_gate[0], torch.tensor([[1.0]], dtype=adapter.dtype))
221+
torch.testing.assert_close(global_gate[1], torch.tensor([[2.0]], dtype=adapter.dtype))
222+
torch.testing.assert_close(global_gate[2], torch.tensor([[3.0]], dtype=adapter.dtype))
223+
torch.testing.assert_close(global_gate[3], torch.tensor([[4.0]], dtype=adapter.dtype))
224+
184225

185226
class TestFromHF:
186227
def test_detects_model_prefix(self, adapter):

0 commit comments

Comments
 (0)