Skip to content

Commit e297322

Browse files
committed
Fix code quality - apply black and ruff formatting
1 parent 5201904 commit e297322

File tree

2 files changed

+99
-56
lines changed

2 files changed

+99
-56
lines changed

src/transformers/models/mixtral/modular_mixtral.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def load_balancing_loss_func(
8686

8787
if isinstance(gate_logits, tuple):
8888
compute_device = gate_logits[0].device
89-
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
89+
concatenated_gate_logits = torch.cat(
90+
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
91+
)
9092

9193
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
9294

@@ -102,20 +104,24 @@ def load_balancing_loss_func(
102104
router_prob_per_expert = torch.mean(routing_weights, dim=0)
103105
else:
104106
batch_size, sequence_length = attention_mask.shape
105-
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
107+
num_hidden_layers = concatenated_gate_logits.shape[0] // (
108+
batch_size * sequence_length
109+
)
106110

107111
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
108112
expert_attention_mask = (
109113
attention_mask[None, :, :, None, None]
110-
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
114+
.expand(
115+
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
116+
)
111117
.reshape(-1, top_k, num_experts)
112118
.to(compute_device)
113119
)
114120

115121
# Compute the percentage of tokens routed to each experts
116-
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
117-
expert_attention_mask, dim=0
118-
)
122+
tokens_per_expert = torch.sum(
123+
expert_mask.float() * expert_attention_mask, dim=0
124+
) / torch.sum(expert_attention_mask, dim=0)
119125

120126
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
121127
router_per_expert_attention_mask = (
@@ -126,9 +132,9 @@ def load_balancing_loss_func(
126132
)
127133

128134
# Compute the average probability of routing to these experts
129-
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
130-
router_per_expert_attention_mask, dim=0
131-
)
135+
router_prob_per_expert = torch.sum(
136+
routing_weights * router_per_expert_attention_mask, dim=0
137+
) / torch.sum(router_per_expert_attention_mask, dim=0)
132138

133139
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
134140
return overall_loss * num_experts
@@ -147,7 +153,9 @@ def __init__(self, config: MixtralConfig):
147153
self.act_fn = ACT2FN[config.hidden_act]
148154

149155
def forward(self, hidden_states):
150-
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
156+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
157+
hidden_states
158+
)
151159
current_hidden_states = self.w2(current_hidden_states)
152160
return current_hidden_states
153161

@@ -174,7 +182,9 @@ def __init__(self, config):
174182
# gating
175183
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
176184

177-
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
185+
self.experts = nn.ModuleList(
186+
[MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]
187+
)
178188

179189
# Jitter parameters
180190
self.jitter_noise = config.router_jitter_noise
@@ -183,24 +193,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
183193
""" """
184194
batch_size, sequence_length, hidden_dim = hidden_states.shape
185195
if self.training and self.jitter_noise > 0:
186-
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
196+
hidden_states *= torch.empty_like(hidden_states).uniform_(
197+
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
198+
)
187199
hidden_states = hidden_states.view(-1, hidden_dim)
188200
# router_logits: (batch * sequence_length, n_experts)
189201
router_logits = self.gate(hidden_states)
190202

191203
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
192-
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
204+
routing_weights, selected_experts = torch.topk(
205+
routing_weights, self.top_k, dim=-1
206+
)
193207
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
194208
# we cast back to the input dtype
195209
routing_weights = routing_weights.to(hidden_states.dtype)
196210

197211
final_hidden_states = torch.zeros(
198-
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
212+
(batch_size * sequence_length, hidden_dim),
213+
dtype=hidden_states.dtype,
214+
device=hidden_states.device,
199215
)
200216

201217
# One hot encode the selected experts to create an expert mask
202218
# this will be used to easily index which expert is going to be sollicitated
203-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
219+
expert_mask = torch.nn.functional.one_hot(
220+
selected_experts, num_classes=self.num_experts
221+
).permute(2, 1, 0)
204222

205223
# Loop over all available experts in the model and perform the computation on each expert
206224
for expert_idx in range(self.num_experts):
@@ -210,12 +228,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
210228
# the current expert. We need to make sure to multiply the output hidden
211229
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
212230
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
213-
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
231+
current_hidden_states = (
232+
expert_layer(current_state) * routing_weights[top_x, idx, None]
233+
)
214234

