@@ -67,7 +67,9 @@ def __init__(self, config: MixtralConfig):
67
67
self .act_fn = ACT2FN [config .hidden_act ]
68
68
69
69
def forward (self , hidden_states ):
70
- current_hidden_states = self .act_fn (self .w1 (hidden_states )) * self .w3 (hidden_states )
70
+ current_hidden_states = self .act_fn (self .w1 (hidden_states )) * self .w3 (
71
+ hidden_states
72
+ )
71
73
current_hidden_states = self .w2 (current_hidden_states )
72
74
return current_hidden_states
73
75
@@ -94,7 +96,9 @@ def __init__(self, config):
94
96
# gating
95
97
self .gate = nn .Linear (self .hidden_dim , self .num_experts , bias = False )
96
98
97
- self .experts = nn .ModuleList ([MixtralBlockSparseTop2MLP (config ) for _ in range (self .num_experts )])
99
+ self .experts = nn .ModuleList (
100
+ [MixtralBlockSparseTop2MLP (config ) for _ in range (self .num_experts )]
101
+ )
98
102
99
103
# Jitter parameters
100
104
self .jitter_noise = config .router_jitter_noise
@@ -103,39 +107,53 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
103
107
""" """
104
108
batch_size , sequence_length , hidden_dim = hidden_states .shape
105
109
if self .training and self .jitter_noise > 0 :
106
- hidden_states *= torch .empty_like (hidden_states ).uniform_ (1.0 - self .jitter_noise , 1.0 + self .jitter_noise )
110
+ hidden_states *= torch .empty_like (hidden_states ).uniform_ (
111
+ 1.0 - self .jitter_noise , 1.0 + self .jitter_noise
112
+ )
107
113
hidden_states = hidden_states .view (- 1 , hidden_dim )
108
114
# router_logits: (batch * sequence_length, n_experts)
109
115
router_logits = self .gate (hidden_states )
110
116
111
117
routing_weights = F .softmax (router_logits , dim = 1 , dtype = torch .float )
112
- routing_weights , selected_experts = torch .topk (routing_weights , self .top_k , dim = - 1 )
118
+ routing_weights , selected_experts = torch .topk (
119
+ routing_weights , self .top_k , dim = - 1
120
+ )
113
121
routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
114
122
# we cast back to the input dtype
115
123
routing_weights = routing_weights .to (hidden_states .dtype )
116
124
117
125
final_hidden_states = torch .zeros (
118
- (batch_size * sequence_length , hidden_dim ), dtype = hidden_states .dtype , device = hidden_states .device
126
+ (batch_size * sequence_length , hidden_dim ),
127
+ dtype = hidden_states .dtype ,
128
+ device = hidden_states .device ,
119
129
)
120
130
121
131
# One hot encode the selected experts to create an expert mask
122
132
# this will be used to easily index which expert is going to be sollicitated
123
- expert_mask = torch .nn .functional .one_hot (selected_experts , num_classes = self .num_experts ).permute (2 , 1 , 0 )
133
+ expert_mask = torch .nn .functional .one_hot (
134
+ selected_experts , num_classes = self .num_experts
135
+ ).permute (2 , 1 , 0 )
124
136
125
- expert_hit = torch . greater ( expert_mask . sum ( dim = ( - 1 , - 2 )), 0 ). nonzero ()
126
- for expert_idx in expert_hit :
137
+ # Loop over all available experts in the model and perform the computation on each expert
138
+ for expert_idx in range ( self . num_experts ) :
127
139
expert_layer = self .experts [expert_idx ]
128
- idx , top_x = torch .where (expert_mask [expert_idx ]. squeeze ( 0 ) )
140
+ idx , top_x = torch .where (expert_mask [expert_idx ])
129
141
# Index the correct hidden states and compute the expert hidden state for
130
142
# the current expert. We need to make sure to multiply the output hidden
131
143
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
132
144
current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
133
- current_hidden_states = expert_layer (current_state ) * routing_weights [top_x , idx , None ]
145
+ current_hidden_states = (
146
+ expert_layer (current_state ) * routing_weights [top_x , idx , None ]
147
+ )
134
148
135
149
# However `index_add_` only support torch tensors for indexing so we'll use
136
150
# the `top_x` tensor here.
137
- final_hidden_states .index_add_ (0 , top_x , current_hidden_states .to (hidden_states .dtype ))
138
- final_hidden_states = final_hidden_states .reshape (batch_size , sequence_length , hidden_dim )
151
+ final_hidden_states .index_add_ (
152
+ 0 , top_x , current_hidden_states .to (hidden_states .dtype )
153
+ )
154
+ final_hidden_states = final_hidden_states .reshape (
155
+ batch_size , sequence_length , hidden_dim
156
+ )
139
157
return final_hidden_states , router_logits
140
158
141
159
@@ -202,7 +220,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
202
220
batch , num_key_value_heads , slen , head_dim = hidden_states .shape
203
221
if n_rep == 1 :
204
222
return hidden_states
205
- hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
223
+ hidden_states = hidden_states [:, :, None , :, :].expand (
224
+ batch , num_key_value_heads , n_rep , slen , head_dim
225
+ )
206
226
return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
207
227
208
228
@@ -224,8 +244,12 @@ def eager_attention_forward(
224
244
causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
225
245
attn_weights = attn_weights + causal_mask
226
246
227
- attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
228
- attn_weights = nn .functional .dropout (attn_weights , p = dropout , training = module .training )
247
+ attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (
248
+ query .dtype
249
+ )
250
+ attn_weights = nn .functional .dropout (
251
+ attn_weights , p = dropout , training = module .training
252
+ )
229
253
attn_output = torch .matmul (attn_weights , value_states )
230
254
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
231
255
@@ -239,15 +263,28 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
239
263
super ().__init__ ()
240
264
self .config = config
241
265
self .layer_idx = layer_idx
242
- self .head_dim = getattr (config , "head_dim" , None ) or config .hidden_size // config .num_attention_heads
243
- self .num_key_value_groups = config .num_attention_heads // config .num_key_value_heads
266
+ self .head_dim = (
267
+ getattr (config , "head_dim" , None )
268
+ or config .hidden_size // config .num_attention_heads
269
+ )
270
+ self .num_key_value_groups = (
271
+ config .num_attention_heads // config .num_key_value_heads
272
+ )
244
273
self .scaling = self .head_dim ** - 0.5
245
274
self .attention_dropout = config .attention_dropout
246
275
self .is_causal = True
247
- self .q_proj = nn .Linear (config .hidden_size , config .num_attention_heads * self .head_dim , bias = False )
248
- self .k_proj = nn .Linear (config .hidden_size , config .num_key_value_heads * self .head_dim , bias = False )
249
- self .v_proj = nn .Linear (config .hidden_size , config .num_key_value_heads * self .head_dim , bias = False )
250
- self .o_proj = nn .Linear (config .num_attention_heads * self .head_dim , config .hidden_size , bias = False )
276
+ self .q_proj = nn .Linear (
277
+ config .hidden_size , config .num_attention_heads * self .head_dim , bias = False
278
+ )
279
+ self .k_proj = nn .Linear (
280
+ config .hidden_size , config .num_key_value_heads * self .head_dim , bias = False
281
+ )
282
+ self .v_proj = nn .Linear (
283
+ config .hidden_size , config .num_key_value_heads * self .head_dim , bias = False
284
+ )
285
+ self .o_proj = nn .Linear (
286
+ config .num_attention_heads * self .head_dim , config .hidden_size , bias = False
287
+ )
251
288
252
289
@deprecate_kwarg ("past_key_value" , new_name = "past_key_values" , version = "4.58" )
253
290
def forward (
@@ -267,16 +304,22 @@ def forward(
267
304
value_states = self .v_proj (hidden_states ).view (hidden_shape ).transpose (1 , 2 )
268
305
269
306
cos , sin = position_embeddings
270
- query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
307
+ query_states , key_states = apply_rotary_pos_emb (
308
+ query_states , key_states , cos , sin
309
+ )
271
310
272
311
if past_key_values is not None :
273
312
# sin and cos are specific to RoPE models; cache_position needed for the static cache
274
313
cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position }
275
- key_states , value_states = past_key_values .update (key_states , value_states , self .layer_idx , cache_kwargs )
314
+ key_states , value_states = past_key_values .update (
315
+ key_states , value_states , self .layer_idx , cache_kwargs
316
+ )
276
317
277
318
attention_interface : Callable = eager_attention_forward
278
319
if self .config ._attn_implementation != "eager" :
279
- attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
320
+ attention_interface = ALL_ATTENTION_FUNCTIONS [
321
+ self .config ._attn_implementation
322
+ ]
280
323
281
324
attn_output , attn_weights = attention_interface (
282
325
self ,
@@ -286,7 +329,9 @@ def forward(
286
329
attention_mask ,
287
330
dropout = 0.0 if not self .training else self .attention_dropout ,
288
331
scaling = self .scaling ,
289
- sliding_window = getattr (self .config , "sliding_window" , None ), # main diff with Llama
332
+ sliding_window = getattr (
333
+ self .config , "sliding_window" , None
334
+ ), # main diff with Llama
290
335
** kwargs ,
291
336
)
292
337
@@ -303,8 +348,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
303
348
self .self_attn = MixtralAttention (config , layer_idx )
304
349
305
350
self .block_sparse_moe = MixtralSparseMoeBlock (config )
306
- self .input_layernorm = MixtralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
307
- self .post_attention_layernorm = MixtralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
351
+ self .input_layernorm = MixtralRMSNorm (
352
+ config .hidden_size , eps = config .rms_norm_eps
353
+ )
354
+ self .post_attention_layernorm = MixtralRMSNorm (
355
+ config .hidden_size , eps = config .rms_norm_eps
356
+ )
308
357
309
358
@deprecate_kwarg ("past_key_value" , new_name = "past_key_values" , version = "4.58" )
310
359
def forward (
@@ -349,7 +398,9 @@ def __init__(self, config: MixtralConfig, device=None):
349
398
super ().__init__ ()
350
399
# BC: "rope_type" was originally "type"
351
400
if hasattr (config , "rope_scaling" ) and isinstance (config .rope_scaling , dict ):
352
- self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
401
+ self .rope_type = config .rope_scaling .get (
402
+ "rope_type" , config .rope_scaling .get ("type" )
403
+ )
353
404
else :
354
405
self .rope_type = "default"
355
406
self .max_seq_len_cached = config .max_position_embeddings
@@ -365,12 +416,23 @@ def __init__(self, config: MixtralConfig, device=None):
365
416
@torch .no_grad ()
366
417
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
367
418
def forward (self , x , position_ids ):
368
- inv_freq_expanded = self .inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 ).to (x .device )
419
+ inv_freq_expanded = (
420
+ self .inv_freq [None , :, None ]
421
+ .float ()
422
+ .expand (position_ids .shape [0 ], - 1 , 1 )
423
+ .to (x .device )
424
+ )
369
425
position_ids_expanded = position_ids [:, None , :].float ()
370
426
371
- device_type = x .device .type if isinstance (x .device .type , str ) and x .device .type != "mps" else "cpu"
427
+ device_type = (
428
+ x .device .type
429
+ if isinstance (x .device .type , str ) and x .device .type != "mps"
430
+ else "cpu"
431
+ )
372
432
with torch .autocast (device_type = device_type , enabled = False ): # Force float32
373
- freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
433
+ freqs = (
434
+ inv_freq_expanded .float () @ position_ids_expanded .float ()
435
+ ).transpose (1 , 2 )
374
436
emb = torch .cat ((freqs , freqs ), dim = - 1 )
375
437
cos = emb .cos () * self .attention_scaling
376
438
sin = emb .sin () * self .attention_scaling
@@ -404,9 +466,14 @@ def __init__(self, config: MixtralConfig):
404
466
self .padding_idx = config .pad_token_id
405
467
self .vocab_size = config .vocab_size
406
468
407
- self .embed_tokens = nn .Embedding (config .vocab_size , config .hidden_size , self .padding_idx )
469
+ self .embed_tokens = nn .Embedding (
470
+ config .vocab_size , config .hidden_size , self .padding_idx
471
+ )
408
472
self .layers = nn .ModuleList (
409
- [MixtralDecoderLayer (config , layer_idx ) for layer_idx in range (config .num_hidden_layers )]
473
+ [
474
+ MixtralDecoderLayer (config , layer_idx )
475
+ for layer_idx in range (config .num_hidden_layers )
476
+ ]
410
477
)
411
478
self .norm = MixtralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
412
479
self .rotary_emb = MixtralRotaryEmbedding (config = config )
@@ -429,7 +496,9 @@ def forward(
429
496
** kwargs : Unpack [TransformersKwargs ],
430
497
) -> MoeModelOutputWithPast :
431
498
if (input_ids is None ) ^ (inputs_embeds is not None ):
432
- raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
499
+ raise ValueError (
500
+ "You must specify exactly one of input_ids or inputs_embeds"
501
+ )
433
502
434
503
if use_cache and past_key_values is None :
435
504
past_key_values = DynamicCache (config = self .config )
@@ -438,14 +507,22 @@ def forward(
438
507
inputs_embeds = self .embed_tokens (input_ids )
439
508
440
509
if cache_position is None :
441
- past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
510
+ past_seen_tokens = (
511
+ past_key_values .get_seq_length () if past_key_values is not None else 0
512
+ )
442
513
cache_position = torch .arange (
443
- past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
514
+ past_seen_tokens ,
515
+ past_seen_tokens + inputs_embeds .shape [1 ],
516
+ device = inputs_embeds .device ,
444
517
)
445
518
if position_ids is None :
446
519
position_ids = cache_position .unsqueeze (0 )
447
520
448
- mask_function = create_causal_mask if self .config .sliding_window is None else create_sliding_window_causal_mask
521
+ mask_function = (
522
+ create_causal_mask
523
+ if self .config .sliding_window is None
524
+ else create_sliding_window_causal_mask
525
+ )
449
526
causal_mask = mask_function (
450
527
config = self .config ,
451
528
input_embeds = inputs_embeds ,
@@ -514,7 +591,9 @@ def load_balancing_loss_func(
514
591
515
592
if isinstance (gate_logits , tuple ):
516
593
compute_device = gate_logits [0 ].device
517
- concatenated_gate_logits = torch .cat ([layer_gate .to (compute_device ) for layer_gate in gate_logits ], dim = 0 )
594
+ concatenated_gate_logits = torch .cat (
595
+ [layer_gate .to (compute_device ) for layer_gate in gate_logits ], dim = 0
596
+ )
518
597
519
598
routing_weights = torch .nn .functional .softmax (concatenated_gate_logits , dim = - 1 )
520
599
@@ -530,20 +609,24 @@ def load_balancing_loss_func(
530
609
router_prob_per_expert = torch .mean (routing_weights , dim = 0 )
531
610
else :
532
611
batch_size , sequence_length = attention_mask .shape
533
- num_hidden_layers = concatenated_gate_logits .shape [0 ] // (batch_size * sequence_length )
612
+ num_hidden_layers = concatenated_gate_logits .shape [0 ] // (
613
+ batch_size * sequence_length
614
+ )
534
615
535
616
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
536
617
expert_attention_mask = (
537
618
attention_mask [None , :, :, None , None ]
538
- .expand ((num_hidden_layers , batch_size , sequence_length , top_k , num_experts ))
619
+ .expand (
620
+ (num_hidden_layers , batch_size , sequence_length , top_k , num_experts )
621
+ )
539
622
.reshape (- 1 , top_k , num_experts )
540
623
.to (compute_device )
541
624
)
542
625
543
626
# Compute the percentage of tokens routed to each experts
544
- tokens_per_expert = torch .sum (expert_mask . float () * expert_attention_mask , dim = 0 ) / torch . sum (
545
- expert_attention_mask , dim = 0
546
- )
627
+ tokens_per_expert = torch .sum (
628
+ expert_mask . float () * expert_attention_mask , dim = 0
629
+ ) / torch . sum ( expert_attention_mask , dim = 0 )
547
630
548
631
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
549
632
router_per_expert_attention_mask = (
@@ -554,9 +637,9 @@ def load_balancing_loss_func(
554
637
)
555
638
556
639
# Compute the average probability of routing to these experts
557
- router_prob_per_expert = torch .sum (routing_weights * router_per_expert_attention_mask , dim = 0 ) / torch . sum (
558
- router_per_expert_attention_mask , dim = 0
559
- )
640
+ router_prob_per_expert = torch .sum (
641
+ routing_weights * router_per_expert_attention_mask , dim = 0
642
+ ) / torch . sum ( router_per_expert_attention_mask , dim = 0 )
560
643
561
644
overall_loss = torch .sum (tokens_per_expert * router_prob_per_expert .unsqueeze (0 ))
562
645
return overall_loss * num_experts
@@ -626,7 +709,9 @@ def forward(
626
709
```"""
627
710
628
711
output_router_logits = (
629
- output_router_logits if output_router_logits is not None else self .config .output_router_logits
712
+ output_router_logits
713
+ if output_router_logits is not None
714
+ else self .config .output_router_logits
630
715
)
631
716
632
717
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@@ -644,7 +729,11 @@ def forward(
644
729
645
730
hidden_states = outputs .last_hidden_state
646
731
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
647
- slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
732
+ slice_indices = (
733
+ slice (- logits_to_keep , None )
734
+ if isinstance (logits_to_keep , int )
735
+ else logits_to_keep
736
+ )
648
737
logits = self .lm_head (hidden_states [:, slice_indices , :])
649
738
650
739
loss = None
@@ -660,7 +749,9 @@ def forward(
660
749
attention_mask ,
661
750
)
662
751
if labels is not None :
663
- loss += self .router_aux_loss_coef * aux_loss .to (loss .device ) # make sure to reside in the same device
752
+ loss += self .router_aux_loss_coef * aux_loss .to (
753
+ loss .device
754
+ ) # make sure to reside in the same device
664
755
665
756
return MoeCausalLMOutputWithPast (
666
757
loss = loss ,
@@ -673,11 +764,15 @@ def forward(
673
764
)
674
765
675
766
676
- class MixtralForSequenceClassification (GenericForSequenceClassification , MixtralPreTrainedModel ):
767
+ class MixtralForSequenceClassification (
768
+ GenericForSequenceClassification , MixtralPreTrainedModel
769
+ ):
677
770
pass
678
771
679
772
680
- class MixtralForTokenClassification (GenericForTokenClassification , MixtralPreTrainedModel ):
773
+ class MixtralForTokenClassification (
774
+ GenericForTokenClassification , MixtralPreTrainedModel
775
+ ):
681
776
pass
682
777
683
778
0 commit comments