2828
2929class VllmEplbAdaptor (EplbAdaptor ):
3030
31- def __init__ (self , model , ** args ):
31+ def __init__ (self , model , mtp_instance = None , num_mtp_layers = 0 , ** args ):
3232 super ().__init__ (** args )
3333 self .model = model
3434 self .rank_id = dist .get_rank ()
3535 self .world_size = dist .get_world_size ()
3636 self .param_dict = dict (self .model .named_parameters ())
37+ self .mtp_instance = mtp_instance
38+ self .num_mtp_layers = num_mtp_layers
3739 if self .model .config .model_type == "qwen3_moe" :
3840 self .num_dense_layers = 0
3941 self .global_expert_num = self .model .config .num_experts
4042 else :
4143 self .num_dense_layers = self .model .config .first_k_dense_replace
4244 self .global_expert_num = self .model .config .n_routed_experts
43- self .num_moe_layers = self .model .config .num_hidden_layers - self .num_dense_layers
45+ self .num_moe_layers = self .model .config .num_hidden_layers - self .num_dense_layers # MTP not included
4446 self .init_redundancy_expert = get_ascend_config (
4547 ).init_redundancy_expert
4648
@@ -64,6 +66,18 @@ def __init__(self, model, **args):
6466 else :
6567 self .expert_weight_names = ["w13_weight" , "w2_weight" ]
6668
69+ if self .mtp_instance is not None :
70+ if any ("w13_weight_offset" in name
71+ for name , _ in self .mtp_instance .named_parameters ()):
72+ self .mtp_expert_weight_names = [
73+ "w13_weight" , "w2_weight" , "w13_weight_scale" ,
74+ "w13_weight_offset" , "w2_weight_scale" , "w2_weight_offset"
75+ ]
76+ else :
77+ self .mtp_expert_weight_names = ["w13_weight" , "w2_weight" ]
78+ else :
79+ self .mtp_expert_weight_names = []
80+
6781 self .expert_map_per_layer = dict (
6882 ) # reference to expert map on device for expert map update
6983 self .expert_map_per_layer_cpu = dict (
@@ -72,6 +86,12 @@ def __init__(self, model, **args):
7286 self .expert_map_per_layer [self .num_dense_layers + layer_idx ] = \
7387 self .model .get_expert_map (self .num_dense_layers + layer_idx )
7488
89+ # Currently, MTP only support one layer.
90+ if self .mtp_instance is not None :
91+ for mtp_layer_idx in range (self .num_mtp_layers ):
92+ self .expert_map_per_layer [self .num_dense_layers + self .num_moe_layers + mtp_layer_idx ] = \
93+ self .mtp_instance .model .get_expert_map (self .num_dense_layers + self .num_moe_layers + mtp_layer_idx )
94+
7595 # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
7696 num_buffer_tensor = torch .where (
7797 self .expert_map_per_layer [self .num_dense_layers ] != - 1 )[0 ].numel ()
@@ -88,6 +108,11 @@ def __init__(self, model, **args):
88108 self .log2phy_map_per_layer [self .num_dense_layers + layer_idx ] = \
89109 self .model .get_log2phy_map (self .num_dense_layers + layer_idx )
90110
111+ if self .mtp_instance is not None :
112+ for mtp_layer_idx in range (self .num_mtp_layers ):
113+ self .log2phy_map_per_layer [self .num_dense_layers + self .num_moe_layers + mtp_layer_idx ] = \
114+ self .mtp_instance .model .get_log2phy_map (self .num_dense_layers + self .num_moe_layers + mtp_layer_idx )
115+
91116 self .all_topk_ids = []
92117
93118 def init_buffer_tensor (self , num_buffer_tensor ):
@@ -131,12 +156,46 @@ def init_expert_param_per_layer(self):
131156 name ][0 ].data [local_expert_id ])
132157 self .expert_param_per_layer [layer_idx ].append (per_expert_param )
133158
159+ if self .mtp_instance is not None :
160+ mtp_param_dict = dict (self .mtp_instance .named_parameters ())
161+ for mtp_layer_idx in range (self .num_mtp_layers ):
162+ self .expert_param_per_layer [self .num_dense_layers +
163+ self .num_moe_layers +
164+ mtp_layer_idx ] = list ()
165+ for local_expert_id in range (num_local_expert ):
166+ for mtp_layer_idx in range (self .num_mtp_layers ):
167+ self .expert_param_per_layer [
168+ self .num_dense_layers + self .num_moe_layers +
169+ mtp_layer_idx ].append ([
170+ mtp_param_dict ["model.layers." +
171+ str (self .num_dense_layers +
172+ self .num_moe_layers +
173+ mtp_layer_idx ) +
174+ ".mtp_block.mlp.experts." +
175+ name ].data [local_expert_id ]
176+ for name in self .mtp_expert_weight_names
177+ ])
178+
134179 def get_rank_expert_workload (self ) -> torch .Tensor :
135180 self .moe_load = self .model .get_all_moe_loads ()
181+ if self .mtp_instance is not None :
182+ self .moe_load = torch .cat ([
183+ self .moe_load ,
184+ self .mtp_instance .model .get_all_moe_loads ().to (
185+ device = self .moe_load .device )
186+ ],
187+ dim = 0 )
136188 return self .moe_load
137189
138190 def get_init_expert_map (self , num_moe_layers ):
139191 expert_map = self .model .get_all_expert_map (num_moe_layers )
192+ if self .mtp_instance is not None :
193+ expert_map = torch .cat ([
194+ expert_map ,
195+ self .mtp_instance .model .get_all_expert_map ().to (
196+ device = expert_map .device )
197+ ],
198+ dim = 0 )
140199 if dist .is_initialized ():
141200 world_size = dist .get_world_size ()
142201
@@ -288,7 +347,9 @@ def determine_expert_map_all(self):
288347 local_num_experts = self .global_expert_num // self .world_size
289348
290349 expert_map_all = torch .full (
291- (self .num_moe_layers , self .world_size , self .global_expert_num ),
350+ (self .num_moe_layers if self .mtp_instance is None else
351+ (self .num_moe_layers + self .num_mtp_layers ), self .world_size ,
352+ self .global_expert_num ),
292353 - 1 ,
293354 dtype = torch .int32 )
294355
@@ -311,6 +372,7 @@ def determine_expert_map_all(self):
311372
312373 local_ids = torch .arange (local_count , dtype = torch .int32 )
313374 expert_map_all [:, r , start :end ] = local_ids .unsqueeze (0 ).expand (
314- self .num_moe_layers , - 1 )
375+ self .num_moe_layers if self .mtp_instance is None else
376+ (self .num_moe_layers + self .num_mtp_layers ), - 1 )
315377
316- return expert_map_all
378+ return expert_map_all
0 commit comments