Skip to content

Commit c3e3c5e

Browse files
committed
Generate modeling_mixtral.py from modular_mixtral.py
- Auto-generate modeling_mixtral.py with the same fix - Apply black formatting - Fix repository consistency check
1 parent e297322 commit c3e3c5e

File tree

1 file changed

+146
-51
lines changed

1 file changed

+146
-51
lines changed

src/transformers/models/mixtral/modeling_mixtral.py

Lines changed: 146 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def __init__(self, config: MixtralConfig):
6767
self.act_fn = ACT2FN[config.hidden_act]
6868

6969
def forward(self, hidden_states):
70-
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
70+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
71+
hidden_states
72+
)
7173
current_hidden_states = self.w2(current_hidden_states)
7274
return current_hidden_states
7375

@@ -94,7 +96,9 @@ def __init__(self, config):
9496
# gating
9597
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
9698

97-
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
99+
self.experts = nn.ModuleList(
100+
[MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]
101+
)
98102

99103
# Jitter parameters
100104
self.jitter_noise = config.router_jitter_noise
@@ -103,39 +107,53 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
103107
""" """
104108
batch_size, sequence_length, hidden_dim = hidden_states.shape
105109
if self.training and self.jitter_noise > 0:
106-
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
110+
hidden_states *= torch.empty_like(hidden_states).uniform_(
111+
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
112+
)
107113
hidden_states = hidden_states.view(-1, hidden_dim)
108114
# router_logits: (batch * sequence_length, n_experts)
109115
router_logits = self.gate(hidden_states)
110116

111117
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
112-
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
118+
routing_weights, selected_experts = torch.topk(
119+
routing_weights, self.top_k, dim=-1
120+
)
113121
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
114122
# we cast back to the input dtype
115123
routing_weights = routing_weights.to(hidden_states.dtype)
116124

117125
final_hidden_states = torch.zeros(
118-
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
126+
(batch_size * sequence_length, hidden_dim),
127+
dtype=hidden_states.dtype,
128+
device=hidden_states.device,
119129
)
120130

121131
# One hot encode the selected experts to create an expert mask
122132
# this will be used to easily index which expert is going to be sollicitated
123-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
133+
expert_mask = torch.nn.functional.one_hot(
134+
selected_experts, num_classes=self.num_experts
135+
).permute(2, 1, 0)
124136

125-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
126-
for expert_idx in expert_hit:
137+
# Loop over all available experts in the model and perform the computation on each expert
138+
for expert_idx in range(self.num_experts):
127139
expert_layer = self.experts[expert_idx]
128-
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
140+
idx, top_x = torch.where(expert_mask[expert_idx])
129141
# Index the correct hidden states and compute the expert hidden state for
130142
# the current expert. We need to make sure to multiply the output hidden
131143
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
132144
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
133-
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
145+
current_hidden_states = (
146+
expert_layer(current_state) * routing_weights[top_x, idx, None]
147+
)
134148

135149
# However `index_add_` only support torch tensors for indexing so we'll use
136150
# the `top_x` tensor here.
137-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
138-
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
151+
final_hidden_states.index_add_(
152+
0, top_x, current_hidden_states.to(hidden_states.dtype)
153+
)
154+
final_hidden_states = final_hidden_states.reshape(
155+
batch_size, sequence_length, hidden_dim
156+
)
139157
return final_hidden_states, router_logits
140158

141159

@@ -202,7 +220,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
202220
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
203221
if n_rep == 1:
204222
return hidden_states
205-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
223+
hidden_states = hidden_states[:, :, None, :, :].expand(
224+
batch, num_key_value_heads, n_rep, slen, head_dim
225+
)
206226
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
207227

208228

@@ -224,8 +244,12 @@ def eager_attention_forward(
224244
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
225245
attn_weights = attn_weights + causal_mask
226246

227-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
228-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
247+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
248+
query.dtype
249+
)
250+
attn_weights = nn.functional.dropout(
251+
attn_weights, p=dropout, training=module.training
252+
)
229253
attn_output = torch.matmul(attn_weights, value_states)
230254
attn_output = attn_output.transpose(1, 2).contiguous()
231255

