Skip to content

Conversation

akacmazz
Copy link

@akacmazz akacmazz commented Aug 12, 2025

  • Replace data-dependent .nonzero() operation with static expert loop
  • Resolves GuardOnDataDependentSymNode error during torch.export
  • Maintains identical functionality while enabling export compatibility
  • Fixes issue introduced in PR Skip non-selected experts for mixtral and qwen2_moe #32429
  • Add tests for torch.export compatibility

What does this PR do?

This PR fixes a torch.export compatibility issue #38518 with Mixtral MoE models that was introduced in PR #32429.

Problem

The optimization in PR #32429 introduced a .nonzero() operation that creates data-dependent tensor shapes, causing torch.export to fail with:
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression

Solution

Replace the dynamic expert selection loop:

expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:

With a static loop over all experts:
for expert_idx in range(self.num_experts):

Impact

  • ✅ Enables torch.export compatibility for Mixtral models
  • ✅ Maintains identical functionality (empty experts contribute 0 naturally)
  • ✅ Minimal performance impact (same computation, different loop structure)
  • ✅ Consistent with other MoE implementations (Jamba, DBRX)

Testing

  • Verified torch.export works without errors
  • Confirmed functionality preservation with identical outputs
  • Tested with various input configurations

Fixes torch.export compatibility issues reported for Mixtral-8x7B models.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@Cyrilvallez
@ArthurZucker
@gante

- Replace data-dependent .nonzero() operation with static expert loop
- Resolves GuardOnDataDependentSymNode error during torch.export
- Maintains identical functionality while enabling export compatibility
- Fixes issue introduced in PR huggingface#32429
- Add tests for torch.export compatibility
@akacmazz akacmazz changed the title Fix torch.export compatibility for Mixtral MoE models Fix torch.export compatibility for Mixtral MoE models Aug 12, 2025
- Auto-generate modeling_mixtral.py with the same fix
- Apply black formatting
- Fix repository consistency check
@akacmazz akacmazz closed this Aug 12, 2025
@akacmazz akacmazz reopened this Aug 12, 2025
@akacmazz akacmazz force-pushed the fix-mixtral-torch-export-compatibility branch 3 times, most recently from 9b41625 to c3e3c5e Compare August 12, 2025 20:23
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, but the reason we cannot do that is because it is really a lot lot less efficient! I would recommend to rather have 2 paths, 1 for training one for inference, inference can just not loop at all and repeat the inputs

@akacmazz
Copy link
Author

the reason we cannot do that is because it is really a lot lot less efficient! I would recommend to rather have 2 paths, 1 for training one for inference, inference can just not loop at all and repeat the inputs

Thx for the feedback, i will work on this

…ility

  - Training path: Keep efficient .nonzero() for performance
  - Inference path: Use static loop for torch.export compatibility
  - Add conditional check to skip empty experts in inference
  - Update tests to validate inference mode export
  - Addresses maintainer feedback on performance concerns
@akacmazz akacmazz force-pushed the fix-mixtral-torch-export-compatibility branch from 952a181 to 0aa9de7 Compare August 13, 2025 08:51
akacmazz and others added 9 commits August 13, 2025 12:23
- Apply black formatting to fix code style
- Fix import sorting with isort
- Address CI code quality checks
- Fix import organization in modeling_mixtral.py
- Fix import organization in modular_mixtral.py
- Address ruff I001 import sorting warnings
- Remove manually edited modeling_mixtral.py
- Auto-generate from modular_mixtral.py using proper tool
- Ensure consistency between modular and generated files
- Fix check_repository_consistency CI failure
- Remove 'if top_x.shape[0] == 0: continue' check that causes GuardOnDataDependentSymNode error
- Empty expert tensors naturally contribute 0, no explicit check needed
- Update test error message for clarity
- Fixes tests_processors CI failure

Co-authored-by: ArthurZucker <[email protected]>
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: mixtral

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants