@@ -181,6 +181,47 @@ def test_aggregates_dtensor_path_uses_split_helper(self, adapter, monkeypatch):
181181 assert global_down .shape [0 ] == adapter .moe_config .n_routed_experts
182182 torch .testing .assert_close (global_down [2 ], local_slice )
183183
184+ def test_all_gather_path_populates_global_tensor (self , adapter , monkeypatch ):
185+ # Local shard has experts 0 and 1; simulate another rank providing experts 2 and 3
186+ local_experts = torch .tensor (
187+ [
188+ [[1.0 ]],
189+ [[2.0 ]],
190+ ],
191+ dtype = adapter .dtype ,
192+ ) # shape: [2, 1, 1]
193+
194+ device_mesh = Mock ()
195+ device_mesh .mesh_dim_names = ["ep" ]
196+ device_mesh .get_group = lambda dim : "ep_group" if dim == 0 else None
197+
198+ monkeypatch .setattr (
199+ "nemo_automodel.components.moe.state_dict_utils.get_expert_range_for_rank_from_mesh" ,
200+ lambda mesh , n_experts : (0 , 2 ),
201+ )
202+ monkeypatch .setattr ("torch.distributed.is_initialized" , lambda : True )
203+ monkeypatch .setattr ("torch.distributed.get_world_size" , lambda group = None : 2 )
204+
205+ def fake_all_gather_object (gathered , payload , group = None ):
206+ # payload from this rank for experts [0,1]; simulate other rank with [2,3]
207+ gathered [0 ] = payload
208+ other_weights = [torch .tensor ([[3.0 ]], dtype = adapter .dtype ), torch .tensor ([[4.0 ]], dtype = adapter .dtype )]
209+ gathered [1 ] = ([2 , 3 ], other_weights )
210+
211+ monkeypatch .setattr ("torch.distributed.all_gather_object" , fake_all_gather_object )
212+
213+ state_dict = {"model.language_model.layers.0.mlp.experts.gate_and_up_projs" : local_experts }
214+ out = adapter .to_hf (state_dict , device_mesh = device_mesh )
215+
216+ gate_key = "model.language_model.layers.0.mlp.experts.gate_up_proj"
217+ global_gate = out [gate_key ]
218+
219+ assert global_gate .shape == (adapter .moe_config .n_routed_experts , 1 , 1 )
220+ torch .testing .assert_close (global_gate [0 ], torch .tensor ([[1.0 ]], dtype = adapter .dtype ))
221+ torch .testing .assert_close (global_gate [1 ], torch .tensor ([[2.0 ]], dtype = adapter .dtype ))
222+ torch .testing .assert_close (global_gate [2 ], torch .tensor ([[3.0 ]], dtype = adapter .dtype ))
223+ torch .testing .assert_close (global_gate [3 ], torch .tensor ([[4.0 ]], dtype = adapter .dtype ))
224+
184225
185226class TestFromHF :
186227 def test_detects_model_prefix (self , adapter ):
0 commit comments