@@ -856,6 +856,9 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
856856
857857 self .o_proj = RowParallelLinear (self .num_heads * self .v_head_dim , self .hidden_size , has_bias = config .attention_bias , input_is_parallel = True )
858858
859+ assert self .num_heads % config .tensor_parallel_degree == 0 , f"num_heads: { self .num_heads } , tensor_parallel_degree: { config .tensor_parallel_degree } "
860+ self .num_heads = self .num_heads // config .tensor_parallel_degree
861+
859862 else :
860863 # for without tensor parallel
861864 if self .q_lora_rank is None :
@@ -1228,12 +1231,15 @@ def get_tensor_parallel_split_mappings(num_layers):
12281231 # Column Linear
12291232 base_actions ["layers.0.self_attn.q_proj.weight" ] = partial (fn , is_column = True )
12301233 base_actions ["layers.0.self_attn.q_proj.bias" ] = partial (fn , is_column = True )
1234+ base_actions ["layers.0.self_attn.q_b_proj.weight" ] = partial (fn , is_column = True )
1235+
12311236 # if we have enough num_key_value_heads to split, then split it.
12321237 if config .num_key_value_heads % config .tensor_parallel_degree == 0 :
12331238 base_actions ["layers.0.self_attn.k_proj.weight" ] = partial (fn , is_column = True )
12341239 base_actions ["layers.0.self_attn.v_proj.weight" ] = partial (fn , is_column = True )
12351240 base_actions ["layers.0.self_attn.k_proj.bias" ] = partial (fn , is_column = True )
12361241 base_actions ["layers.0.self_attn.v_proj.bias" ] = partial (fn , is_column = True )
1242+ base_actions ["layers.0.self_attn.kv_b_proj.weight" ] = partial (fn , is_column = True )
12371243
12381244 base_actions ["layers.0.mlp.up_proj.weight" ] = partial (fn , is_column = True )
12391245 base_actions ["layers.0.mlp.gate_proj.weight" ] = partial (fn , is_column = True )
@@ -1625,9 +1631,7 @@ def forward(self, hidden_states, tensor_parallel_output=None):
16251631 if tensor_parallel_output is None :
16261632 tensor_parallel_output = self .config .tensor_parallel_output
16271633
1628- logits = parallel_matmul (
1629- hidden_states , self .weight , transpose_y = False , tensor_parallel_output = tensor_parallel_output
1630- )
1634+ logits = parallel_matmul (hidden_states , self .weight , tensor_parallel_output = tensor_parallel_output )
16311635 return logits
16321636
16331637
@@ -1639,7 +1643,7 @@ def __init__(self, config: DeepseekV2Config):
16391643 self .config = config
16401644 self .deepseek_v2 = DeepseekV2Model (config )
16411645 self .vocab_size = config .vocab_size
1642- self .lm_head = nn . Linear (config . hidden_size , config . vocab_size , bias_attr = False )
1646+ self .lm_head = DeepSeekV2LMHead (config )
16431647 self .criterion = DeepSeekV2PretrainingCriterion (config )
16441648
16451649 def get_input_embeddings (self ):
0 commit comments