Skip to content

Conversation

@samsja
Copy link
Member

@samsja samsja commented Jan 20, 2026

Implement Flash Attention (FA2/FA3) for the AFMoE model with sliding window support.

This enables the use of Flash Attention for AFMoE, which was previously unsupported due to perceived conflicts with sliding window and RoPE constraints, allowing for potential performance improvements.


Open in Cursor Open in Web


Note

Introduces Flash Attention for AFMoE and wires it into the model, including sliding window support via the varlen API.

  • New AfmoeFlashAttention implementing FA2/FA3 using flash_attn_varlen_func/flash_attn_3_varlen_func; added to AFMOE_ATTN_IMPL2CLASS
  • Computes seqlens, cu_seqlens, and max_seqlen from position_ids when using flash attention; disables explicit causal masks for FA paths
  • Plumbs cu_seqlens/max_seqlen through AfmoeModelAfmoeDecoderLayer → attention forward
  • Supports local/sliding attention by setting window_size in FA kwargs; preserves RoPE application
  • Marks _supports_flash_attn = True and adjusts unsupported-impl error message

Written by Cursor Bugbot for commit d37a047. This will update automatically on new commits. Configure here.

@cursor
Copy link

cursor bot commented Jan 20, 2026

Cursor Agent can help with this pull request. Just @cursor in comments and I'll start working on changes in this branch.
Learn more about Cursor Agents

cursoragent and others added 3 commits January 20, 2026 23:14
Co-authored-by: sami <[email protected]>

Sort AFMoE imports for ruff

Co-authored-by: sami <[email protected]>

Align AFMoE flash attn with sliding window

Co-authored-by: sami <[email protected]>
@samsja samsja force-pushed the cursor/afmoe-flash-attention-96f0 branch from 9886a41 to fd031b0 Compare January 21, 2026 00:06
@samsja samsja marked this pull request as ready for review January 21, 2026 17:29
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

max_seqlen,
max_seqlen,
**attn_kwargs,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flash Attention missing dropout_p parameter

Medium Severity

The AfmoeFlashAttention class does not pass dropout_p to the flash attention function, while AfmoeSDPAAttention properly handles attention_dropout with dropout_p = self.attention_dropout if self.training else 0.0. If the model config has non-zero attention_dropout, SDPA will apply dropout during training but Flash Attention will not, causing inconsistent training behavior between the two attention implementations.

Fix in Cursor Fix in Web

)
max_seqlen = seqlens.max().item()
cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32)
torch._dynamo.mark_dynamic(cu_seqlens, 0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cu_seqlens computation ignores actual batch size

High Severity

When position_ids is None and batch_size > 1, the cu_seqlens computation produces incorrect results. The default position_ids has shape [1, seq_len], so flat_position_ids.view(-1) yields only seq_len elements. The resulting cu_seqlens (e.g., [0, seq_len]) indicates a single sequence, but hidden_states contains batch_size * seq_len tokens. Flash attention then processes only the first seq_len tokens, silently ignoring the rest of the batch. The fallback in AfmoeFlashAttention.forward() would correctly handle this, but it's bypassed because cu_seqlens is already provided.

Additional Locations (1)

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants