@@ -208,11 +208,16 @@ def __init__(
208208 self .embed_size_per_head = self .config .head_dim
209209 elif self .config .model_type == "gpt_bigcode" :
210210 self .embed_size_per_head = self .config .hidden_size // self .config .num_attention_heads * 2
211+ elif self .config .model_type == "deepseek_v3" :
212+ # For deepseek_v3, keys and values have different head dimensions
213+ self .qk_head_dim = self .config .qk_rope_head_dim + self .config .qk_nope_head_dim
214+ self .v_head_dim = self .config .v_head_dim
211215 else :
212216 self .embed_size_per_head = self .config .hidden_size // self .config .num_attention_heads
213217
214218 if self .config .model_type in {
215219 "arcee" ,
220+ "deepseek_v3" ,
216221 "cohere" ,
217222 "gemma" ,
218223 "helium" ,
@@ -345,6 +350,10 @@ def forward(
345350 v_shape = (batch_size * self .num_key_value_heads , 0 , self .embed_size_per_head )
346351 elif self .config .model_type == "gpt_bigcode" and self .config .multi_query :
347352 k_shape = v_shape = (batch_size , 0 , self .embed_size_per_head )
353+ elif self .config .model_type == "deepseek_v3" :
354+ # For deepseek_v3, keys and values have different head dimensions
355+ k_shape = (batch_size , self .num_key_value_heads , 0 , self .qk_head_dim )
356+ v_shape = (batch_size , self .num_key_value_heads , 0 , self .v_head_dim )
348357 else :
349358 k_shape = v_shape = (batch_size , self .num_key_value_heads , 0 , self .embed_size_per_head )
350359 k_tensor = torch .zeros (k_shape , dtype = self .dtype , device = self .device )
@@ -375,6 +384,10 @@ def forward(
375384 elif self .config .model_type == "gpt_bigcode" and self .config .multi_query :
376385 embed_size_per_head = past_key_values [0 ].shape [- 1 ]
377386 k_shape = v_shape = (batch_size , pkv_seq_len + seq_len , embed_size_per_head )
387+ elif self .config .model_type == "deepseek_v3" :
388+ # For deepseek_v3, keys and values have different head dimensions
389+ k_shape = (batch_size , self .num_key_value_heads , pkv_seq_len + seq_len , self .qk_head_dim )
390+ v_shape = (batch_size , self .num_key_value_heads , pkv_seq_len + seq_len , self .v_head_dim )
378391 else :
379392 embed_size_per_head = past_key_values [0 ].shape [- 1 ]
380393 k_shape = v_shape = (batch_size , self .num_key_value_heads , pkv_seq_len + seq_len , embed_size_per_head )
0 commit comments