@@ -239,15 +263,28 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
239263
super().__init__()
240264
self.config = config
241265
self.layer_idx = layer_idx
242-
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
243-
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
266+
self.head_dim = (
267+
getattr(config, "head_dim", None)
268+
or config.hidden_size // config.num_attention_heads
269+
)
270+
self.num_key_value_groups = (
271+
config.num_attention_heads // config.num_key_value_heads
272+
)
244273
self.scaling = self.head_dim**-0.5
245274
self.attention_dropout = config.attention_dropout
246275
self.is_causal = True
247-
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
248-
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
249-
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
250-
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
276+
self.q_proj = nn.Linear(
277+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=False
278+
)
279+
self.k_proj = nn.Linear(
280+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
281+
)
282+
self.v_proj = nn.Linear(
283+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
284+
)
285+
self.o_proj = nn.Linear(
286+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
287+
)
251288

252289
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
253290
def forward(
@@ -267,16 +304,22 @@ def forward(
267304
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
268305

269306
cos, sin = position_embeddings
270-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
307+
query_states, key_states = apply_rotary_pos_emb(
308+
query_states, key_states, cos, sin
309+
)
271310

272311
if past_key_values is not None:
273312
# sin and cos are specific to RoPE models; cache_position needed for the static cache
274313
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
275-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
314+
key_states, value_states = past_key_values.update(
315+
key_states, value_states, self.layer_idx, cache_kwargs
316+
)
276317

277318
attention_interface: Callable = eager_attention_forward
278319
if self.config._attn_implementation != "eager":
279-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
320+
attention_interface = ALL_ATTENTION_FUNCTIONS[
321+
self.config._attn_implementation
322+
]
280323

281324
attn_output, attn_weights = attention_interface(
282325
self,
@@ -286,7 +329,9 @@ def forward(
286329
attention_mask,
287330
dropout=0.0 if not self.training else self.attention_dropout,
288331
scaling=self.scaling,
289-
sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
332+
sliding_window=getattr(
333+
self.config, "sliding_window", None
334+
), # main diff with Llama
290335
**kwargs,
291336
)
292337

@@ -303,8 +348,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
303348
self.self_attn = MixtralAttention(config, layer_idx)
304349

305350
self.block_sparse_moe = MixtralSparseMoeBlock(config)
306-
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
307-
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
351+
self.input_layernorm = MixtralRMSNorm(
352+
config.hidden_size, eps=config.rms_norm_eps
353+
)
354+
self.post_attention_layernorm = MixtralRMSNorm(
355+
config.hidden_size, eps=config.rms_norm_eps
356+
)
308357

309358
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
310359
def forward(
@@ -349,7 +398,9 @@ def __init__(self, config: MixtralConfig, device=None):
349398
super().__init__()
350399
# BC: "rope_type" was originally "type"
351400
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
352-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
401+
self.rope_type = config.rope_scaling.get(
402+
"rope_type", config.rope_scaling.get("type")
403+
)
353404
else:
354405
self.rope_type = "default"
355406
self.max_seq_len_cached = config.max_position_embeddings
@@ -365,12 +416,23 @@ def __init__(self, config: MixtralConfig, device=None):
365416
@torch.no_grad()
366417
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
367418
def forward(self, x, position_ids):
368-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
419+
inv_freq_expanded = (
420+
self.inv_freq[None, :, None]
421+
.float()
422+
.expand(position_ids.shape[0], -1, 1)
423+
.to(x.device)
424+
)
369425
position_ids_expanded = position_ids[:, None, :].float()
370426

371-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
427+
device_type = (
428+
x.device.type
429+
if isinstance(x.device.type, str) and x.device.type != "mps"
430+
else "cpu"
431+
)
372432
with torch.autocast(device_type=device_type, enabled=False): # Force float32
373-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
433+
freqs = (
434+
inv_freq_expanded.float() @ position_ids_expanded.float()
435+
).transpose(1, 2)
374436
emb = torch.cat((freqs, freqs), dim=-1)
375437
cos = emb.cos() * self.attention_scaling
376438
sin = emb.sin() * self.attention_scaling
@@ -404,9 +466,14 @@ def __init__(self, config: MixtralConfig):
404466
self.padding_idx = config.pad_token_id
405467
self.vocab_size = config.vocab_size
406468

407-
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
469+
self.embed_tokens = nn.Embedding(
470+
config.vocab_size, config.hidden_size, self.padding_idx
471+
)
408472
self.layers = nn.ModuleList(
409-
[MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
473+
[
474+
MixtralDecoderLayer(config, layer_idx)
475+
for layer_idx in range(config.num_hidden_layers)
476+
]
410477
)
411478
self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
412479
self.rotary_emb = MixtralRotaryEmbedding(config=config)
@@ -429,7 +496,9 @@ def forward(
429496
**kwargs: Unpack[TransformersKwargs],
430497
) -> MoeModelOutputWithPast:
431498
if (input_ids is None) ^ (inputs_embeds is not None):
432-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
499+
raise ValueError(
500+
"You must specify exactly one of input_ids or inputs_embeds"
501+
)
433502

434503
if use_cache and past_key_values is None:
435504
past_key_values = DynamicCache(config=self.config)
@@ -438,14 +507,22 @@ def forward(
438507
inputs_embeds = self.embed_tokens(input_ids)
439508

440509
if cache_position is None:
441-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
510+
past_seen_tokens = (
511+
past_key_values.get_seq_length() if past_key_values is not None else 0
512+
)
442513
cache_position = torch.arange(
443-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
514+
past_seen_tokens,
515+
past_seen_tokens + inputs_embeds.shape[1],
516+
device=inputs_embeds.device,
444517
)
445518
if position_ids is None:
446519
position_ids = cache_position.unsqueeze(0)
447520

448-
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
521+
mask_function = (
522+
create_causal_mask
523+
if self.config.sliding_window is None
524+
else create_sliding_window_causal_mask
525+
)
449526
causal_mask = mask_function(
450527
config=self.config,
451528
input_embeds=inputs_embeds,
@@ -514,7 +591,9 @@ def load_balancing_loss_func(
514591

515592
if isinstance(gate_logits, tuple):
516593
compute_device = gate_logits[0].device
517-
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
594+
concatenated_gate_logits = torch.cat(
595+
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
596+
)
518597

519598
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
520599

@@ -530,20 +609,24 @@ def load_balancing_loss_func(
530609
router_prob_per_expert = torch.mean(routing_weights, dim=0)
531610
else:
532611
batch_size, sequence_length = attention_mask.shape
533-
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
612+
num_hidden_layers = concatenated_gate_logits.shape[0] // (
613+
batch_size * sequence_length
614+
)
534615

535616
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
536617
expert_attention_mask = (
537618
attention_mask[None, :, :, None, None]
538-
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
619+
.expand(
620+
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
621+
)
539622
.reshape(-1, top_k, num_experts)
540623
.to(compute_device)
541624
)
542625

543626
# Compute the percentage of tokens routed to each experts
544-
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
545-
expert_attention_mask, dim=0
546-
)
627+
tokens_per_expert = torch.sum(
628+
expert_mask.float() * expert_attention_mask, dim=0
629+
) / torch.sum(expert_attention_mask, dim=0)
547630

548631
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
549632
router_per_expert_attention_mask = (
@@ -554,9 +637,9 @@ def load_balancing_loss_func(
554637
)
555638

556639
# Compute the average probability of routing to these experts
557-
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
558-
router_per_expert_attention_mask, dim=0
559-
)
640+
router_prob_per_expert = torch.sum(
641+
routing_weights * router_per_expert_attention_mask, dim=0
642+
) / torch.sum(router_per_expert_attention_mask, dim=0)
560643

561644
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
562645
return overall_loss * num_experts
@@ -626,7 +709,9 @@ def forward(
626709
```"""
627710

628711
output_router_logits = (
629-
output_router_logits if output_router_logits is not None else self.config.output_router_logits
712+
output_router_logits
713+
if output_router_logits is not None
714+
else self.config.output_router_logits
630715
)
631716

632717
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@@ -644,7 +729,11 @@ def forward(
644729

645730
hidden_states = outputs.last_hidden_state
646731
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
647-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
732+
slice_indices = (
733+
slice(-logits_to_keep, None)
734+
if isinstance(logits_to_keep, int)
735+
else logits_to_keep
736+
)
648737
logits = self.lm_head(hidden_states[:, slice_indices, :])
649738

650739
loss = None
@@ -660,7 +749,9 @@ def forward(
660749
attention_mask,
661750
)
662751
if labels is not None:
663-
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
752+
loss += self.router_aux_loss_coef * aux_loss.to(
753+
loss.device
754+
) # make sure to reside in the same device
664755

665756
return MoeCausalLMOutputWithPast(
666757
loss=loss,
@@ -673,11 +764,15 @@ def forward(
673764
)
674765

675766

676-
class MixtralForSequenceClassification(GenericForSequenceClassification, MixtralPreTrainedModel):
767+
class MixtralForSequenceClassification(
768+
GenericForSequenceClassification, MixtralPreTrainedModel
769+
):
677770
pass
678771

679772

680-
class MixtralForTokenClassification(GenericForTokenClassification, MixtralPreTrainedModel):
773+
class MixtralForTokenClassification(
774+
GenericForTokenClassification, MixtralPreTrainedModel
775+
):
681776
pass
682777

683778

0 commit comments

Comments
 (0)