1+ import functools
12from dataclasses import dataclass
23from typing import Optional , Union
34
3637 convert_tt_to_hf_moe ,
3738)
3839
40+ try :
41+ from flash_attn import flash_attn_varlen_func
42+ except ImportError :
43+ flash_attn_varlen_func = None
44+
45+ try :
46+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
47+ except ImportError :
48+ flash_attn_3_varlen_func = None
49+
3950
4051@dataclass
4152class AfmoeAttentionConfig :
@@ -135,6 +146,8 @@ def forward(
135146 hidden_states : torch .Tensor ,
136147 position_embeddings : tuple [torch .Tensor , torch .Tensor ],
137148 attention_mask : torch .Tensor | None = None ,
149+ cu_seqlens : torch .LongTensor | None = None ,
150+ max_seqlen : int | None = None ,
138151 ) -> tuple [torch .Tensor , None ]:
139152 query_states , key_states , value_states , gate_states , input_shape = self ._project_states (
140153 hidden_states , position_embeddings
@@ -154,8 +167,75 @@ def forward(
154167 return self ._finalize_output (attn_output , gate_states , input_shape )
155168
156169
170+ class AfmoeFlashAttention (AfmoeAttentionBase ):
171+ """AFMoE attention using Flash Attention."""
172+
173+ def __init__ (self , config : AfmoeAttentionConfig , flash_attn_version : int = 2 ):
174+ super ().__init__ (config )
175+ self .flash_attn_version = flash_attn_version
176+ self .func = flash_attn_3_varlen_func if flash_attn_version == 3 else flash_attn_varlen_func
177+
178+ def forward (
179+ self ,
180+ hidden_states : torch .Tensor ,
181+ position_embeddings : tuple [torch .Tensor , torch .Tensor ],
182+ attention_mask : torch .Tensor | None = None ,
183+ cu_seqlens : torch .LongTensor | None = None ,
184+ max_seqlen : int | None = None ,
185+ ) -> tuple [torch .Tensor , None ]:
186+ if self .func is None :
187+ raise ImportError ("flash-attn is not installed but flash attention was requested." )
188+
189+ query_states , key_states , value_states , gate_states , input_shape = self ._project_states (
190+ hidden_states , position_embeddings
191+ )
192+
193+ query_states = query_states .transpose (1 , 2 ).contiguous ()
194+ key_states = key_states .transpose (1 , 2 ).contiguous ()
195+ value_states = value_states .transpose (1 , 2 ).contiguous ()
196+
197+ q = query_states .reshape (- 1 , self .num_heads , self .head_dim )
198+ k = key_states .reshape (- 1 , self .num_heads , self .head_dim )
199+ v = value_states .reshape (- 1 , self .num_heads , self .head_dim )
200+
201+ if cu_seqlens is None :
202+ batch_size , seq_len = input_shape
203+ lengths = torch .full ((batch_size ,), seq_len , dtype = torch .int32 , device = hidden_states .device )
204+ cu_seqlens = torch .zeros ((batch_size + 1 ,), dtype = torch .int32 , device = hidden_states .device )
205+ cu_seqlens [1 :] = torch .cumsum (lengths , dim = 0 )
206+ max_seqlen = seq_len
207+ elif max_seqlen is None :
208+ max_seqlen = int ((cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ())
209+
210+ dropout_p = self .attention_dropout if self .training else 0.0
211+ attn_kwargs = {"causal" : True , "dropout_p" : dropout_p }
212+ if self .is_local_attention and self .sliding_window is not None :
213+ attn_kwargs ["window_size" ] = (self .sliding_window , 0 )
214+
215+ out = self .func (
216+ q ,
217+ k ,
218+ v ,
219+ cu_seqlens ,
220+ cu_seqlens ,
221+ max_seqlen ,
222+ max_seqlen ,
223+ ** attn_kwargs ,
224+ )
225+ if isinstance (out , tuple ):
226+ out = out [0 ]
227+
228+ attn_output = out .view (* input_shape , self .num_heads , self .head_dim )
229+ attn_output = attn_output .view (* input_shape , - 1 )
230+ attn_output = attn_output * torch .sigmoid (gate_states )
231+ attn_output = self .o_proj (attn_output )
232+ return attn_output , None
233+
234+
157235AFMOE_ATTN_IMPL2CLASS = {
158236 "sdpa" : AfmoeSDPAAttention ,
237+ "flash_attention_2" : functools .partial (AfmoeFlashAttention , flash_attn_version = 2 ),
238+ "flash_attention_3" : functools .partial (AfmoeFlashAttention , flash_attn_version = 3 ),
159239}
160240
161241
@@ -195,8 +275,7 @@ def _get_afmoe_attention(config: AfmoeConfig, layer_idx: int) -> nn.Module:
195275 supported = list (AFMOE_ATTN_IMPL2CLASS .keys ())
196276 raise ValueError (
197277 f"AFMoE attention does not support '{ config ._attn_implementation } '. "
198- f"Supported implementations: { supported } . "
199- f"Note: flash_attention is not supported for AFMoE due to sliding window + RoPE constraints."
278+ f"Supported implementations: { supported } ."
200279 )
201280
202281 return AFMOE_ATTN_IMPL2CLASS [attn_impl ](attn_config )
@@ -245,6 +324,8 @@ def forward(
245324 hidden_states : torch .Tensor ,
246325 attention_mask : Optional [torch .Tensor ] = None ,
247326 position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None ,
327+ cu_seqlens : Optional [torch .LongTensor ] = None ,
328+ max_seqlen : Optional [int ] = None ,
248329 ) -> torch .FloatTensor :
249330 residual = hidden_states
250331
@@ -253,6 +334,8 @@ def forward(
253334 hidden_states = hidden_states ,
254335 position_embeddings = position_embeddings ,
255336 attention_mask = attention_mask ,
337+ cu_seqlens = cu_seqlens ,
338+ max_seqlen = max_seqlen ,
256339 )
257340 hidden_states = self .post_attention_layernorm (hidden_states )
258341 hidden_states = residual + hidden_states
@@ -280,7 +363,7 @@ class AfmoePreTrainedModel(PreTrainedModelPrimeRL):
280363 "norm" ,
281364 ]
282365 _supports_sdpa = True
283- _supports_flash_attn = False
366+ _supports_flash_attn = True
284367 _supports_flex_attn = False
285368 _supports_attention_backend = True
286369 _can_compile_fullgraph = False
@@ -359,19 +442,40 @@ def forward(
359442 position_ids = torch .arange (inputs_embeds .shape [1 ], device = inputs_embeds .device ).unsqueeze (0 )
360443 cache_position = torch .arange (inputs_embeds .shape [1 ], device = inputs_embeds .device )
361444
445+ if self .config ._attn_implementation in ("flash_attention_2" , "flash_attention_3" ):
446+ batch_size , seq_len = inputs_embeds .shape [:2 ]
447+ cu_seqlens = torch .arange (
448+ 0 ,
449+ (batch_size + 1 ) * seq_len ,
450+ step = seq_len ,
451+ dtype = torch .int32 ,
452+ device = inputs_embeds .device ,
453+ )
454+ max_seqlen = seq_len
455+ torch ._dynamo .mark_dynamic (cu_seqlens , 0 )
456+ else :
457+ max_seqlen = None
458+ cu_seqlens = None
459+
362460 if not isinstance (causal_mask_mapping := attention_mask , dict ):
363- mask_kwargs = {
364- "config" : self .config ,
365- "input_embeds" : inputs_embeds ,
366- "attention_mask" : attention_mask ,
367- "cache_position" : cache_position ,
368- "past_key_values" : None ,
369- "position_ids" : position_ids ,
370- }
371- causal_mask_mapping = {
372- "full_attention" : create_causal_mask (** mask_kwargs ),
373- "sliding_attention" : create_sliding_window_causal_mask (** mask_kwargs ),
374- }
461+ if self .config ._attn_implementation in ("flash_attention_2" , "flash_attention_3" ):
462+ causal_mask_mapping = {
463+ "full_attention" : None ,
464+ "sliding_attention" : None ,
465+ }
466+ else :
467+ mask_kwargs = {
468+ "config" : self .config ,
469+ "input_embeds" : inputs_embeds ,
470+ "attention_mask" : attention_mask ,
471+ "cache_position" : cache_position ,
472+ "past_key_values" : None ,
473+ "position_ids" : position_ids ,
474+ }
475+ causal_mask_mapping = {
476+ "full_attention" : create_causal_mask (** mask_kwargs ),
477+ "sliding_attention" : create_sliding_window_causal_mask (** mask_kwargs ),
478+ }
375479
376480 hidden_states = inputs_embeds
377481
@@ -387,6 +491,8 @@ def forward(
387491 hidden_states ,
388492 attention_mask = mask ,
389493 position_embeddings = position_embeddings ,
494+ cu_seqlens = cu_seqlens ,
495+ max_seqlen = max_seqlen ,
390496 )
391497
392498 hidden_states = self .norm (hidden_states )
0 commit comments