-
Notifications
You must be signed in to change notification settings - Fork 178
Afmoe flash attention #1626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Afmoe flash attention #1626
Conversation
Co-authored-by: sami <[email protected]>
|
Cursor Agent can help with this pull request. Just |
Co-authored-by: sami <[email protected]>
Co-authored-by: sami <[email protected]>
Co-authored-by: sami <[email protected]> Sort AFMoE imports for ruff Co-authored-by: sami <[email protected]> Align AFMoE flash attn with sliding window Co-authored-by: sami <[email protected]>
9886a41 to
fd031b0
Compare
Co-authored-by: sami <[email protected]>
Co-authored-by: sami <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
| max_seqlen, | ||
| max_seqlen, | ||
| **attn_kwargs, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flash Attention missing dropout_p parameter
Medium Severity
The AfmoeFlashAttention class does not pass dropout_p to the flash attention function, while AfmoeSDPAAttention properly handles attention_dropout with dropout_p = self.attention_dropout if self.training else 0.0. If the model config has non-zero attention_dropout, SDPA will apply dropout during training but Flash Attention will not, causing inconsistent training behavior between the two attention implementations.
| ) | ||
| max_seqlen = seqlens.max().item() | ||
| cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32) | ||
| torch._dynamo.mark_dynamic(cu_seqlens, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cu_seqlens computation ignores actual batch size
High Severity
When position_ids is None and batch_size > 1, the cu_seqlens computation produces incorrect results. The default position_ids has shape [1, seq_len], so flat_position_ids.view(-1) yields only seq_len elements. The resulting cu_seqlens (e.g., [0, seq_len]) indicates a single sequence, but hidden_states contains batch_size * seq_len tokens. Flash attention then processes only the first seq_len tokens, silently ignoring the rest of the batch. The fallback in AfmoeFlashAttention.forward() would correctly handle this, but it's bypassed because cu_seqlens is already provided.
Implement Flash Attention (FA2/FA3) for the AFMoE model with sliding window support.
This enables the use of Flash Attention for AFMoE, which was previously unsupported due to perceived conflicts with sliding window and RoPE constraints, allowing for potential performance improvements.
Note
Introduces Flash Attention for AFMoE and wires it into the model, including sliding window support via the varlen API.
AfmoeFlashAttentionimplementing FA2/FA3 usingflash_attn_varlen_func/flash_attn_3_varlen_func; added toAFMOE_ATTN_IMPL2CLASSseqlens,cu_seqlens, andmax_seqlenfromposition_idswhen using flash attention; disables explicit causal masks for FA pathscu_seqlens/max_seqlenthroughAfmoeModel→AfmoeDecoderLayer→ attention forwardwindow_sizein FA kwargs; preserves RoPE application_supports_flash_attn = Trueand adjusts unsupported-impl error messageWritten by Cursor Bugbot for commit d37a047. This will update automatically on new commits. Configure here.