@@ -120,6 +120,109 @@ def test_respects_exclude_regex(self, adapter):
120120 assert "exclude.me" not in out
121121
122122
123+ def test_aggregates_with_device_mesh_non_dtensor (self , adapter , monkeypatch ):
124+ local_experts = torch .tensor (
125+ [
126+ [[1.0 , 2.0 ], [3.0 , 4.0 ]],
127+ [[5.0 , 6.0 ], [7.0 , 8.0 ]],
128+ ],
129+ dtype = adapter .dtype ,
130+ ) # shape: [2, 2, 2]
131+
132+ # Only experts 1 and 2 live on this rank
133+ monkeypatch .setattr (
134+ "nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh" ,
135+ lambda mesh , n_experts : (1 , 3 ),
136+ )
137+ # No distributed init => skip all_gather branch
138+ monkeypatch .setattr ("torch.distributed.is_initialized" , lambda : False )
139+
140+ device_mesh = Mock ()
141+ device_mesh .mesh_dim_names = ["ep" ]
142+
143+ state_dict = {
144+ "model.language_model.layers.0.mlp.experts.gate_and_up_projs" : local_experts ,
145+ }
146+
147+ out = adapter .to_hf (state_dict , device_mesh = device_mesh )
148+ gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj"
149+ global_gate = out [gate_key ]
150+
151+ assert global_gate .shape == (adapter .moe_config .n_routed_experts , 2 , 2 )
152+ # Experts 1 and 2 should be populated from local_experts; others remain zero
153+ torch .testing .assert_close (global_gate [1 :3 ], local_experts )
154+ assert torch .all (global_gate [0 ] == 0 )
155+ assert torch .all (global_gate [3 ] == 0 )
156+
157+
158+ def test_aggregates_dtensor_path_uses_split_helper (self , adapter , monkeypatch ):
159+ local_slice = torch .tensor ([[9.0 , 10.0 ]], dtype = adapter .dtype ) # shape: [1, 2]
160+
161+ monkeypatch .setattr (
162+ "nemo_automodel.components.moe.state_dict_utils.is_dtensor" , lambda tensor : True
163+ )
164+ monkeypatch .setattr (
165+ "nemo_automodel.components.moe.state_dict_utils.split_experts_weights_dtensor_aware" ,
166+ lambda weight , n_experts : ([local_slice ], [2 ]),
167+ )
168+ monkeypatch .setattr ("torch.distributed.is_initialized" , lambda : False )
169+
170+ device_mesh = Mock ()
171+ device_mesh .mesh_dim_names = ["ep" ]
172+
173+ state_dict = {
174+ "model.language_model.layers.0.mlp.experts.down_projs" : torch .empty (1 , 1 , 2 ),
175+ }
176+
177+ out = adapter .to_hf (state_dict , device_mesh = device_mesh )
178+ down_key = "model.language_model.layers.0.mlp.experts.down_proj"
179+ global_down = out [down_key ]
180+
181+ assert global_down .shape [0 ] == adapter .moe_config .n_routed_experts
182+ torch .testing .assert_close (global_down [2 ], local_slice )
183+
184+ def test_all_gather_path_populates_global_tensor (self , adapter , monkeypatch ):
185+ # Local shard has experts 0 and 1; simulate another rank providing experts 2 and 3
186+ local_experts = torch .tensor (
187+ [
188+ [[1.0 ]],
189+ [[2.0 ]],
190+ ],
191+ dtype = adapter .dtype ,
192+ ) # shape: [2, 1, 1]
193+
194+ device_mesh = Mock ()
195+ device_mesh .mesh_dim_names = ["ep" ]
196+ device_mesh .get_group = lambda dim : "ep_group" if dim == 0 else None
197+
198+ monkeypatch .setattr (
199+ "nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh" ,
200+ lambda mesh , n_experts : (0 , 2 ),
201+ )
202+ monkeypatch .setattr ("torch.distributed.is_initialized" , lambda : True )
203+ monkeypatch .setattr ("torch.distributed.get_world_size" , lambda group = None : 2 )
204+
205+ def fake_all_gather_object (gathered , payload , group = None ):
206+ # payload from this rank for experts [0,1]; simulate other rank with [2,3]
207+ gathered [0 ] = payload
208+ other_weights = [torch .tensor ([[3.0 ]], dtype = adapter .dtype ), torch .tensor ([[4.0 ]], dtype = adapter .dtype )]
209+ gathered [1 ] = ([2 , 3 ], other_weights )
210+
211+ monkeypatch .setattr ("torch.distributed.all_gather_object" , fake_all_gather_object )
212+
213+ state_dict = {"model.language_model.layers.0.mlp.experts.gate_and_up_projs" : local_experts }
214+ out = adapter .to_hf (state_dict , device_mesh = device_mesh )
215+
216+ gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj"
217+ global_gate = out [gate_key ]
218+
219+ assert global_gate .shape == (adapter .moe_config .n_routed_experts , 1 , 1 )
220+ torch .testing .assert_close (global_gate [0 ], torch .tensor ([[1.0 ]], dtype = adapter .dtype ))
221+ torch .testing .assert_close (global_gate [1 ], torch .tensor ([[2.0 ]], dtype = adapter .dtype ))
222+ torch .testing .assert_close (global_gate [2 ], torch .tensor ([[3.0 ]], dtype = adapter .dtype ))
223+ torch .testing .assert_close (global_gate [3 ], torch .tensor ([[4.0 ]], dtype = adapter .dtype ))
224+
225+
123226class TestFromHF :
124227 def test_detects_model_prefix (self , adapter ):
125228 hf_state = {
@@ -173,6 +276,9 @@ def __init__(self, data):
173276 def to_local (self ):
174277 return self ._data
175278
279+ def __getitem__ (self , idx ):
280+ return self ._data [idx ]
281+
176282 captured = {"locals" : []}
177283
178284 monkeypatch .setattr (
0 commit comments