Skip to content

Commit 5201904

Browse files
committed
Fix torch.export compatibility for Mixtral MoE models
- Replace data-dependent .nonzero() operation with static expert loop - Resolves GuardOnDataDependentSymNode error during torch.export - Maintains identical functionality while enabling export compatibility - Fixes issue introduced in PR #32429 - Add tests for torch.export compatibility
1 parent a07b5e9 commit 5201904

File tree

2 files changed

+157
-3
lines changed

2 files changed

+157
-3
lines changed

src/transformers/models/mixtral/modular_mixtral.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202202
# this will be used to easily index which expert is going to be sollicitated
203203
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
204204

205-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
206-
for expert_idx in expert_hit:
205+
# Loop over all available experts in the model and perform the computation on each expert
206+
for expert_idx in range(self.num_experts):
207207
expert_layer = self.experts[expert_idx]
208-
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
208+
idx, top_x = torch.where(expert_mask[expert_idx])
209209
# Index the correct hidden states and compute the expert hidden state for
210210
# the current expert. We need to make sure to multiply the output hidden
211211
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Testing torch.export compatibility for Mixtral models."""
16+
17+
import unittest
18+
19+
import torch
20+
import torch.export as te
21+
22+
from transformers import MixtralConfig
23+
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
24+
from transformers.testing_utils import require_torch, torch_device
25+
26+
27+
@require_torch
28+
class MixtralTorchExportTest(unittest.TestCase):
29+
"""Test torch.export compatibility for Mixtral MoE components."""
30+
31+
def setUp(self):
32+
"""Set up test configuration."""
33+
self.config = MixtralConfig(
34+
hidden_size=128,
35+
intermediate_size=256,
36+
num_local_experts=8,
37+
num_experts_per_tok=2,
38+
router_jitter_noise=0.0,
39+
)
40+
41+
def test_moe_block_torch_export(self):
42+
"""Test that MixtralSparseMoeBlock can be exported with torch.export."""
43+
# Create MoE block
44+
moe_block = MixtralSparseMoeBlock(self.config)
45+
moe_block.eval()
46+
47+
# Move to meta device for export testing
48+
moe_block = moe_block.to("meta")
49+
50+
# Create test input
51+
batch_size, seq_len = 2, 8
52+
hidden_states = torch.randn(
53+
batch_size, seq_len, self.config.hidden_size,
54+
device="meta"
55+
)
56+
57+
# Test torch.export - should not raise GuardOnDataDependentSymNode error
58+
try:
59+
exported_program = te.export(
60+
moe_block,
61+
args=(hidden_states,),
62+
kwargs={},
63+
strict=False
64+
)
65+
# If export succeeds, the test passes
66+
self.assertIsNotNone(exported_program)
67+
except Exception as e:
68+
# Check if it's the specific error we're trying to avoid
69+
error_msg = str(e)
70+
if "GuardOnDataDependentSymNode" in error_msg or "nonzero" in error_msg.lower():
71+
self.fail(
72+
f"torch.export failed with data-dependent operation error: {error_msg}\n"
73+
"This suggests the .nonzero() fix is not working properly."
74+
)
75+
else:
76+
# Re-raise other unexpected errors
77+
raise
78+
79+
def test_moe_block_functionality(self):
80+
"""Test that MoE block maintains correct functionality after the fix."""
81+
# Create MoE block
82+
moe_block = MixtralSparseMoeBlock(self.config)
83+
moe_block.eval()
84+
85+
# Create test input
86+
batch_size, seq_len = 2, 4
87+
hidden_states = torch.randn(batch_size, seq_len, self.config.hidden_size)
88+
89+
# Forward pass
90+
with torch.no_grad():
91+
output, router_logits = moe_block(hidden_states)
92+
93+
# Verify output shapes
94+
self.assertEqual(output.shape, hidden_states.shape)
95+
self.assertEqual(
96+
router_logits.shape,
97+
(batch_size * seq_len, self.config.num_local_experts)
98+
)
99+
100+
# Verify that outputs are not all zeros (computation happened)
101+
self.assertFalse(torch.allclose(output, torch.zeros_like(output)))
102+
103+
# Test with different input to ensure different outputs
104+
hidden_states2 = torch.randn(batch_size, seq_len, self.config.hidden_size)
105+
with torch.no_grad():
106+
output2, _ = moe_block(hidden_states2)
107+
108+
# Outputs should be different for different inputs
109+
self.assertFalse(torch.allclose(output, output2))
110+
111+
def test_moe_block_export_with_different_configs(self):
112+
"""Test torch.export with various expert configurations."""
113+
test_configs = [
114+
# (num_experts, top_k)
115+
(4, 2),
116+
(8, 2),
117+
(16, 2),
118+
(8, 4),
119+
]
120+
121+
for num_experts, top_k in test_configs:
122+
with self.subTest(num_experts=num_experts, top_k=top_k):
123+
config = MixtralConfig(
124+
hidden_size=64,
125+
intermediate_size=128,
126+
num_local_experts=num_experts,
127+
num_experts_per_tok=top_k,
128+
router_jitter_noise=0.0,
129+
)
130+
131+
moe_block = MixtralSparseMoeBlock(config)
132+
moe_block.eval()
133+
moe_block = moe_block.to("meta")
134+
135+
hidden_states = torch.randn(1, 4, config.hidden_size, device="meta")
136+
137+
# Should export without errors
138+
try:
139+
exported_program = te.export(
140+
moe_block,
141+
args=(hidden_states,),
142+
kwargs={},
143+
strict=False
144+
)
145+
self.assertIsNotNone(exported_program)
146+
except Exception as e:
147+
if "GuardOnDataDependentSymNode" in str(e):
148+
self.fail(f"Export failed for config ({num_experts}, {top_k}): {e}")
149+
else:
150+
raise
151+
152+
153+
if __name__ == "__main__":
154+
unittest.main()

0 commit comments

Comments
 (0)