215235
# However `index_add_` only support torch tensors for indexing so we'll use
216236
# the `top_x` tensor here.
217-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
218-
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
237+
final_hidden_states.index_add_(
238+
0, top_x, current_hidden_states.to(hidden_states.dtype)
239+
)
240+
final_hidden_states = final_hidden_states.reshape(
241+
batch_size, sequence_length, hidden_dim
242+
)
219243
return final_hidden_states, router_logits
220244

221245

@@ -235,8 +259,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
235259
self.self_attn = MixtralAttention(config, layer_idx)
236260

237261
self.block_sparse_moe = MixtralSparseMoeBlock(config)
238-
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
239-
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
262+
self.input_layernorm = MixtralRMSNorm(
263+
config.hidden_size, eps=config.rms_norm_eps
264+
)
265+
self.post_attention_layernorm = MixtralRMSNorm(
266+
config.hidden_size, eps=config.rms_norm_eps
267+
)
240268

241269
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
242270
def forward(
@@ -300,7 +328,9 @@ def forward(
300328
**kwargs: Unpack[TransformersKwargs],
301329
) -> MoeModelOutputWithPast:
302330
if (input_ids is None) ^ (inputs_embeds is not None):
303-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
331+
raise ValueError(
332+
"You must specify exactly one of input_ids or inputs_embeds"
333+
)
304334

305335
if use_cache and past_key_values is None:
306336
past_key_values = DynamicCache(config=self.config)
@@ -309,14 +339,22 @@ def forward(
309339
inputs_embeds = self.embed_tokens(input_ids)
310340

311341
if cache_position is None:
312-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
342+
past_seen_tokens = (
343+
past_key_values.get_seq_length() if past_key_values is not None else 0
344+
)
313345
cache_position = torch.arange(
314-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
346+
past_seen_tokens,
347+
past_seen_tokens + inputs_embeds.shape[1],
348+
device=inputs_embeds.device,
315349
)
316350
if position_ids is None:
317351
position_ids = cache_position.unsqueeze(0)
318352

319-
mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
353+
mask_function = (
354+
create_causal_mask
355+
if self.config.sliding_window is None
356+
else create_sliding_window_causal_mask
357+
)
320358
causal_mask = mask_function(
321359
config=self.config,
322360
input_embeds=inputs_embeds,
@@ -399,7 +437,9 @@ def forward(
399437
```"""
400438

401439
output_router_logits = (
402-
output_router_logits if output_router_logits is not None else self.config.output_router_logits
440+
output_router_logits
441+
if output_router_logits is not None
442+
else self.config.output_router_logits
403443
)
404444

405445
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@@ -417,7 +457,11 @@ def forward(
417457

418458
hidden_states = outputs.last_hidden_state
419459
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
420-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
460+
slice_indices = (
461+
slice(-logits_to_keep, None)
462+
if isinstance(logits_to_keep, int)
463+
else logits_to_keep
464+
)
421465
logits = self.lm_head(hidden_states[:, slice_indices, :])
422466

423467
loss = None
@@ -433,7 +477,9 @@ def forward(
433477
attention_mask,
434478
)
435479
if labels is not None:
436-
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
480+
loss += self.router_aux_loss_coef * aux_loss.to(
481+
loss.device
482+
) # make sure to reside in the same device
437483

438484
return MoeCausalLMOutputWithPast(
439485
loss=loss,

tests/models/mixtral/test_mixtral_torch_export.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from transformers import MixtralConfig
2323
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
24-
from transformers.testing_utils import require_torch, torch_device
24+
from transformers.testing_utils import require_torch
2525

2626

2727
@require_torch
@@ -43,31 +43,30 @@ def test_moe_block_torch_export(self):
4343
# Create MoE block
4444
moe_block = MixtralSparseMoeBlock(self.config)
4545
moe_block.eval()
46-
46+
4747
# Move to meta device for export testing
4848
moe_block = moe_block.to("meta")
49-
49+
5050
# Create test input
5151
batch_size, seq_len = 2, 8
5252
hidden_states = torch.randn(
53-
batch_size, seq_len, self.config.hidden_size,
54-
device="meta"
53+
batch_size, seq_len, self.config.hidden_size, device="meta"
5554
)
56-
55+
5756
# Test torch.export - should not raise GuardOnDataDependentSymNode error
5857
try:
5958
exported_program = te.export(
60-
moe_block,
61-
args=(hidden_states,),
62-
kwargs={},
63-
strict=False
59+
moe_block, args=(hidden_states,), kwargs={}, strict=False
6460
)
6561
# If export succeeds, the test passes
6662
self.assertIsNotNone(exported_program)
6763
except Exception as e:
6864
# Check if it's the specific error we're trying to avoid
6965
error_msg = str(e)
70-
if "GuardOnDataDependentSymNode" in error_msg or "nonzero" in error_msg.lower():
66+
if (
67+
"GuardOnDataDependentSymNode" in error_msg
68+
or "nonzero" in error_msg.lower()
69+
):
7170
self.fail(
7271
f"torch.export failed with data-dependent operation error: {error_msg}\n"
7372
"This suggests the .nonzero() fix is not working properly."
@@ -81,30 +80,29 @@ def test_moe_block_functionality(self):
8180
# Create MoE block
8281
moe_block = MixtralSparseMoeBlock(self.config)
8382
moe_block.eval()
84-
83+
8584
# Create test input
8685
batch_size, seq_len = 2, 4
8786
hidden_states = torch.randn(batch_size, seq_len, self.config.hidden_size)
88-
87+
8988
# Forward pass
9089
with torch.no_grad():
9190
output, router_logits = moe_block(hidden_states)
92-
91+
9392
# Verify output shapes
9493
self.assertEqual(output.shape, hidden_states.shape)
9594
self.assertEqual(
96-
router_logits.shape,
97-
(batch_size * seq_len, self.config.num_local_experts)
95+
router_logits.shape, (batch_size * seq_len, self.config.num_local_experts)
9896
)
99-
97+
10098
# Verify that outputs are not all zeros (computation happened)
10199
self.assertFalse(torch.allclose(output, torch.zeros_like(output)))
102-
100+
103101
# Test with different input to ensure different outputs
104102
hidden_states2 = torch.randn(batch_size, seq_len, self.config.hidden_size)
105103
with torch.no_grad():
106104
output2, _ = moe_block(hidden_states2)
107-
105+
108106
# Outputs should be different for different inputs
109107
self.assertFalse(torch.allclose(output, output2))
110108

@@ -117,7 +115,7 @@ def test_moe_block_export_with_different_configs(self):
117115
(16, 2),
118116
(8, 4),
119117
]
120-
118+
121119
for num_experts, top_k in test_configs:
122120
with self.subTest(num_experts=num_experts, top_k=top_k):
123121
config = MixtralConfig(
@@ -127,28 +125,27 @@ def test_moe_block_export_with_different_configs(self):
127125
num_experts_per_tok=top_k,
128126
router_jitter_noise=0.0,
129127
)
130-
128+
131129
moe_block = MixtralSparseMoeBlock(config)
132130
moe_block.eval()
133131
moe_block = moe_block.to("meta")
134-
132+
135133
hidden_states = torch.randn(1, 4, config.hidden_size, device="meta")
136-
134+
137135
# Should export without errors
138136
try:
139137
exported_program = te.export(
140-
moe_block,
141-
args=(hidden_states,),
142-
kwargs={},
143-
strict=False
138+
moe_block, args=(hidden_states,), kwargs={}, strict=False
144139
)
145140
self.assertIsNotNone(exported_program)
146141
except Exception as e:
147142
if "GuardOnDataDependentSymNode" in str(e):
148-
self.fail(f"Export failed for config ({num_experts}, {top_k}): {e}")
143+
self.fail(
144+
f"Export failed for config ({num_experts}, {top_k}): {e}"
145+
)
149146
else:
150147
raise
151148

152149

153150
if __name__ == "__main__":
154-
unittest.main()
151+
unittest.main()

0 commit comments

Comments
 (0)