From 52019042c2037f1cb3712bd14dacf0db5a6c7621 Mon Sep 17 00:00:00 2001 From: akacmazz Date: Tue, 12 Aug 2025 19:17:36 +0300 Subject: [PATCH 1/8] Fix torch.export compatibility for Mixtral MoE models - Replace data-dependent .nonzero() operation with static expert loop - Resolves GuardOnDataDependentSymNode error during torch.export - Maintains identical functionality while enabling export compatibility - Fixes issue introduced in PR #32429 - Add tests for torch.export compatibility --- .../models/mixtral/modular_mixtral.py | 6 +- .../mixtral/test_mixtral_torch_export.py | 154 ++++++++++++++++++ 2 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 tests/models/mixtral/test_mixtral_torch_export.py diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index ffcf8224353f..c42fc69781d1 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -202,10 +202,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) diff --git a/tests/models/mixtral/test_mixtral_torch_export.py b/tests/models/mixtral/test_mixtral_torch_export.py new file mode 100644 index 000000000000..e10520c94ca6 --- /dev/null +++ b/tests/models/mixtral/test_mixtral_torch_export.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing torch.export compatibility for Mixtral models.""" + +import unittest + +import torch +import torch.export as te + +from transformers import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from transformers.testing_utils import require_torch, torch_device + + +@require_torch +class MixtralTorchExportTest(unittest.TestCase): + """Test torch.export compatibility for Mixtral MoE components.""" + + def setUp(self): + """Set up test configuration.""" + self.config = MixtralConfig( + hidden_size=128, + intermediate_size=256, + num_local_experts=8, + num_experts_per_tok=2, + router_jitter_noise=0.0, + ) + + def test_moe_block_torch_export(self): + """Test that MixtralSparseMoeBlock can be exported with torch.export.""" + # Create MoE block + moe_block = MixtralSparseMoeBlock(self.config) + moe_block.eval() + + # Move to meta device for export testing + moe_block = moe_block.to("meta") + + # Create test input + batch_size, seq_len = 2, 8 + hidden_states = torch.randn( + batch_size, seq_len, self.config.hidden_size, + device="meta" + ) + + # Test torch.export - should not raise GuardOnDataDependentSymNode error + try: + exported_program = te.export( + moe_block, + args=(hidden_states,), + kwargs={}, + strict=False + ) + # If export succeeds, the test passes + self.assertIsNotNone(exported_program) + except Exception as e: + # Check if it's the specific error we're trying to avoid + error_msg = str(e) + if "GuardOnDataDependentSymNode" in error_msg or "nonzero" in error_msg.lower(): + self.fail( + f"torch.export failed with data-dependent operation error: {error_msg}\n" + "This suggests the .nonzero() fix is not working properly." + ) + else: + # Re-raise other unexpected errors + raise + + def test_moe_block_functionality(self): + """Test that MoE block maintains correct functionality after the fix.""" + # Create MoE block + moe_block = MixtralSparseMoeBlock(self.config) + moe_block.eval() + + # Create test input + batch_size, seq_len = 2, 4 + hidden_states = torch.randn(batch_size, seq_len, self.config.hidden_size) + + # Forward pass + with torch.no_grad(): + output, router_logits = moe_block(hidden_states) + + # Verify output shapes + self.assertEqual(output.shape, hidden_states.shape) + self.assertEqual( + router_logits.shape, + (batch_size * seq_len, self.config.num_local_experts) + ) + + # Verify that outputs are not all zeros (computation happened) + self.assertFalse(torch.allclose(output, torch.zeros_like(output))) + + # Test with different input to ensure different outputs + hidden_states2 = torch.randn(batch_size, seq_len, self.config.hidden_size) + with torch.no_grad(): + output2, _ = moe_block(hidden_states2) + + # Outputs should be different for different inputs + self.assertFalse(torch.allclose(output, output2)) + + def test_moe_block_export_with_different_configs(self): + """Test torch.export with various expert configurations.""" + test_configs = [ + # (num_experts, top_k) + (4, 2), + (8, 2), + (16, 2), + (8, 4), + ] + + for num_experts, top_k in test_configs: + with self.subTest(num_experts=num_experts, top_k=top_k): + config = MixtralConfig( + hidden_size=64, + intermediate_size=128, + num_local_experts=num_experts, + num_experts_per_tok=top_k, + router_jitter_noise=0.0, + ) + + moe_block = MixtralSparseMoeBlock(config) + moe_block.eval() + moe_block = moe_block.to("meta") + + hidden_states = torch.randn(1, 4, config.hidden_size, device="meta") + + # Should export without errors + try: + exported_program = te.export( + moe_block, + args=(hidden_states,), + kwargs={}, + strict=False + ) + self.assertIsNotNone(exported_program) + except Exception as e: + if "GuardOnDataDependentSymNode" in str(e): + self.fail(f"Export failed for config ({num_experts}, {top_k}): {e}") + else: + raise + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From e2973228087a5b6226fc524287b808f1a1c3645c Mon Sep 17 00:00:00 2001 From: akacmazz Date: Tue, 12 Aug 2025 20:53:37 +0300 Subject: [PATCH 2/8] Fix code quality - apply black and ruff formatting --- .../models/mixtral/modular_mixtral.py | 100 +++++++++++++----- .../mixtral/test_mixtral_torch_export.py | 55 +++++----- 2 files changed, 99 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index c42fc69781d1..2d47b0624f59 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -86,7 +86,9 @@ def load_balancing_loss_func( if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) @@ -102,20 +104,24 @@ def load_balancing_loss_func( router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] - .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 - ) + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( @@ -126,9 +132,9 @@ def load_balancing_loss_func( ) # Compute the average probability of routing to these experts - router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( - router_per_expert_attention_mask, dim=0 - ) + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts @@ -147,7 +153,9 @@ def __init__(self, config: MixtralConfig): self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( + hidden_states + ) current_hidden_states = self.w2(current_hidden_states) return current_hidden_states @@ -174,7 +182,9 @@ def __init__(self, config): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + self.experts = nn.ModuleList( + [MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)] + ) # Jitter parameters self.jitter_noise = config.router_jitter_noise @@ -183,24 +193,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: - hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states *= torch.empty_like(hidden_states).uniform_( + 1.0 - self.jitter_noise, 1.0 + self.jitter_noise + ) hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): @@ -210,12 +228,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) return final_hidden_states, router_logits @@ -235,8 +259,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.self_attn = MixtralAttention(config, layer_idx) self.block_sparse_moe = MixtralSparseMoeBlock(config) - self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = MixtralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MixtralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -300,7 +328,9 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) @@ -309,14 +339,22 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + mask_function = ( + create_causal_mask + if self.config.sliding_window is None + else create_sliding_window_causal_mask + ) causal_mask = mask_function( config=self.config, input_embeds=inputs_embeds, @@ -399,7 +437,9 @@ def forward( ```""" output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -417,7 +457,11 @@ def forward( hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None @@ -433,7 +477,9 @@ def forward( attention_mask, ) if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device return MoeCausalLMOutputWithPast( loss=loss, diff --git a/tests/models/mixtral/test_mixtral_torch_export.py b/tests/models/mixtral/test_mixtral_torch_export.py index e10520c94ca6..b0aa7699ab14 100644 --- a/tests/models/mixtral/test_mixtral_torch_export.py +++ b/tests/models/mixtral/test_mixtral_torch_export.py @@ -21,7 +21,7 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock -from transformers.testing_utils import require_torch, torch_device +from transformers.testing_utils import require_torch @require_torch @@ -43,31 +43,30 @@ def test_moe_block_torch_export(self): # Create MoE block moe_block = MixtralSparseMoeBlock(self.config) moe_block.eval() - + # Move to meta device for export testing moe_block = moe_block.to("meta") - + # Create test input batch_size, seq_len = 2, 8 hidden_states = torch.randn( - batch_size, seq_len, self.config.hidden_size, - device="meta" + batch_size, seq_len, self.config.hidden_size, device="meta" ) - + # Test torch.export - should not raise GuardOnDataDependentSymNode error try: exported_program = te.export( - moe_block, - args=(hidden_states,), - kwargs={}, - strict=False + moe_block, args=(hidden_states,), kwargs={}, strict=False ) # If export succeeds, the test passes self.assertIsNotNone(exported_program) except Exception as e: # Check if it's the specific error we're trying to avoid error_msg = str(e) - if "GuardOnDataDependentSymNode" in error_msg or "nonzero" in error_msg.lower(): + if ( + "GuardOnDataDependentSymNode" in error_msg + or "nonzero" in error_msg.lower() + ): self.fail( f"torch.export failed with data-dependent operation error: {error_msg}\n" "This suggests the .nonzero() fix is not working properly." @@ -81,30 +80,29 @@ def test_moe_block_functionality(self): # Create MoE block moe_block = MixtralSparseMoeBlock(self.config) moe_block.eval() - + # Create test input batch_size, seq_len = 2, 4 hidden_states = torch.randn(batch_size, seq_len, self.config.hidden_size) - + # Forward pass with torch.no_grad(): output, router_logits = moe_block(hidden_states) - + # Verify output shapes self.assertEqual(output.shape, hidden_states.shape) self.assertEqual( - router_logits.shape, - (batch_size * seq_len, self.config.num_local_experts) + router_logits.shape, (batch_size * seq_len, self.config.num_local_experts) ) - + # Verify that outputs are not all zeros (computation happened) self.assertFalse(torch.allclose(output, torch.zeros_like(output))) - + # Test with different input to ensure different outputs hidden_states2 = torch.randn(batch_size, seq_len, self.config.hidden_size) with torch.no_grad(): output2, _ = moe_block(hidden_states2) - + # Outputs should be different for different inputs self.assertFalse(torch.allclose(output, output2)) @@ -117,7 +115,7 @@ def test_moe_block_export_with_different_configs(self): (16, 2), (8, 4), ] - + for num_experts, top_k in test_configs: with self.subTest(num_experts=num_experts, top_k=top_k): config = MixtralConfig( @@ -127,28 +125,27 @@ def test_moe_block_export_with_different_configs(self): num_experts_per_tok=top_k, router_jitter_noise=0.0, ) - + moe_block = MixtralSparseMoeBlock(config) moe_block.eval() moe_block = moe_block.to("meta") - + hidden_states = torch.randn(1, 4, config.hidden_size, device="meta") - + # Should export without errors try: exported_program = te.export( - moe_block, - args=(hidden_states,), - kwargs={}, - strict=False + moe_block, args=(hidden_states,), kwargs={}, strict=False ) self.assertIsNotNone(exported_program) except Exception as e: if "GuardOnDataDependentSymNode" in str(e): - self.fail(f"Export failed for config ({num_experts}, {top_k}): {e}") + self.fail( + f"Export failed for config ({num_experts}, {top_k}): {e}" + ) else: raise if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From c3e3c5e6938d5b8975058d5eecaa5c7f93985db3 Mon Sep 17 00:00:00 2001 From: akacmazz Date: Tue, 12 Aug 2025 21:10:35 +0300 Subject: [PATCH 3/8] Generate modeling_mixtral.py from modular_mixtral.py - Auto-generate modeling_mixtral.py with the same fix - Apply black formatting - Fix repository consistency check --- .../models/mixtral/modeling_mixtral.py | 197 +++++++++++++----- 1 file changed, 146 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index b6b5883f4a77..3a651cbb2225 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -67,7 +67,9 @@ def __init__(self, config: MixtralConfig): self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( + hidden_states + ) current_hidden_states = self.w2(current_hidden_states) return current_hidden_states @@ -94,7 +96,9 @@ def __init__(self, config): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + self.experts = nn.ModuleList( + [MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)] + ) # Jitter parameters self.jitter_noise = config.router_jitter_noise @@ -103,39 +107,53 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: - hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states *= torch.empty_like(hidden_states).uniform_( + 1.0 - self.jitter_noise, 1.0 + self.jitter_noise + ) hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) return final_hidden_states, router_logits @@ -202,7 +220,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -224,8 +244,12 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -239,15 +263,28 @@ def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.head_dim = ( + getattr(config, "head_dim", None) + or config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True - self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=False + ) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -267,16 +304,22 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] attn_output, attn_weights = attention_interface( self, @@ -286,7 +329,9 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + sliding_window=getattr( + self.config, "sliding_window", None + ), # main diff with Llama **kwargs, ) @@ -303,8 +348,12 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.self_attn = MixtralAttention(config, layer_idx) self.block_sparse_moe = MixtralSparseMoeBlock(config) - self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = MixtralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MixtralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -349,7 +398,9 @@ def __init__(self, config: MixtralConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings @@ -365,12 +416,23 @@ def __init__(self, config: MixtralConfig, device=None): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) position_ids_expanded = position_ids[:, None, :].float() - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -404,9 +466,14 @@ def __init__(self, config: MixtralConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) self.layers = nn.ModuleList( - [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [ + MixtralDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] ) self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = MixtralRotaryEmbedding(config=config) @@ -429,7 +496,9 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) @@ -438,14 +507,22 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + mask_function = ( + create_causal_mask + if self.config.sliding_window is None + else create_sliding_window_causal_mask + ) causal_mask = mask_function( config=self.config, input_embeds=inputs_embeds, @@ -514,7 +591,9 @@ def load_balancing_loss_func( if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) @@ -530,20 +609,24 @@ def load_balancing_loss_func( router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] - .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 - ) + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( @@ -554,9 +637,9 @@ def load_balancing_loss_func( ) # Compute the average probability of routing to these experts - router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( - router_per_expert_attention_mask, dim=0 - ) + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts @@ -626,7 +709,9 @@ def forward( ```""" output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -644,7 +729,11 @@ def forward( hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None @@ -660,7 +749,9 @@ def forward( attention_mask, ) if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device return MoeCausalLMOutputWithPast( loss=loss, @@ -673,11 +764,15 @@ def forward( ) -class MixtralForSequenceClassification(GenericForSequenceClassification, MixtralPreTrainedModel): +class MixtralForSequenceClassification( + GenericForSequenceClassification, MixtralPreTrainedModel +): pass -class MixtralForTokenClassification(GenericForTokenClassification, MixtralPreTrainedModel): +class MixtralForTokenClassification( + GenericForTokenClassification, MixtralPreTrainedModel +): pass From 0aa9de7090553c24bd69707eecec55ee0028ad52 Mon Sep 17 00:00:00 2001 From: akacmazz Date: Wed, 13 Aug 2025 11:24:59 +0300 Subject: [PATCH 4/8] Implement separate training/inference paths for torch.export compatibility - Training path: Keep efficient .nonzero() for performance - Inference path: Use static loop for torch.export compatibility - Add conditional check to skip empty experts in inference - Update tests to validate inference mode export - Addresses maintainer feedback on performance concerns --- .../models/mixtral/modeling_mixtral.py | 60 +++++++++++++------ .../models/mixtral/modular_mixtral.py | 60 +++++++++++++------ .../mixtral/test_mixtral_torch_export.py | 6 +- 3 files changed, 89 insertions(+), 37 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 3a651cbb2225..fd8d6d8a57f9 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -134,23 +134,49 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: selected_experts, num_classes=self.num_experts ).permute(2, 1, 0) - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = ( - expert_layer(current_state) * routing_weights[top_x, idx, None] - ) - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) + # Separate paths for training (with .nonzero()) and inference (without .nonzero()) + if self.training: + # Training path: use .nonzero() for efficiency (skip non-selected experts) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + else: + # Inference path: loop over all experts for torch.export compatibility + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Skip if no tokens are assigned to this expert + if top_x.shape[0] == 0: + continue + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim ) diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 2d47b0624f59..b2736901b432 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -220,23 +220,49 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: selected_experts, num_classes=self.num_experts ).permute(2, 1, 0) - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = ( - expert_layer(current_state) * routing_weights[top_x, idx, None] - ) - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) + # Separate paths for training (with .nonzero()) and inference (without .nonzero()) + if self.training: + # Training path: use .nonzero() for efficiency (skip non-selected experts) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + else: + # Inference path: loop over all experts for torch.export compatibility + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Skip if no tokens are assigned to this expert + if top_x.shape[0] == 0: + continue + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim ) diff --git a/tests/models/mixtral/test_mixtral_torch_export.py b/tests/models/mixtral/test_mixtral_torch_export.py index b0aa7699ab14..9bf00a5f5111 100644 --- a/tests/models/mixtral/test_mixtral_torch_export.py +++ b/tests/models/mixtral/test_mixtral_torch_export.py @@ -39,10 +39,10 @@ def setUp(self): ) def test_moe_block_torch_export(self): - """Test that MixtralSparseMoeBlock can be exported with torch.export.""" + """Test that MixtralSparseMoeBlock can be exported with torch.export in inference mode.""" # Create MoE block moe_block = MixtralSparseMoeBlock(self.config) - moe_block.eval() + moe_block.eval() # Set to eval mode for inference path # Move to meta device for export testing moe_block = moe_block.to("meta") @@ -69,7 +69,7 @@ def test_moe_block_torch_export(self): ): self.fail( f"torch.export failed with data-dependent operation error: {error_msg}\n" - "This suggests the .nonzero() fix is not working properly." + "This suggests the inference path fix is not working properly." ) else: # Re-raise other unexpected errors From 6c22f4540bfcd072e88e2416428a4de9f4c0cf8a Mon Sep 17 00:00:00 2001 From: akacmazz Date: Wed, 13 Aug 2025 12:23:03 +0300 Subject: [PATCH 5/8] Fix code quality - apply black and isort formatting - Apply black formatting to fix code style - Fix import sorting with isort - Address CI code quality checks --- .../models/mixtral/modeling_mixtral.py | 20 ++++++------- .../models/mixtral/modular_mixtral.py | 28 ++++++++----------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index fd8d6d8a57f9..ed8e970c1108 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -36,15 +36,15 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...masking_utils import (create_causal_mask, + create_sliding_window_causal_mask) from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import ( - GenericForQuestionAnswering, - GenericForSequenceClassification, - GenericForTokenClassification, - GradientCheckpointingLayer, -) -from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_layers import (GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer) +from ...modeling_outputs import (MoeCausalLMOutputWithPast, + MoeModelOutputWithPast) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -159,11 +159,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) - + # Skip if no tokens are assigned to this expert if top_x.shape[0] == 0: continue - + # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index b2736901b432..2636b0526c06 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -28,27 +28,23 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache -from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...masking_utils import (create_causal_mask, + create_sliding_window_causal_mask) from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_outputs import (MoeCausalLMOutputWithPast, + MoeModelOutputWithPast) from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder -from ..mistral.modeling_mistral import ( - MistralAttention, - MistralForCausalLM, - MistralForQuestionAnswering, - MistralForSequenceClassification, - MistralForTokenClassification, - MistralModel, - MistralPreTrainedModel, - MistralRMSNorm, - MistralRotaryEmbedding, -) +from ..mistral.modeling_mistral import (MistralAttention, MistralForCausalLM, + MistralForQuestionAnswering, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralModel, MistralPreTrainedModel, + MistralRMSNorm, MistralRotaryEmbedding) from .configuration_mixtral import MixtralConfig - logger = logging.get_logger(__name__) @@ -245,11 +241,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) - + # Skip if no tokens are assigned to this expert if top_x.shape[0] == 0: continue - + # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) From 960451137c5618f45498b9ef802f7143b371ad78 Mon Sep 17 00:00:00 2001 From: akacmazz Date: Wed, 13 Aug 2025 13:44:51 +0300 Subject: [PATCH 6/8] Fix ruff import sorting issues - Fix import organization in modeling_mixtral.py - Fix import organization in modular_mixtral.py - Address ruff I001 import sorting warnings --- .../models/mixtral/modeling_mixtral.py | 16 ++++++------- .../models/mixtral/modular_mixtral.py | 24 +++++++++++-------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index ed8e970c1108..f10bfb5cfd9e 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -36,15 +36,15 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...masking_utils import (create_causal_mask, - create_sliding_window_causal_mask) +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import (GenericForQuestionAnswering, - GenericForSequenceClassification, - GenericForTokenClassification, - GradientCheckpointingLayer) -from ...modeling_outputs import (MoeCausalLMOutputWithPast, - MoeModelOutputWithPast) +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 2636b0526c06..847bd56c2016 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -28,23 +28,27 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache -from ...masking_utils import (create_causal_mask, - create_sliding_window_causal_mask) +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import (MoeCausalLMOutputWithPast, - MoeModelOutputWithPast) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder -from ..mistral.modeling_mistral import (MistralAttention, MistralForCausalLM, - MistralForQuestionAnswering, - MistralForSequenceClassification, - MistralForTokenClassification, - MistralModel, MistralPreTrainedModel, - MistralRMSNorm, MistralRotaryEmbedding) +from ..mistral.modeling_mistral import ( + MistralAttention, + MistralForCausalLM, + MistralForQuestionAnswering, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralModel, + MistralPreTrainedModel, + MistralRMSNorm, + MistralRotaryEmbedding, +) from .configuration_mixtral import MixtralConfig + logger = logging.get_logger(__name__) From 068f99da2d291be9962357b71a0779b09031d6b7 Mon Sep 17 00:00:00 2001 From: akacmazz Date: Wed, 13 Aug 2025 14:01:46 +0300 Subject: [PATCH 7/8] Fix repository consistency - auto-generate modeling_mixtral.py - Remove manually edited modeling_mixtral.py - Auto-generate from modular_mixtral.py using proper tool - Ensure consistency between modular and generated files - Fix check_repository_consistency CI failure --- .../models/mixtral/modeling_mixtral.py | 498 ++++++++++++------ 1 file changed, 330 insertions(+), 168 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index f10bfb5cfd9e..7bfc9162eec1 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -30,30 +30,33 @@ import torch.nn.functional as F from torch import nn -from transformers.utils.generic import check_model_inputs - from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import ( - GenericForQuestionAnswering, - GenericForSequenceClassification, - GenericForTokenClassification, - GradientCheckpointingLayer, +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, ) -from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder from .configuration_mixtral import MixtralConfig +logger = logging.get_logger(__name__) + + class MixtralBlockSparseTop2MLP(nn.Module): def __init__(self, config: MixtralConfig): super().__init__() @@ -67,9 +70,7 @@ def __init__(self, config: MixtralConfig): self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( - hidden_states - ) + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) current_hidden_states = self.w2(current_hidden_states) return current_hidden_states @@ -96,9 +97,7 @@ def __init__(self, config): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = nn.ModuleList( - [MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)] - ) + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) # Jitter parameters self.jitter_noise = config.router_jitter_noise @@ -107,17 +106,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: - hidden_states *= torch.empty_like(hidden_states).uniform_( - 1.0 - self.jitter_noise, 1.0 + self.jitter_noise - ) + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk( - routing_weights, self.top_k, dim=-1 - ) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) @@ -130,9 +125,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.num_experts - ).permute(2, 1, 0) + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Separate paths for training (with .nonzero()) and inference (without .nonzero()) if self.training: @@ -145,15 +138,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = ( - expert_layer(current_state) * routing_weights[top_x, idx, None] - ) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) else: # Inference path: loop over all experts for torch.export compatibility for expert_idx in range(self.num_experts): @@ -168,18 +157,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = ( - expert_layer(current_state) * routing_weights[top_x, idx, None] - ) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) - final_hidden_states = final_hidden_states.reshape( - batch_size, sequence_length, hidden_dim - ) + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits @@ -246,9 +229,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -260,7 +241,7 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], + **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -270,12 +251,8 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query.dtype - ) - attn_weights = nn.functional.dropout( - attn_weights, p=dropout, training=module.training - ) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -289,39 +266,25 @@ def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - self.head_dim = ( - getattr(config, "head_dim", None) - or config.hidden_size // config.num_attention_heads - ) - self.num_key_value_groups = ( - config.num_attention_heads // config.num_key_value_heads - ) + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=False - ) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) - @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_values: Optional[Cache] = None, + past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -330,22 +293,16 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_values is not None: + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[ - self.config._attn_implementation - ] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -355,9 +312,7 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=getattr( - self.config, "sliding_window", None - ), # main diff with Llama + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama **kwargs, ) @@ -374,12 +329,8 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.self_attn = MixtralAttention(config, layer_idx) self.block_sparse_moe = MixtralSparseMoeBlock(config) - self.input_layernorm = MixtralRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = MixtralRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -418,15 +369,11 @@ def forward( class MixtralRotaryEmbedding(nn.Module): - inv_freq: torch.Tensor # fix linting for `register_buffer` - def __init__(self, config: MixtralConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): - self.rope_type = config.rope_scaling.get( - "rope_type", config.rope_scaling.get("type") - ) + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings @@ -442,23 +389,12 @@ def __init__(self, config: MixtralConfig, device=None): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - inv_freq_expanded = ( - self.inv_freq[None, :, None] - .float() - .expand(position_ids.shape[0], -1, 1) - .to(x.device) - ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() - device_type = ( - x.device.type - if isinstance(x.device.type, str) and x.device.type != "mps" - else "cpu" - ) + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = ( - inv_freq_expanded.float() @ position_ids_expanded.float() - ).transpose(1, 2) + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling @@ -468,22 +404,39 @@ def forward(self, x, position_ids): @auto_docstring class MixtralPreTrainedModel(PreTrainedModel): - config: MixtralConfig + config_class = MixtralConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MixtralDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn = True + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True - _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _can_record_outputs = { "router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1), "hidden_states": MixtralDecoderLayer, "attentions": MixtralAttention, } + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MixtralRMSNorm): + module.weight.data.fill_(1.0) + @auto_docstring class MixtralModel(MixtralPreTrainedModel): @@ -492,14 +445,9 @@ def __init__(self, config: MixtralConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx - ) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [ - MixtralDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] + [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = MixtralRotaryEmbedding(config=config) @@ -508,7 +456,13 @@ def __init__(self, config: MixtralConfig): # Initialize weights and apply final processing self.post_init() - @check_model_inputs + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple @auto_docstring def forward( self, @@ -522,9 +476,7 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) @@ -533,9 +485,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], @@ -544,11 +494,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - mask_function = ( - create_causal_mask - if self.config.sliding_window is None - else create_sliding_window_causal_mask - ) + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask causal_mask = mask_function( config=self.config, input_embeds=inputs_embeds, @@ -617,9 +563,7 @@ def load_balancing_loss_func( if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat( - [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 - ) + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) @@ -635,24 +579,20 @@ def load_balancing_loss_func( router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // ( - batch_size * sequence_length - ) + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] - .expand( - (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) - ) + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum( - expert_mask.float() * expert_attention_mask, dim=0 - ) / torch.sum(expert_attention_mask, dim=0) + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( @@ -663,9 +603,9 @@ def load_balancing_loss_func( ) # Compute the average probability of routing to these experts - router_prob_per_expert = torch.sum( - routing_weights * router_per_expert_attention_mask, dim=0 - ) / torch.sum(router_per_expert_attention_mask, dim=0) + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts @@ -689,6 +629,18 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + def set_decoder(self, decoder): self.model = decoder @@ -735,9 +687,7 @@ def forward( ```""" output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits + output_router_logits if output_router_logits is not None else self.config.output_router_logits ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -755,11 +705,7 @@ def forward( hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None @@ -775,9 +721,7 @@ def forward( attention_mask, ) if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to( - loss.device - ) # make sure to reside in the same device + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device return MoeCausalLMOutputWithPast( loss=loss, @@ -790,20 +734,238 @@ def forward( ) -class MixtralForSequenceClassification( - GenericForSequenceClassification, MixtralPreTrainedModel -): - pass +@auto_docstring( + custom_intro=""" + The Mixtral Model transformer with a sequence classification head on top (linear layer). + [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. -class MixtralForTokenClassification( - GenericForTokenClassification, MixtralPreTrainedModel -): - pass + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class MixtralForSequenceClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value -class MixtralForQuestionAnswering(GenericForQuestionAnswering, MixtralPreTrainedModel): - pass + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class MixtralForTokenClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class MixtralForQuestionAnswering(MixtralPreTrainedModel): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.model = MixtralModel(config) # diff with Llama: transformer->model + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> QuestionAnsweringModelOutput: + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) __all__ = [ From 7dcc373d9f0a231847d0eb5270efe4f9f704bd27 Mon Sep 17 00:00:00 2001 From: akacmazz Date: Wed, 13 Aug 2025 14:12:16 +0300 Subject: [PATCH 8/8] Fix torch.export data-dependent condition in inference path - Remove 'if top_x.shape[0] == 0: continue' check that causes GuardOnDataDependentSymNode error - Empty expert tensors naturally contribute 0, no explicit check needed - Update test error message for clarity - Fixes tests_processors CI failure Co-authored-by: ArthurZucker --- src/transformers/models/mixtral/modeling_mixtral.py | 4 ---- src/transformers/models/mixtral/modular_mixtral.py | 4 ---- tests/models/mixtral/test_mixtral_torch_export.py | 2 +- 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 7bfc9162eec1..b1d995b27474 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -149,10 +149,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) - # Skip if no tokens are assigned to this expert - if top_x.shape[0] == 0: - continue - # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 847bd56c2016..ba7d957a60c1 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -246,10 +246,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) - # Skip if no tokens are assigned to this expert - if top_x.shape[0] == 0: - continue - # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) diff --git a/tests/models/mixtral/test_mixtral_torch_export.py b/tests/models/mixtral/test_mixtral_torch_export.py index 9bf00a5f5111..0aa39baa61d8 100644 --- a/tests/models/mixtral/test_mixtral_torch_export.py +++ b/tests/models/mixtral/test_mixtral_torch_export.py @@ -69,7 +69,7 @@ def test_moe_block_torch_export(self): ): self.fail( f"torch.export failed with data-dependent operation error: {error_msg}\n" - "This suggests the inference path fix is not working properly." + "This suggests the inference path has data-dependent operations that need to be removed." ) else: # Re-raise other unexpected errors