@@ -280,7 +280,14 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
280280 json .dump (record , f , indent = 4 )
281281
282282 def do_update_expert_map (self , layer_id , updated_expert_map ):
283- self .expert_map_per_layer [layer_id ].copy_ (updated_expert_map )
283+ pad_len = self .expert_map_per_layer [layer_id ].shape [0 ] - updated_expert_map .shape [0 ]
284+ updated_expert_map_padded = torch .nn .functional .pad (
285+ updated_expert_map ,
286+ pad = (0 ,pad_len ),
287+ mode = 'constant' ,
288+ value = - 1
289+ )
290+ self .expert_map_per_layer [layer_id ].copy_ (updated_expert_map_padded )
284291 self .expert_map_per_layer_cpu [layer_id ].copy_ (updated_expert_map )
285292
286293 def do_update_expert_weight (self , layer_id , local_expert_to_replace ,
@@ -293,7 +300,14 @@ def do_update_expert_weight(self, layer_id, local_expert_to_replace,
293300
294301 def do_update_log2phy_map (self , layer_id , updated_log2phy_map ):
295302 if self .log2phy_map_per_layer [layer_id ] is not None :
296- self .log2phy_map_per_layer [layer_id ].copy_ (updated_log2phy_map )
303+ pad_len = self .log2phy_map_per_layer [layer_id ].shape [0 ] - updated_log2phy_map .shape [0 ]
304+ updated_log2phy_map_padded = torch .nn .functional .pad (
305+ updated_log2phy_map ,
306+ pad = (0 ,pad_len ),
307+ mode = 'constant' ,
308+ value = - 1
309+ )
310+ self .log2phy_map_per_layer [layer_id ].copy_ (updated_log2phy_map_padded )
297311
298312 def global2local (self , placement : torch .Tensor ,
299313 E_local : int ) -> torch .Tensor :
0 commit comments