@@ -407,8 +407,9 @@ def __init__(
407407 causal = causal ,
408408 ** kwargs ,
409409 )
410- from lmdeploy .pytorch .third_party .flash_attn_interface import flash_attn_varlen_func
410+ from lmdeploy .pytorch .third_party .flash_attn_interface import flash_attn_varlen_func , flash_attn_with_kvcache
411411 self .flash_attn_varlen_func_v3 = flash_attn_varlen_func
412+ self .flash_attn_with_kvcache_v3 = flash_attn_with_kvcache
412413
413414 def forward (
414415 self ,
@@ -460,11 +461,10 @@ def forward(
460461 quant_policy = quant_policy ,
461462 )
462463
463- q_shape = query .shape
464- o_shape = q_shape [:- 1 ] + (self .v_head_size , )
465- attn_output = query .new_empty (o_shape )
466-
467464 if is_decoding :
465+ q_shape = query .shape
466+ o_shape = q_shape [:- 1 ] + (self .v_head_size , )
467+ attn_output = query .new_empty (o_shape )
468468 self .paged_attention_fwd (
469469 query ,
470470 k_cache ,
@@ -480,6 +480,24 @@ def forward(
480480 logit_softcapping = self .logit_softcapping ,
481481 )
482482 else :
483+ sliding_window = (- 1 , - 1 ) if self .sliding_window is None else self .sliding_window
484+ if isinstance (sliding_window , int ):
485+ sliding_window = (sliding_window , sliding_window )
486+ attn_output = self .flash_attn_with_kvcache_v3 (
487+ query ,
488+ k_cache ,
489+ v_cache ,
490+ cache_seqlens = attn_metadata .kv_seqlens .to (torch .int32 ),
491+ cu_seqlens_q = attn_metadata .cu_seqlens_q ,
492+ cu_seqlens_k_new = attn_metadata .cu_seqlens_k ,
493+ max_seqlen_q = max_q_seqlen ,
494+ page_table = block_offsets ,
495+ softmax_scale = self .scale ,
496+ causal = self .causal ,
497+ window_size = sliding_window ,
498+ softcap = - 1.0 if self .logit_softcapping is None else self .logit_softcapping ,
499+ )
500+ return attn_output
483501 flatten_k , flatten_v = self .flatten_kv_cache (
484502 k_cache ,
485503 v_cache ,
@@ -527,6 +545,7 @@ def build(
527545 logical_softcapping : float = None ,
528546 causal : bool = True ,
529547 use_flash_mla : bool = False ,
548+ use_flash_attn3 : bool = False ,
530549 learnable_sink : bool = False ,
531550 ** kwargs ,
532551 ) -> TritonAttentionImpl :
@@ -542,7 +561,7 @@ def build(
542561 logical_softcapping = logical_softcapping ,
543562 causal = causal ,
544563 ** kwargs )
545- elif use_fa3 and not alibi and not learnable_sink :
564+ elif use_flash_attn3 and not alibi and not learnable_sink :
546565 return FA3Impl (num_heads ,
547566 head_size ,
548567 scale = scale ,
0 commit comments