@@ -283,19 +283,14 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
283283
284284 self .hidden_size = config .hidden_size
285285 self .num_heads = config .num_attention_heads
286- self .head_dim = self . hidden_size // self . num_heads
286+ self .head_dim = config . head_dim
287287 self .num_key_value_heads = config .num_key_value_heads
288288 self .num_key_value_groups = self .num_heads // self .num_key_value_heads
289289 self .max_position_embeddings = config .max_position_embeddings
290290 self .rope_theta = config .rope_theta
291291 self .is_causal = True
292292 self .attention_dropout = config .attention_dropout
293293
294- if (self .head_dim * self .num_heads ) != self .hidden_size :
295- raise ValueError (
296- f"hidden_size must be divisible by num_heads (got `hidden_size`: { self .hidden_size } "
297- f" and `num_heads`: { self .num_heads } )."
298- )
299294 self .q_proj = nn .Linear (self .hidden_size , self .num_heads * self .head_dim , bias = False )
300295 self .k_proj = nn .Linear (self .hidden_size , self .num_key_value_heads * self .head_dim , bias = False )
301296 self .v_proj = nn .Linear (self .hidden_size , self .num_key_value_heads * self .head_dim , bias = False )
@@ -374,7 +369,7 @@ def forward(
374369 )
375370
376371 attn_output = attn_output .transpose (1 , 2 ).contiguous ()
377- attn_output = attn_output .reshape (bsz , q_len , self . hidden_size )
372+ attn_output = attn_output .reshape (bsz , q_len , - 1 )
378373
379374 attn_output = self .o_proj (attn_output )
380375
@@ -481,7 +476,7 @@ def forward(
481476 is_causal = self .is_causal ,
482477 )
483478
484- attn_output = attn_output .reshape (bsz , q_len , self . hidden_size ).contiguous ()
479+ attn_output = attn_output .reshape (bsz , q_len , - 1 ).contiguous ()
485480 attn_output = self .o_proj (attn_output )
486481
487482 if not output_attentions :
@@ -575,7 +570,7 @@ def forward(
575570 )
576571
577572 attn_output = attn_output .transpose (1 , 2 ).contiguous ()
578- attn_output = attn_output .view (bsz , q_len , self . hidden_size )
573+ attn_output = attn_output .view (bsz , q_len , - 1 )
579574
580575 attn_output = self .o_proj (attn_output )
581576
0 commit comments