Skip to content

Commit fd031b0

Browse files
cursoragentsami
andcommitted
Add flash attention support to AFMoE
Co-authored-by: sami <sami@primeintellect.ai> Sort AFMoE imports for ruff Co-authored-by: sami <sami@primeintellect.ai> Align AFMoE flash attn with sliding window Co-authored-by: sami <sami@primeintellect.ai>
1 parent c3de771 commit fd031b0

File tree

1 file changed

+121
-15
lines changed

1 file changed

+121
-15
lines changed

src/prime_rl/trainer/models/afmoe/modeling_afmoe.py

Lines changed: 121 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
from dataclasses import dataclass
23
from typing import Optional, Union
34

@@ -36,6 +37,16 @@
3637
convert_tt_to_hf_moe,
3738
)
3839

40+
try:
41+
from flash_attn import flash_attn_varlen_func
42+
except ImportError:
43+
flash_attn_varlen_func = None
44+
45+
try:
46+
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
47+
except ImportError:
48+
flash_attn_3_varlen_func = None
49+
3950

4051
@dataclass
4152
class AfmoeAttentionConfig:
@@ -135,6 +146,8 @@ def forward(
135146
hidden_states: torch.Tensor,
136147
position_embeddings: tuple[torch.Tensor, torch.Tensor],
137148
attention_mask: torch.Tensor | None = None,
149+
cu_seqlens: torch.LongTensor | None = None,
150+
max_seqlen: int | None = None,
138151
) -> tuple[torch.Tensor, None]:
139152
query_states, key_states, value_states, gate_states, input_shape = self._project_states(
140153
hidden_states, position_embeddings
@@ -154,8 +167,75 @@ def forward(
154167
return self._finalize_output(attn_output, gate_states, input_shape)
155168

156169

170+
class AfmoeFlashAttention(AfmoeAttentionBase):
171+
"""AFMoE attention using Flash Attention."""
172+
173+
def __init__(self, config: AfmoeAttentionConfig, flash_attn_version: int = 2):
174+
super().__init__(config)
175+
self.flash_attn_version = flash_attn_version
176+
self.func = flash_attn_3_varlen_func if flash_attn_version == 3 else flash_attn_varlen_func
177+
178+
def forward(
179+
self,
180+
hidden_states: torch.Tensor,
181+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
182+
attention_mask: torch.Tensor | None = None,
183+
cu_seqlens: torch.LongTensor | None = None,
184+
max_seqlen: int | None = None,
185+
) -> tuple[torch.Tensor, None]:
186+
if self.func is None:
187+
raise ImportError("flash-attn is not installed but flash attention was requested.")
188+
189+
query_states, key_states, value_states, gate_states, input_shape = self._project_states(
190+
hidden_states, position_embeddings
191+
)
192+
193+
query_states = query_states.transpose(1, 2).contiguous()
194+
key_states = key_states.transpose(1, 2).contiguous()
195+
value_states = value_states.transpose(1, 2).contiguous()
196+
197+
q = query_states.reshape(-1, self.num_heads, self.head_dim)
198+
k = key_states.reshape(-1, self.num_heads, self.head_dim)
199+
v = value_states.reshape(-1, self.num_heads, self.head_dim)
200+
201+
if cu_seqlens is None:
202+
batch_size, seq_len = input_shape
203+
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=hidden_states.device)
204+
cu_seqlens = torch.zeros((batch_size + 1,), dtype=torch.int32, device=hidden_states.device)
205+
cu_seqlens[1:] = torch.cumsum(lengths, dim=0)
206+
max_seqlen = seq_len
207+
elif max_seqlen is None:
208+
max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
209+
210+
dropout_p = self.attention_dropout if self.training else 0.0
211+
attn_kwargs = {"causal": True, "dropout_p": dropout_p}
212+
if self.is_local_attention and self.sliding_window is not None:
213+
attn_kwargs["window_size"] = (self.sliding_window, 0)
214+
215+
out = self.func(
216+
q,
217+
k,
218+
v,
219+
cu_seqlens,
220+
cu_seqlens,
221+
max_seqlen,
222+
max_seqlen,
223+
**attn_kwargs,
224+
)
225+
if isinstance(out, tuple):
226+
out = out[0]
227+
228+
attn_output = out.view(*input_shape, self.num_heads, self.head_dim)
229+
attn_output = attn_output.view(*input_shape, -1)
230+
attn_output = attn_output * torch.sigmoid(gate_states)
231+
attn_output = self.o_proj(attn_output)
232+
return attn_output, None
233+
234+
157235
AFMOE_ATTN_IMPL2CLASS = {
158236
"sdpa": AfmoeSDPAAttention,
237+
"flash_attention_2": functools.partial(AfmoeFlashAttention, flash_attn_version=2),
238+
"flash_attention_3": functools.partial(AfmoeFlashAttention, flash_attn_version=3),
159239
}
160240

161241

@@ -195,8 +275,7 @@ def _get_afmoe_attention(config: AfmoeConfig, layer_idx: int) -> nn.Module:
195275
supported = list(AFMOE_ATTN_IMPL2CLASS.keys())
196276
raise ValueError(
197277
f"AFMoE attention does not support '{config._attn_implementation}'. "
198-
f"Supported implementations: {supported}. "
199-
f"Note: flash_attention is not supported for AFMoE due to sliding window + RoPE constraints."
278+
f"Supported implementations: {supported}."
200279
)
201280

202281
return AFMOE_ATTN_IMPL2CLASS[attn_impl](attn_config)
@@ -245,6 +324,8 @@ def forward(
245324
hidden_states: torch.Tensor,
246325
attention_mask: Optional[torch.Tensor] = None,
247326
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
327+
cu_seqlens: Optional[torch.LongTensor] = None,
328+
max_seqlen: Optional[int] = None,
248329
) -> torch.FloatTensor:
249330
residual = hidden_states
250331

@@ -253,6 +334,8 @@ def forward(
253334
hidden_states=hidden_states,
254335
position_embeddings=position_embeddings,
255336
attention_mask=attention_mask,
337+
cu_seqlens=cu_seqlens,
338+
max_seqlen=max_seqlen,
256339
)
257340
hidden_states = self.post_attention_layernorm(hidden_states)
258341
hidden_states = residual + hidden_states
@@ -280,7 +363,7 @@ class AfmoePreTrainedModel(PreTrainedModelPrimeRL):
280363
"norm",
281364
]
282365
_supports_sdpa = True
283-
_supports_flash_attn = False
366+
_supports_flash_attn = True
284367
_supports_flex_attn = False
285368
_supports_attention_backend = True
286369
_can_compile_fullgraph = False
@@ -359,19 +442,40 @@ def forward(
359442
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
360443
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
361444

445+
if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3"):
446+
batch_size, seq_len = inputs_embeds.shape[:2]
447+
cu_seqlens = torch.arange(
448+
0,
449+
(batch_size + 1) * seq_len,
450+
step=seq_len,
451+
dtype=torch.int32,
452+
device=inputs_embeds.device,
453+
)
454+
max_seqlen = seq_len
455+
torch._dynamo.mark_dynamic(cu_seqlens, 0)
456+
else:
457+
max_seqlen = None
458+
cu_seqlens = None
459+
362460
if not isinstance(causal_mask_mapping := attention_mask, dict):
363-
mask_kwargs = {
364-
"config": self.config,
365-
"input_embeds": inputs_embeds,
366-
"attention_mask": attention_mask,
367-
"cache_position": cache_position,
368-
"past_key_values": None,
369-
"position_ids": position_ids,
370-
}
371-
causal_mask_mapping = {
372-
"full_attention": create_causal_mask(**mask_kwargs),
373-
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
374-
}
461+
if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3"):
462+
causal_mask_mapping = {
463+
"full_attention": None,
464+
"sliding_attention": None,
465+
}
466+
else:
467+
mask_kwargs = {
468+
"config": self.config,
469+
"input_embeds": inputs_embeds,
470+
"attention_mask": attention_mask,
471+
"cache_position": cache_position,
472+
"past_key_values": None,
473+
"position_ids": position_ids,
474+
}
475+
causal_mask_mapping = {
476+
"full_attention": create_causal_mask(**mask_kwargs),
477+
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
478+
}
375479

376480
hidden_states = inputs_embeds
377481

@@ -387,6 +491,8 @@ def forward(
387491
hidden_states,
388492
attention_mask=mask,
389493
position_embeddings=position_embeddings,
494+
cu_seqlens=cu_seqlens,
495+
max_seqlen=max_seqlen,
390496
)
391497

392498
hidden_states = self.norm(hidden_states)

0 commit comments

Comments
 (0)