66import torch
77
88from lmdeploy .pytorch .distributed import get_tp_world_rank
9+ from lmdeploy .utils import get_logger
910
1011from ..attention import AttentionBuilder , AttentionImpl , AttentionMetadata
1112
13+ logger = get_logger ('lmdeploy' )
14+
15+ use_fa3 = False
16+ try :
17+ # Now flash-attention only support FA3 for sm90a && cuda >= 12.3
18+ if (torch .cuda .get_device_capability ()[0 ] == 9 ) and (torch .version .cuda >= '12.3' ):
19+ import flash_attn_interface # noqa: F401
20+ assert torch .ops .flash_attn_3 is not None
21+ use_fa3 = True
22+ except Exception :
23+ logger .warning ('For higher performance, please install FlashAttention-3 '
24+ 'https://github.com/Dao-AILab/flash-attention' )
25+
1226
1327@dataclass
1428class TritonAttentionMetadata (AttentionMetadata ):
@@ -25,6 +39,8 @@ class TritonAttentionMetadata(AttentionMetadata):
2539 # flash mla
2640 tile_scheduler_metadata : torch .Tensor = None
2741 num_splits : torch .Tensor = None
42+ cu_seqlens_q : torch .Tensor = None
43+ cu_seqlens_k : torch .Tensor = None
2844
2945
3046def _cdiv (a , b ):
@@ -89,7 +105,6 @@ def forward(
89105 inplace : bool = True ,
90106 ) -> torch .Tensor :
91107 """forward."""
92-
93108 block_offsets = attn_metadata .block_offsets
94109 q_start_loc = attn_metadata .q_start_loc
95110 fill_q_start_loc = q_start_loc
@@ -129,7 +144,6 @@ def forward(
129144 q_shape = query .shape
130145 o_shape = q_shape [:- 1 ] + (self .v_head_size , )
131146 attn_output = query .new_empty (o_shape )
132-
133147 is_decoding = attn_metadata .is_decoding
134148 if not self .alibi :
135149 if is_decoding :
@@ -286,7 +300,6 @@ def forward(
286300
287301 q_shape = query .shape
288302 o_shape = q_shape [:- 1 ] + (self .v_head_size , )
289- attn_output = query .new_empty (o_shape )
290303
291304 is_decoding = attn_metadata .is_decoding
292305 if is_decoding :
@@ -302,7 +315,6 @@ def forward(
302315 tile_scheduler_metadata = attn_metadata .tile_scheduler_metadata ,
303316 num_splits = attn_metadata .num_splits ,
304317 causal = True )
305-
306318 else :
307319 BLOCK_BS = k_cache .size (1 )
308320 # pad one more block to avoid invalid kv visit
@@ -313,26 +325,179 @@ def forward(
313325 kv_seqlens ,
314326 block_offsets ,
315327 start_loc = kv_start_loc ,
316- out_size = out_size ,
328+ out_size = kv_flatten_size if use_fa3 else out_size ,
317329 out_dtype = query .dtype ,
318330 k_scales_zeros = k_scales_zeros ,
319331 v_scales_zeros = v_scales_zeros ,
320332 quant_policy = quant_policy ,
333+ flatten_kv_layout = 'shd' if use_fa3 else 'hsd' ,
321334 )
322- self .flash_attention_fwd (
335+ if use_fa3 :
336+ q_rope = query [:, :, self .v_head_size :]
337+ q_nope = query [:, :, :self .v_head_size ]
338+ k_rope = flatten_k .view (kv_flatten_size , self .num_kv_heads , - 1 )[:, :, self .v_head_size :]
339+ c_kv = flatten_k .view (kv_flatten_size , self .num_kv_heads , - 1 )[:, :, :self .v_head_size ]
340+ from flash_attn_interface import flash_attn_varlen_func
341+ attn_output , _ = flash_attn_varlen_func (
342+ q = q_rope ,
343+ k = k_rope ,
344+ v = c_kv ,
345+ qv = q_nope ,
346+ cu_seqlens_q = attn_metadata .cu_seqlens_q ,
347+ cu_seqlens_k = attn_metadata .cu_seqlens_k ,
348+ max_seqlen_q = max_q_seqlen ,
349+ max_seqlen_k = kv_flatten_size ,
350+ softmax_scale = self .scale ,
351+ causal = self .causal ,
352+ window_size = (- 1 , - 1 ) if self .sliding_window is None else self .sliding_window ,
353+ softcap = - 1.0 if self .logit_softcapping is None else self .logit_softcapping ,
354+ )
355+ else :
356+ attn_output = query .new_empty (o_shape )
357+ self .flash_attention_fwd (
358+ query ,
359+ flatten_k ,
360+ flatten_v ,
361+ attn_output ,
362+ q_start_loc = q_start_loc ,
363+ q_seqlens = q_seqlens ,
364+ kv_start_loc = kv_start_loc ,
365+ kv_seqlens = kv_seqlens ,
366+ max_seqlen = max_q_seqlen ,
367+ window_size = self .sliding_window ,
368+ sm_scale = self .scale ,
369+ logit_softcapping = self .logit_softcapping ,
370+ causal = self .causal ,
371+ )
372+ return attn_output
373+
374+
375+ class FA3Impl (TritonAttentionImpl ):
376+ """Triton attention implementation."""
377+
378+ def __init__ (
379+ self ,
380+ num_heads : int ,
381+ head_size : int ,
382+ scale : float = None ,
383+ num_kv_heads : int = None ,
384+ v_head_size : int = None ,
385+ alibi : bool = False ,
386+ sliding_window : int = None ,
387+ logit_softcapping : float = None ,
388+ causal : bool = True ,
389+ ** kwargs ,
390+ ):
391+ assert alibi is False , 'alibi not supported for FA3'
392+ super ().__init__ (
393+ num_heads = num_heads ,
394+ head_size = head_size ,
395+ scale = scale ,
396+ num_kv_heads = num_kv_heads ,
397+ v_head_size = v_head_size ,
398+ alibi = alibi ,
399+ sliding_window = sliding_window ,
400+ logit_softcapping = logit_softcapping ,
401+ causal = causal ,
402+ ** kwargs ,
403+ )
404+ from flash_attn_interface import flash_attn_varlen_func
405+ self .flash_attn_varlen_func_v3 = flash_attn_varlen_func
406+
407+ def forward (
408+ self ,
409+ query : torch .Tensor ,
410+ key : torch .Tensor ,
411+ value : torch .Tensor ,
412+ k_cache : torch .Tensor ,
413+ v_cache : torch .Tensor ,
414+ attn_metadata : TritonAttentionMetadata ,
415+ k_scales_zeros : torch .Tensor = None ,
416+ v_scales_zeros : torch .Tensor = None ,
417+ inplace : bool = True ,
418+ ) -> torch .Tensor :
419+ """forward."""
420+ block_offsets = attn_metadata .block_offsets
421+ q_start_loc = attn_metadata .q_start_loc
422+ fill_q_start_loc = q_start_loc
423+ q_seqlens = attn_metadata .q_seqlens
424+ fill_seqlens = q_seqlens
425+ kv_start_loc = attn_metadata .kv_start_loc
426+ kv_seqlens = attn_metadata .kv_seqlens
427+ kv_flatten_size = attn_metadata .kv_flatten_size
428+ quant_policy = attn_metadata .quant_policy
429+ if attn_metadata .is_decoding :
430+ max_q_seqlen = 1
431+ else :
432+ max_q_seqlen = query .numel () // (query .size (- 1 ) * query .size (- 2 ))
433+ fill_max_q_seqlen = max_q_seqlen
434+ if attn_metadata .fill_seqlens is not None :
435+ fill_seqlens = attn_metadata .fill_seqlens
436+ fill_max_q_seqlen = key .numel () // (key .size (- 1 ) * key .size (- 2 ))
437+ fill_q_start_loc = fill_seqlens .cumsum (0 ) - fill_seqlens
438+ is_decoding = attn_metadata .is_decoding
439+ # fill kv cache
440+ if key is not None and value is not None :
441+ self .fill_kv_cache (
442+ key ,
443+ value ,
444+ k_cache ,
445+ v_cache ,
446+ fill_q_start_loc ,
447+ fill_seqlens ,
448+ kv_seq_length = kv_seqlens ,
449+ max_q_seq_length = fill_max_q_seqlen ,
450+ block_offsets = block_offsets ,
451+ k_scales_zeros = k_scales_zeros ,
452+ v_scales_zeros = v_scales_zeros ,
453+ quant_policy = quant_policy ,
454+ )
455+
456+ q_shape = query .shape
457+ o_shape = q_shape [:- 1 ] + (self .v_head_size , )
458+ attn_output = query .new_empty (o_shape )
459+
460+ if is_decoding :
461+ self .paged_attention_fwd (
323462 query ,
324- flatten_k ,
325- flatten_v ,
463+ k_cache ,
464+ v_cache ,
326465 attn_output ,
327- q_start_loc = q_start_loc ,
328- q_seqlens = q_seqlens ,
329- kv_start_loc = kv_start_loc ,
466+ block_offsets ,
330467 kv_seqlens = kv_seqlens ,
331- max_seqlen = max_q_seqlen ,
468+ k_scales_zeros = k_scales_zeros ,
469+ v_scales_zeros = v_scales_zeros ,
470+ quant_policy = quant_policy ,
332471 window_size = self .sliding_window ,
333472 sm_scale = self .scale ,
334473 logit_softcapping = self .logit_softcapping ,
474+ )
475+ else :
476+ flatten_k , flatten_v = self .flatten_kv_cache (
477+ k_cache ,
478+ v_cache ,
479+ kv_seqlens ,
480+ block_offsets ,
481+ start_loc = kv_start_loc ,
482+ out_size = kv_flatten_size ,
483+ out_dtype = query .dtype ,
484+ k_scales_zeros = k_scales_zeros ,
485+ v_scales_zeros = v_scales_zeros ,
486+ quant_policy = quant_policy ,
487+ flatten_kv_layout = 'shd' ,
488+ )
489+ attn_output , _ = self .flash_attn_varlen_func_v3 (
490+ q = query ,
491+ k = flatten_k ,
492+ v = flatten_v ,
493+ cu_seqlens_q = attn_metadata .cu_seqlens_q ,
494+ cu_seqlens_k = attn_metadata .cu_seqlens_k ,
495+ max_seqlen_q = max_q_seqlen ,
496+ max_seqlen_k = kv_flatten_size ,
497+ softmax_scale = self .scale ,
335498 causal = self .causal ,
499+ window_size = (- 1 , - 1 ) if self .sliding_window is None else self .sliding_window ,
500+ softcap = - 1.0 if self .logit_softcapping is None else self .logit_softcapping ,
336501 )
337502 return attn_output
338503
@@ -366,13 +531,25 @@ def build(
366531 logical_softcapping = logical_softcapping ,
367532 causal = causal ,
368533 ** kwargs )
369- return TritonAttentionImpl (num_heads ,
370- head_size ,
371- scale = scale ,
372- num_kv_heads = num_kv_heads ,
373- v_head_size = v_head_size ,
374- alibi = alibi ,
375- sliding_window = sliding_window ,
376- logical_softcapping = logical_softcapping ,
377- causal = causal ,
378- ** kwargs )
534+ elif use_fa3 and not alibi :
535+ return FA3Impl (num_heads ,
536+ head_size ,
537+ scale = scale ,
538+ num_kv_heads = num_kv_heads ,
539+ v_head_size = v_head_size ,
540+ alibi = alibi ,
541+ sliding_window = sliding_window ,
542+ logical_softcapping = logical_softcapping ,
543+ causal = causal ,
544+ ** kwargs )
545+ else :
546+ return TritonAttentionImpl (num_heads ,
547+ head_size ,
548+ scale = scale ,
549+ num_kv_heads = num_kv_heads ,
550+ v_head_size = v_head_size ,
551+ alibi = alibi ,
552+ sliding_window = sliding_window ,
553+ logical_softcapping = logical_softcapping ,
554+ causal = causal ,
555+ ** kwargs )
0 commit comments