@@ -86,7 +86,9 @@ def load_balancing_loss_func(
86
86
87
87
if isinstance (gate_logits , tuple ):
88
88
compute_device = gate_logits [0 ].device
89
- concatenated_gate_logits = torch .cat ([layer_gate .to (compute_device ) for layer_gate in gate_logits ], dim = 0 )
89
+ concatenated_gate_logits = torch .cat (
90
+ [layer_gate .to (compute_device ) for layer_gate in gate_logits ], dim = 0
91
+ )
90
92
91
93
routing_weights = torch .nn .functional .softmax (concatenated_gate_logits , dim = - 1 )
92
94
@@ -102,20 +104,24 @@ def load_balancing_loss_func(
102
104
router_prob_per_expert = torch .mean (routing_weights , dim = 0 )
103
105
else :
104
106
batch_size , sequence_length = attention_mask .shape
105
- num_hidden_layers = concatenated_gate_logits .shape [0 ] // (batch_size * sequence_length )
107
+ num_hidden_layers = concatenated_gate_logits .shape [0 ] // (
108
+ batch_size * sequence_length
109
+ )
106
110
107
111
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
108
112
expert_attention_mask = (
109
113
attention_mask [None , :, :, None , None ]
110
- .expand ((num_hidden_layers , batch_size , sequence_length , top_k , num_experts ))
114
+ .expand (
115
+ (num_hidden_layers , batch_size , sequence_length , top_k , num_experts )
116
+ )
111
117
.reshape (- 1 , top_k , num_experts )
112
118
.to (compute_device )
113
119
)
114
120
115
121
# Compute the percentage of tokens routed to each experts
116
- tokens_per_expert = torch .sum (expert_mask . float () * expert_attention_mask , dim = 0 ) / torch . sum (
117
- expert_attention_mask , dim = 0
118
- )
122
+ tokens_per_expert = torch .sum (
123
+ expert_mask . float () * expert_attention_mask , dim = 0
124
+ ) / torch . sum ( expert_attention_mask , dim = 0 )
119
125
120
126
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
121
127
router_per_expert_attention_mask = (
@@ -126,9 +132,9 @@ def load_balancing_loss_func(
126
132
)
127
133
128
134
# Compute the average probability of routing to these experts
129
- router_prob_per_expert = torch .sum (routing_weights * router_per_expert_attention_mask , dim = 0 ) / torch . sum (
130
- router_per_expert_attention_mask , dim = 0
131
- )
135
+ router_prob_per_expert = torch .sum (
136
+ routing_weights * router_per_expert_attention_mask , dim = 0
137
+ ) / torch . sum ( router_per_expert_attention_mask , dim = 0 )
132
138
133
139
overall_loss = torch .sum (tokens_per_expert * router_prob_per_expert .unsqueeze (0 ))
134
140
return overall_loss * num_experts
@@ -147,7 +153,9 @@ def __init__(self, config: MixtralConfig):
147
153
self .act_fn = ACT2FN [config .hidden_act ]
148
154
149
155
def forward (self , hidden_states ):
150
- current_hidden_states = self .act_fn (self .w1 (hidden_states )) * self .w3 (hidden_states )
156
+ current_hidden_states = self .act_fn (self .w1 (hidden_states )) * self .w3 (
157
+ hidden_states
158
+ )
151
159
current_hidden_states = self .w2 (current_hidden_states )
152
160
return current_hidden_states
153
161
@@ -174,7 +182,9 @@ def __init__(self, config):
174
182
# gating
175
183
self .gate = nn .Linear (self .hidden_dim , self .num_experts , bias = False )
176
184
177
- self .experts = nn .ModuleList ([MixtralBlockSparseTop2MLP (config ) for _ in range (self .num_experts )])
185
+ self .experts = nn .ModuleList (
186
+ [MixtralBlockSparseTop2MLP (config ) for _ in range (self .num_experts )]
187
+ )
178
188
179
189
# Jitter parameters
180
190
self .jitter_noise = config .router_jitter_noise
@@ -183,24 +193,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
183
193
""" """
184
194
batch_size , sequence_length , hidden_dim = hidden_states .shape
185
195
if self .training and self .jitter_noise > 0 :
186
- hidden_states *= torch .empty_like (hidden_states ).uniform_ (1.0 - self .jitter_noise , 1.0 + self .jitter_noise )
196
+ hidden_states *= torch .empty_like (hidden_states ).uniform_ (
197
+ 1.0 - self .jitter_noise , 1.0 + self .jitter_noise
198
+ )
187
199
hidden_states = hidden_states .view (- 1 , hidden_dim )
188
200
# router_logits: (batch * sequence_length, n_experts)
189
201
router_logits = self .gate (hidden_states )
190
202
191
203
routing_weights = F .softmax (router_logits , dim = 1 , dtype = torch .float )
192
- routing_weights , selected_experts = torch .topk (routing_weights , self .top_k , dim = - 1 )
204
+ routing_weights , selected_experts = torch .topk (
205
+ routing_weights , self .top_k , dim = - 1
206
+ )
193
207
routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
194
208
# we cast back to the input dtype
195
209
routing_weights = routing_weights .to (hidden_states .dtype )
196
210
197
211
final_hidden_states = torch .zeros (
198
- (batch_size * sequence_length , hidden_dim ), dtype = hidden_states .dtype , device = hidden_states .device
212
+ (batch_size * sequence_length , hidden_dim ),
213
+ dtype = hidden_states .dtype ,
214
+ device = hidden_states .device ,
199
215
)
200
216
201
217
# One hot encode the selected experts to create an expert mask
202
218
# this will be used to easily index which expert is going to be sollicitated
203
- expert_mask = torch .nn .functional .one_hot (selected_experts , num_classes = self .num_experts ).permute (2 , 1 , 0 )
219
+ expert_mask = torch .nn .functional .one_hot (
220
+ selected_experts , num_classes = self .num_experts
221
+ ).permute (2 , 1 , 0 )
204
222
205
223
# Loop over all available experts in the model and perform the computation on each expert
206
224
for expert_idx in range (self .num_experts ):
@@ -210,12 +228,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
210
228
# the current expert. We need to make sure to multiply the output hidden
211
229
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
212
230
current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
213
- current_hidden_states = expert_layer (current_state ) * routing_weights [top_x , idx , None ]
231
+ current_hidden_states = (
232
+ expert_layer (current_state ) * routing_weights [top_x , idx , None ]
233
+ )
214
234
215
235
# However `index_add_` only support torch tensors for indexing so we'll use
216
236
# the `top_x` tensor here.
217
- final_hidden_states .index_add_ (0 , top_x , current_hidden_states .to (hidden_states .dtype ))
218
- final_hidden_states = final_hidden_states .reshape (batch_size , sequence_length , hidden_dim )
237
+ final_hidden_states .index_add_ (
238
+ 0 , top_x , current_hidden_states .to (hidden_states .dtype )
239
+ )
240
+ final_hidden_states = final_hidden_states .reshape (
241
+ batch_size , sequence_length , hidden_dim
242
+ )
219
243
return final_hidden_states , router_logits
220
244
221
245
@@ -235,8 +259,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
235
259
self .self_attn = MixtralAttention (config , layer_idx )
236
260
237
261
self .block_sparse_moe = MixtralSparseMoeBlock (config )
238
- self .input_layernorm = MixtralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
239
- self .post_attention_layernorm = MixtralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
262
+ self .input_layernorm = MixtralRMSNorm (
263
+ config .hidden_size , eps = config .rms_norm_eps
264
+ )
265
+ self .post_attention_layernorm = MixtralRMSNorm (
266
+ config .hidden_size , eps = config .rms_norm_eps
267
+ )
240
268
241
269
@deprecate_kwarg ("past_key_value" , new_name = "past_key_values" , version = "4.58" )
242
270
def forward (
@@ -300,7 +328,9 @@ def forward(
300
328
** kwargs : Unpack [TransformersKwargs ],
301
329
) -> MoeModelOutputWithPast :
302
330
if (input_ids is None ) ^ (inputs_embeds is not None ):
303
- raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
331
+ raise ValueError (
332
+ "You must specify exactly one of input_ids or inputs_embeds"
333
+ )
304
334
305
335
if use_cache and past_key_values is None :
306
336
past_key_values = DynamicCache (config = self .config )
@@ -309,14 +339,22 @@ def forward(
309
339
inputs_embeds = self .embed_tokens (input_ids )
310
340
311
341
if cache_position is None :
312
- past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
342
+ past_seen_tokens = (
343
+ past_key_values .get_seq_length () if past_key_values is not None else 0
344
+ )
313
345
cache_position = torch .arange (
314
- past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
346
+ past_seen_tokens ,
347
+ past_seen_tokens + inputs_embeds .shape [1 ],
348
+ device = inputs_embeds .device ,
315
349
)
316
350
if position_ids is None :
317
351
position_ids = cache_position .unsqueeze (0 )
318
352
319
- mask_function = create_causal_mask if self .config .sliding_window is None else create_sliding_window_causal_mask
353
+ mask_function = (
354
+ create_causal_mask
355
+ if self .config .sliding_window is None
356
+ else create_sliding_window_causal_mask
357
+ )
320
358
causal_mask = mask_function (
321
359
config = self .config ,
322
360
input_embeds = inputs_embeds ,
@@ -399,7 +437,9 @@ def forward(
399
437
```"""
400
438
401
439
output_router_logits = (
402
- output_router_logits if output_router_logits is not None else self .config .output_router_logits
440
+ output_router_logits
441
+ if output_router_logits is not None
442
+ else self .config .output_router_logits
403
443
)
404
444
405
445
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@@ -417,7 +457,11 @@ def forward(
417
457
418
458
hidden_states = outputs .last_hidden_state
419
459
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
420
- slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
460
+ slice_indices = (
461
+ slice (- logits_to_keep , None )
462
+ if isinstance (logits_to_keep , int )
463
+ else logits_to_keep
464
+ )
421
465
logits = self .lm_head (hidden_states [:, slice_indices , :])
422
466
423
467
loss = None
@@ -433,7 +477,9 @@ def forward(
433
477
attention_mask ,
434
478
)
435
479
if labels is not None :
436
- loss += self .router_aux_loss_coef * aux_loss .to (loss .device ) # make sure to reside in the same device
480
+ loss += self .router_aux_loss_coef * aux_loss .to (
481
+ loss .device
482
+ ) # make sure to reside in the same device
437
483
438
484
return MoeCausalLMOutputWithPast (
439
485
loss = loss ,
0 commit comments