-
Notifications
You must be signed in to change notification settings - Fork 184
Afmoe flash attention #1626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Afmoe flash attention #1626
Changes from all commits
8850ab3
94ff2aa
9886a41
fd031b0
63fc89b
3653483
d37a047
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import functools | ||
| from dataclasses import dataclass | ||
| from typing import Optional, Union | ||
|
|
||
|
|
@@ -36,6 +37,16 @@ | |
| convert_tt_to_hf_moe, | ||
| ) | ||
|
|
||
| try: | ||
| from flash_attn import flash_attn_varlen_func | ||
| except ImportError: | ||
| flash_attn_varlen_func = None | ||
|
|
||
| try: | ||
| from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func | ||
| except ImportError: | ||
| flash_attn_3_varlen_func = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class AfmoeAttentionConfig: | ||
|
|
@@ -135,6 +146,8 @@ def forward( | |
| hidden_states: torch.Tensor, | ||
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | ||
| attention_mask: torch.Tensor | None = None, | ||
| cu_seqlens: torch.LongTensor | None = None, | ||
| max_seqlen: int | None = None, | ||
| ) -> tuple[torch.Tensor, None]: | ||
| query_states, key_states, value_states, gate_states, input_shape = self._project_states( | ||
| hidden_states, position_embeddings | ||
|
|
@@ -154,8 +167,75 @@ def forward( | |
| return self._finalize_output(attn_output, gate_states, input_shape) | ||
|
|
||
|
|
||
| class AfmoeFlashAttention(AfmoeAttentionBase): | ||
| """AFMoE attention using Flash Attention.""" | ||
|
|
||
| def __init__(self, config: AfmoeAttentionConfig, flash_attn_version: int = 2): | ||
| super().__init__(config) | ||
| self.flash_attn_version = flash_attn_version | ||
| self.func = flash_attn_3_varlen_func if flash_attn_version == 3 else flash_attn_varlen_func | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | ||
| attention_mask: torch.Tensor | None = None, | ||
| cu_seqlens: torch.LongTensor | None = None, | ||
| max_seqlen: int | None = None, | ||
| ) -> tuple[torch.Tensor, None]: | ||
| if self.func is None: | ||
| raise ImportError("flash-attn is not installed but flash attention was requested.") | ||
|
|
||
| query_states, key_states, value_states, gate_states, input_shape = self._project_states( | ||
| hidden_states, position_embeddings | ||
| ) | ||
|
|
||
| query_states = query_states.transpose(1, 2).contiguous() | ||
| key_states = key_states.transpose(1, 2).contiguous() | ||
| value_states = value_states.transpose(1, 2).contiguous() | ||
|
|
||
| q = query_states.reshape(-1, self.num_heads, self.head_dim) | ||
| k = key_states.reshape(-1, self.num_heads, self.head_dim) | ||
| v = value_states.reshape(-1, self.num_heads, self.head_dim) | ||
|
|
||
| if cu_seqlens is None: | ||
| batch_size, seq_len = input_shape | ||
| lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=hidden_states.device) | ||
| cu_seqlens = torch.zeros((batch_size + 1,), dtype=torch.int32, device=hidden_states.device) | ||
| cu_seqlens[1:] = torch.cumsum(lengths, dim=0) | ||
| max_seqlen = seq_len | ||
| elif max_seqlen is None: | ||
| max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item()) | ||
|
|
||
| attn_kwargs = {"causal": True} | ||
| if self.is_local_attention and self.sliding_window is not None: | ||
| attn_kwargs["window_size"] = (self.sliding_window, 0) | ||
|
|
||
| out = self.func( | ||
| q, | ||
| k, | ||
| v, | ||
| cu_seqlens, | ||
| cu_seqlens, | ||
| max_seqlen, | ||
| max_seqlen, | ||
| **attn_kwargs, | ||
| ) | ||
|
|
||
| if isinstance(out, tuple): | ||
| out = out[0] | ||
|
|
||
| attn_output = out.view(*input_shape, self.num_heads, self.head_dim) | ||
| attn_output = attn_output.view(*input_shape, -1) | ||
| attn_output = attn_output * torch.sigmoid(gate_states) | ||
| attn_output = self.o_proj(attn_output) | ||
| return attn_output, None | ||
|
|
||
|
|
||
| AFMOE_ATTN_IMPL2CLASS = { | ||
| "sdpa": AfmoeSDPAAttention, | ||
| "flash_attention_2": functools.partial(AfmoeFlashAttention, flash_attn_version=2), | ||
| "flash_attention_3": functools.partial(AfmoeFlashAttention, flash_attn_version=3), | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -194,9 +274,7 @@ def _get_afmoe_attention(config: AfmoeConfig, layer_idx: int) -> nn.Module: | |
| if attn_impl not in AFMOE_ATTN_IMPL2CLASS: | ||
| supported = list(AFMOE_ATTN_IMPL2CLASS.keys()) | ||
| raise ValueError( | ||
| f"AFMoE attention does not support '{config._attn_implementation}'. " | ||
| f"Supported implementations: {supported}. " | ||
| f"Note: flash_attention is not supported for AFMoE due to sliding window + RoPE constraints." | ||
| f"AFMoE attention does not support '{config._attn_implementation}'. Supported implementations: {supported}." | ||
| ) | ||
|
|
||
| return AFMOE_ATTN_IMPL2CLASS[attn_impl](attn_config) | ||
|
|
@@ -245,6 +323,8 @@ def forward( | |
| hidden_states: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, | ||
| cu_seqlens: Optional[torch.LongTensor] = None, | ||
| max_seqlen: Optional[int] = None, | ||
| ) -> torch.FloatTensor: | ||
| residual = hidden_states | ||
|
|
||
|
|
@@ -253,6 +333,8 @@ def forward( | |
| hidden_states=hidden_states, | ||
| position_embeddings=position_embeddings, | ||
| attention_mask=attention_mask, | ||
| cu_seqlens=cu_seqlens, | ||
| max_seqlen=max_seqlen, | ||
| ) | ||
| hidden_states = self.post_attention_layernorm(hidden_states) | ||
| hidden_states = residual + hidden_states | ||
|
|
@@ -280,7 +362,7 @@ class AfmoePreTrainedModel(PreTrainedModelPrimeRL): | |
| "norm", | ||
| ] | ||
| _supports_sdpa = True | ||
| _supports_flash_attn = False | ||
| _supports_flash_attn = True | ||
| _supports_flex_attn = False | ||
| _supports_attention_backend = True | ||
| _can_compile_fullgraph = False | ||
|
|
@@ -359,19 +441,41 @@ def forward( | |
| position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) | ||
| cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) | ||
|
|
||
| if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3"): | ||
| flat_position_ids = position_ids.view(-1) | ||
| seqlens = torch.cat( | ||
| [ | ||
| flat_position_ids[0:1], | ||
| flat_position_ids[:-1][(flat_position_ids == 0)[1:]] + 1, | ||
| flat_position_ids[-1:] + 1, | ||
| ] | ||
| ) | ||
| max_seqlen = seqlens.max().item() | ||
| cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32) | ||
| torch._dynamo.mark_dynamic(cu_seqlens, 0) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cu_seqlens computation ignores actual batch sizeHigh Severity When Additional Locations (1) |
||
| else: | ||
| max_seqlen = None | ||
| cu_seqlens = None | ||
|
|
||
| if not isinstance(causal_mask_mapping := attention_mask, dict): | ||
| mask_kwargs = { | ||
| "config": self.config, | ||
| "input_embeds": inputs_embeds, | ||
| "attention_mask": attention_mask, | ||
| "cache_position": cache_position, | ||
| "past_key_values": None, | ||
| "position_ids": position_ids, | ||
| } | ||
| causal_mask_mapping = { | ||
| "full_attention": create_causal_mask(**mask_kwargs), | ||
| "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), | ||
| } | ||
| if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3"): | ||
| causal_mask_mapping = { | ||
| "full_attention": None, | ||
| "sliding_attention": None, | ||
| } | ||
| else: | ||
| mask_kwargs = { | ||
| "config": self.config, | ||
| "input_embeds": inputs_embeds, | ||
| "attention_mask": attention_mask, | ||
| "cache_position": cache_position, | ||
| "past_key_values": None, | ||
| "position_ids": position_ids, | ||
| } | ||
| causal_mask_mapping = { | ||
| "full_attention": create_causal_mask(**mask_kwargs), | ||
| "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), | ||
| } | ||
|
|
||
| hidden_states = inputs_embeds | ||
|
|
||
|
|
@@ -387,6 +491,8 @@ def forward( | |
| hidden_states, | ||
| attention_mask=mask, | ||
| position_embeddings=position_embeddings, | ||
| cu_seqlens=cu_seqlens, | ||
| max_seqlen=max_seqlen, | ||
| ) | ||
|
|
||
| hidden_states = self.norm(hidden_states) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flash Attention missing dropout_p parameter
Medium Severity
The
AfmoeFlashAttentionclass does not passdropout_pto the flash attention function, whileAfmoeSDPAAttentionproperly handlesattention_dropoutwithdropout_p = self.attention_dropout if self.training else 0.0. If the model config has non-zeroattention_dropout, SDPA will apply dropout during training but Flash Attention will not, causing inconsistent training behavior between the two attention implementations.