@@ -1159,7 +1159,7 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi
11591159 ["norm.weight" ],
11601160 ]
11611161 # last one layer contains MTP (eagle) parameters for inference
1162- for layer_index in range (config .num_hidden_layers + 1 ):
1162+ for layer_index in range (config .num_hidden_layers + config . num_nextn_predict_layers ):
11631163 layer_mappings = [
11641164 [f"layers.{ layer_index } .self_attn.q_proj.weight" , None , "transpose" ],
11651165 [f"layers.{ layer_index } .self_attn.q_a_proj.weight" , None , "transpose" ],
@@ -1192,7 +1192,7 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi
11921192 model_mappings .append ([f"layers.{ layer_index } .mlp.shared_experts.down_proj.weight" , None , "transpose" ])
11931193
11941194 # MTP (eagle) parameters for inference
1195- if layer_index = = config .num_hidden_layers :
1195+ if layer_index > = config .num_hidden_layers :
11961196 model_mappings .append ([f"layers.{ layer_index } .embed_tokens.weight" ])
11971197 model_mappings .append ([f"layers.{ layer_index } .enorm.weight" ])
11981198 model_mappings .append ([f"layers.{ layer_index } .hnorm.weight" ])
@@ -1270,7 +1270,10 @@ def get_tensor_parallel_split_mappings(num_layers):
12701270 base_actions ["layers.0.shared_head.head.weight" ] = partial (fn , is_column = True )
12711271 for key , action in base_actions .items ():
12721272 if "layers.0." in key :
1273- final_actions [key .replace ("layers.0." , f"layers.{ config .num_hidden_layers } ." )] = action
1273+ for i in range (
1274+ config .num_hidden_layers , config .num_hidden_layers + config .num_nextn_predict_layers
1275+ ):
1276+ final_actions [key .replace ("layers.0." , f"layers.{ i } ." )] = action
12741277 else :
12751278 final_actions [key ] = action
12761279
0 commit comments