Skip to content

Commit 5a2753d

Browse files
wavy-jungBernardZach
authored andcommitted
manual head_dim for mixtral model (huggingface#34281)
1 parent 0dd38e9 commit 5a2753d

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

src/transformers/models/mixtral/configuration_mixtral.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class MixtralConfig(PretrainedConfig):
5353
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
5454
by meanpooling all the original heads within that group. For more details checkout [this
5555
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
56+
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
57+
The attention head dimension.
5658
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
5759
The non-linear activation function (function or string) in the decoder.
5860
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
@@ -116,6 +118,7 @@ def __init__(
116118
num_hidden_layers=32,
117119
num_attention_heads=32,
118120
num_key_value_heads=8,
121+
head_dim=None,
119122
hidden_act="silu",
120123
max_position_embeddings=4096 * 32,
121124
initializer_range=0.02,
@@ -154,6 +157,7 @@ def __init__(
154157
self.use_cache = use_cache
155158
self.rope_theta = rope_theta
156159
self.attention_dropout = attention_dropout
160+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
157161

158162
self.num_experts_per_tok = num_experts_per_tok
159163
self.num_local_experts = num_local_experts

src/transformers/models/mixtral/modeling_mixtral.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -283,19 +283,14 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
283283

284284
self.hidden_size = config.hidden_size
285285
self.num_heads = config.num_attention_heads
286-
self.head_dim = self.hidden_size // self.num_heads
286+
self.head_dim = config.head_dim
287287
self.num_key_value_heads = config.num_key_value_heads
288288
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
289289
self.max_position_embeddings = config.max_position_embeddings
290290
self.rope_theta = config.rope_theta
291291
self.is_causal = True
292292
self.attention_dropout = config.attention_dropout
293293

294-
if (self.head_dim * self.num_heads) != self.hidden_size:
295-
raise ValueError(
296-
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
297-
f" and `num_heads`: {self.num_heads})."
298-
)
299294
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
300295
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
301296
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
@@ -374,7 +369,7 @@ def forward(
374369
)
375370

376371
attn_output = attn_output.transpose(1, 2).contiguous()
377-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
372+
attn_output = attn_output.reshape(bsz, q_len, -1)
378373

379374
attn_output = self.o_proj(attn_output)
380375

@@ -481,7 +476,7 @@ def forward(
481476
is_causal=self.is_causal,
482477
)
483478

484-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
479+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
485480
attn_output = self.o_proj(attn_output)
486481

487482
if not output_attentions:
@@ -575,7 +570,7 @@ def forward(
575570
)
576571

577572
attn_output = attn_output.transpose(1, 2).contiguous()
578-
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
573+
attn_output = attn_output.view(bsz, q_len, -1)
579574

580575
attn_output = self.o_proj(attn_output)
581576

0 commit comments

Comments
 (0)