Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 122 additions & 16 deletions src/prime_rl/trainer/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from dataclasses import dataclass
from typing import Optional, Union

Expand Down Expand Up @@ -36,6 +37,16 @@
convert_tt_to_hf_moe,
)

try:
from flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None

try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
except ImportError:
flash_attn_3_varlen_func = None


@dataclass
class AfmoeAttentionConfig:
Expand Down Expand Up @@ -135,6 +146,8 @@ def forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
max_seqlen: int | None = None,
) -> tuple[torch.Tensor, None]:
query_states, key_states, value_states, gate_states, input_shape = self._project_states(
hidden_states, position_embeddings
Expand All @@ -154,8 +167,75 @@ def forward(
return self._finalize_output(attn_output, gate_states, input_shape)


class AfmoeFlashAttention(AfmoeAttentionBase):
"""AFMoE attention using Flash Attention."""

def __init__(self, config: AfmoeAttentionConfig, flash_attn_version: int = 2):
super().__init__(config)
self.flash_attn_version = flash_attn_version
self.func = flash_attn_3_varlen_func if flash_attn_version == 3 else flash_attn_varlen_func

def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
max_seqlen: int | None = None,
) -> tuple[torch.Tensor, None]:
if self.func is None:
raise ImportError("flash-attn is not installed but flash attention was requested.")

query_states, key_states, value_states, gate_states, input_shape = self._project_states(
hidden_states, position_embeddings
)

query_states = query_states.transpose(1, 2).contiguous()
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()

q = query_states.reshape(-1, self.num_heads, self.head_dim)
k = key_states.reshape(-1, self.num_heads, self.head_dim)
v = value_states.reshape(-1, self.num_heads, self.head_dim)

if cu_seqlens is None:
batch_size, seq_len = input_shape
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=hidden_states.device)
cu_seqlens = torch.zeros((batch_size + 1,), dtype=torch.int32, device=hidden_states.device)
cu_seqlens[1:] = torch.cumsum(lengths, dim=0)
max_seqlen = seq_len
elif max_seqlen is None:
max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())

attn_kwargs = {"causal": True}
if self.is_local_attention and self.sliding_window is not None:
attn_kwargs["window_size"] = (self.sliding_window, 0)

out = self.func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
**attn_kwargs,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flash Attention missing dropout_p parameter

Medium Severity

The AfmoeFlashAttention class does not pass dropout_p to the flash attention function, while AfmoeSDPAAttention properly handles attention_dropout with dropout_p = self.attention_dropout if self.training else 0.0. If the model config has non-zero attention_dropout, SDPA will apply dropout during training but Flash Attention will not, causing inconsistent training behavior between the two attention implementations.

Fix in Cursor Fix in Web


if isinstance(out, tuple):
out = out[0]

attn_output = out.view(*input_shape, self.num_heads, self.head_dim)
attn_output = attn_output.view(*input_shape, -1)
attn_output = attn_output * torch.sigmoid(gate_states)
attn_output = self.o_proj(attn_output)
return attn_output, None


AFMOE_ATTN_IMPL2CLASS = {
"sdpa": AfmoeSDPAAttention,
"flash_attention_2": functools.partial(AfmoeFlashAttention, flash_attn_version=2),
"flash_attention_3": functools.partial(AfmoeFlashAttention, flash_attn_version=3),
}


Expand Down Expand Up @@ -194,9 +274,7 @@ def _get_afmoe_attention(config: AfmoeConfig, layer_idx: int) -> nn.Module:
if attn_impl not in AFMOE_ATTN_IMPL2CLASS:
supported = list(AFMOE_ATTN_IMPL2CLASS.keys())
raise ValueError(
f"AFMoE attention does not support '{config._attn_implementation}'. "
f"Supported implementations: {supported}. "
f"Note: flash_attention is not supported for AFMoE due to sliding window + RoPE constraints."
f"AFMoE attention does not support '{config._attn_implementation}'. Supported implementations: {supported}."
)

return AFMOE_ATTN_IMPL2CLASS[attn_impl](attn_config)
Expand Down Expand Up @@ -245,6 +323,8 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
) -> torch.FloatTensor:
residual = hidden_states

Expand All @@ -253,6 +333,8 @@ def forward(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -280,7 +362,7 @@ class AfmoePreTrainedModel(PreTrainedModelPrimeRL):
"norm",
]
_supports_sdpa = True
_supports_flash_attn = False
_supports_flash_attn = True
_supports_flex_attn = False
_supports_attention_backend = True
_can_compile_fullgraph = False
Expand Down Expand Up @@ -359,19 +441,41 @@ def forward(
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)

if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3"):
flat_position_ids = position_ids.view(-1)
seqlens = torch.cat(
[
flat_position_ids[0:1],
flat_position_ids[:-1][(flat_position_ids == 0)[1:]] + 1,
flat_position_ids[-1:] + 1,
]
)
max_seqlen = seqlens.max().item()
cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32)
torch._dynamo.mark_dynamic(cu_seqlens, 0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cu_seqlens computation ignores actual batch size

High Severity

When position_ids is None and batch_size > 1, the cu_seqlens computation produces incorrect results. The default position_ids has shape [1, seq_len], so flat_position_ids.view(-1) yields only seq_len elements. The resulting cu_seqlens (e.g., [0, seq_len]) indicates a single sequence, but hidden_states contains batch_size * seq_len tokens. Flash attention then processes only the first seq_len tokens, silently ignoring the rest of the batch. The fallback in AfmoeFlashAttention.forward() would correctly handle this, but it's bypassed because cu_seqlens is already provided.

Additional Locations (1)

Fix in Cursor Fix in Web

else:
max_seqlen = None
cu_seqlens = None

if not isinstance(causal_mask_mapping := attention_mask, dict):
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": None,
"position_ids": position_ids,
}
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3"):
causal_mask_mapping = {
"full_attention": None,
"sliding_attention": None,
}
else:
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": None,
"position_ids": position_ids,
}
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}

hidden_states = inputs_embeds

Expand All @@ -387,6 +491,8 @@ def forward(
hidden_states,
attention_mask=mask,
position_embeddings=position_embeddings,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

hidden_states = self.norm(hidden_states)
Expand Down