Skip to content

Commit 0aa9de7

Browse files
committed
Implement separate training/inference paths for torch.export compatibility
- Training path: Keep efficient .nonzero() for performance - Inference path: Use static loop for torch.export compatibility - Add conditional check to skip empty experts in inference - Update tests to validate inference mode export - Addresses maintainer feedback on performance concerns
1 parent c3e3c5e commit 0aa9de7

File tree

3 files changed

+89
-37
lines changed

3 files changed

+89
-37
lines changed

src/transformers/models/mixtral/modeling_mixtral.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -134,23 +134,49 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
134134
selected_experts, num_classes=self.num_experts
135135
).permute(2, 1, 0)
136136

137-
# Loop over all available experts in the model and perform the computation on each expert
138-
for expert_idx in range(self.num_experts):
139-
expert_layer = self.experts[expert_idx]
140-
idx, top_x = torch.where(expert_mask[expert_idx])
141-
# Index the correct hidden states and compute the expert hidden state for
142-
# the current expert. We need to make sure to multiply the output hidden
143-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
144-
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
145-
current_hidden_states = (
146-
expert_layer(current_state) * routing_weights[top_x, idx, None]
147-
)
148-
149-
# However `index_add_` only support torch tensors for indexing so we'll use
150-
# the `top_x` tensor here.
151-
final_hidden_states.index_add_(
152-
0, top_x, current_hidden_states.to(hidden_states.dtype)
153-
)
137+
# Separate paths for training (with .nonzero()) and inference (without .nonzero())
138+
if self.training:
139+
# Training path: use .nonzero() for efficiency (skip non-selected experts)
140+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
141+
for expert_idx in expert_hit:
142+
expert_layer = self.experts[expert_idx]
143+
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
144+
# Index the correct hidden states and compute the expert hidden state for
145+
# the current expert. We need to make sure to multiply the output hidden
146+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
147+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
148+
current_hidden_states = (
149+
expert_layer(current_state) * routing_weights[top_x, idx, None]
150+
)
151+
152+
# However `index_add_` only support torch tensors for indexing so we'll use
153+
# the `top_x` tensor here.
154+
final_hidden_states.index_add_(
155+
0, top_x, current_hidden_states.to(hidden_states.dtype)
156+
)
157+
else:
158+
# Inference path: loop over all experts for torch.export compatibility
159+
for expert_idx in range(self.num_experts):
160+
expert_layer = self.experts[expert_idx]
161+
idx, top_x = torch.where(expert_mask[expert_idx])
162+
163+
# Skip if no tokens are assigned to this expert
164+
if top_x.shape[0] == 0:
165+
continue
166+
167+
# Index the correct hidden states and compute the expert hidden state for
168+
# the current expert. We need to make sure to multiply the output hidden
169+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
170+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
171+
current_hidden_states = (
172+
expert_layer(current_state) * routing_weights[top_x, idx, None]
173+
)
174+
175+
# However `index_add_` only support torch tensors for indexing so we'll use
176+
# the `top_x` tensor here.
177+
final_hidden_states.index_add_(
178+
0, top_x, current_hidden_states.to(hidden_states.dtype)
179+
)
154180
final_hidden_states = final_hidden_states.reshape(
155181
batch_size, sequence_length, hidden_dim
156182
)

src/transformers/models/mixtral/modular_mixtral.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -220,23 +220,49 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
220220
selected_experts, num_classes=self.num_experts
221221
).permute(2, 1, 0)
222222

223-
# Loop over all available experts in the model and perform the computation on each expert
224-
for expert_idx in range(self.num_experts):
225-
expert_layer = self.experts[expert_idx]
226-
idx, top_x = torch.where(expert_mask[expert_idx])
227-
# Index the correct hidden states and compute the expert hidden state for
228-
# the current expert. We need to make sure to multiply the output hidden
229-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
230-
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
231-
current_hidden_states = (
232-
expert_layer(current_state) * routing_weights[top_x, idx, None]
233-
)
234-
235-
# However `index_add_` only support torch tensors for indexing so we'll use
236-
# the `top_x` tensor here.
237-
final_hidden_states.index_add_(
238-
0, top_x, current_hidden_states.to(hidden_states.dtype)
239-
)
223+
# Separate paths for training (with .nonzero()) and inference (without .nonzero())
224+
if self.training:
225+
# Training path: use .nonzero() for efficiency (skip non-selected experts)
226+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
227+
for expert_idx in expert_hit:
228+
expert_layer = self.experts[expert_idx]
229+
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
230+
# Index the correct hidden states and compute the expert hidden state for
231+
# the current expert. We need to make sure to multiply the output hidden
232+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
233+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
234+
current_hidden_states = (
235+
expert_layer(current_state) * routing_weights[top_x, idx, None]
236+
)
237+
238+
# However `index_add_` only support torch tensors for indexing so we'll use
239+
# the `top_x` tensor here.
240+
final_hidden_states.index_add_(
241+
0, top_x, current_hidden_states.to(hidden_states.dtype)
242+
)
243+
else:
244+
# Inference path: loop over all experts for torch.export compatibility
245+
for expert_idx in range(self.num_experts):
246+
expert_layer = self.experts[expert_idx]
247+
idx, top_x = torch.where(expert_mask[expert_idx])
248+
249+
# Skip if no tokens are assigned to this expert
250+
if top_x.shape[0] == 0:
251+
continue
252+
253+
# Index the correct hidden states and compute the expert hidden state for
254+
# the current expert. We need to make sure to multiply the output hidden
255+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
256+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
257+
current_hidden_states = (
258+
expert_layer(current_state) * routing_weights[top_x, idx, None]
259+
)
260+
261+
# However `index_add_` only support torch tensors for indexing so we'll use
262+
# the `top_x` tensor here.
263+
final_hidden_states.index_add_(
264+
0, top_x, current_hidden_states.to(hidden_states.dtype)
265+
)
240266
final_hidden_states = final_hidden_states.reshape(
241267
batch_size, sequence_length, hidden_dim
242268
)

tests/models/mixtral/test_mixtral_torch_export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def setUp(self):
3939
)
4040

4141
def test_moe_block_torch_export(self):
42-
"""Test that MixtralSparseMoeBlock can be exported with torch.export."""
42+
"""Test that MixtralSparseMoeBlock can be exported with torch.export in inference mode."""
4343
# Create MoE block
4444
moe_block = MixtralSparseMoeBlock(self.config)
45-
moe_block.eval()
45+
moe_block.eval() # Set to eval mode for inference path
4646

4747
# Move to meta device for export testing
4848
moe_block = moe_block.to("meta")
@@ -69,7 +69,7 @@ def test_moe_block_torch_export(self):
6969
):
7070
self.fail(
7171
f"torch.export failed with data-dependent operation error: {error_msg}\n"
72-
"This suggests the .nonzero() fix is not working properly."
72+
"This suggests the inference path fix is not working properly."
7373
)
7474
else:
7575
# Re-raise other unexpected errors

0 commit comments

Comments
 (0)