10
10
import torch
11
11
12
12
from vllm .logger import init_logger
13
+ from vllm .platforms import current_platform
13
14
from vllm .triton_utils import tl , triton
14
15
15
16
logger = init_logger (__name__ )
17
+ float8_info = torch .finfo (current_platform .fp8_dtype ())
16
18
17
19
18
20
@triton .jit
@@ -48,47 +50,51 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
48
50
49
51
@triton .jit
50
52
def kernel_unified_attention_2d (
51
- output_ptr , # [num_tokens, num_query_heads, head_size]
52
- query_ptr , # [num_tokens, num_query_heads, head_size]
53
- key_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
54
- value_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
55
- sink_ptr , # [num_query_heads]
56
- block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
57
- seq_lens_ptr , # [num_seqs]
58
- alibi_slopes_ptr , # [num_query_heads]
59
- qq_bias_ptr , # [num_query_tokens, num_query_tokens]
60
- scale , # float32
61
- k_scale , # float32
62
- v_scale , # float32
63
- softcap , # float32
64
- num_query_heads : tl .constexpr , # int
65
- num_queries_per_kv : tl .constexpr , # int
66
- block_table_stride : tl .int64 , # int
67
- query_stride_0 : tl .int64 , # int
68
- query_stride_1 : tl .int64 , # int, should be equal to head_size
69
- output_stride_0 : tl .int64 , # int
70
- output_stride_1 : tl .int64 , # int, should be equal to head_size
71
- qq_bias_stride_0 : tl .int64 , # int
72
- BLOCK_SIZE : tl .constexpr , # int
73
- HEAD_SIZE : tl .constexpr , # int
74
- HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
75
- USE_ALIBI_SLOPES : tl .constexpr , # bool
76
- USE_QQ_BIAS : tl .constexpr , # bool
77
- USE_SOFTCAP : tl .constexpr , # bool
78
- USE_SINKS : tl .constexpr , # bool
79
- SLIDING_WINDOW : tl .constexpr , # int
80
- stride_k_cache_0 : tl .int64 , # int
81
- stride_k_cache_1 : tl .int64 , # int
82
- stride_k_cache_2 : tl .int64 , # int
83
- stride_k_cache_3 : tl .constexpr , # int
84
- stride_v_cache_0 : tl .int64 , # int
85
- stride_v_cache_1 : tl .int64 , # int
86
- stride_v_cache_2 : tl .int64 , # int
87
- stride_v_cache_3 : tl .constexpr , # int
88
- query_start_len_ptr , # [num_seqs+1]
89
- BLOCK_Q : tl .constexpr , # int
90
- num_seqs : tl .int32 ,
91
- BLOCK_M : tl .constexpr , # int
53
+ output_ptr , # [num_tokens, num_query_heads, head_size]
54
+ query_ptr , # [num_tokens, num_query_heads, head_size]
55
+ key_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
56
+ value_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
57
+ sink_ptr , # [num_query_heads]
58
+ block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
59
+ seq_lens_ptr , # [num_seqs]
60
+ alibi_slopes_ptr , # [num_query_heads]
61
+ qq_bias_ptr , # [num_query_tokens, num_query_tokens]
62
+ scale , # float32
63
+ k_scale , # float32
64
+ v_scale , # float32
65
+ out_scale , # float32
66
+ softcap , # float32
67
+ num_query_heads : tl .constexpr , # int
68
+ num_queries_per_kv : tl .constexpr , # int
69
+ block_table_stride : tl .int64 , # int
70
+ query_stride_0 : tl .int64 , # int
71
+ query_stride_1 : tl .int64 , # int, should be equal to head_size
72
+ output_stride_0 : tl .int64 , # int
73
+ output_stride_1 : tl .int64 , # int, should be equal to head_size
74
+ qq_bias_stride_0 : tl .int64 , # int
75
+ BLOCK_SIZE : tl .constexpr , # int
76
+ HEAD_SIZE : tl .constexpr , # int
77
+ HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
78
+ USE_ALIBI_SLOPES : tl .constexpr , # bool
79
+ USE_QQ_BIAS : tl .constexpr , # bool
80
+ USE_SOFTCAP : tl .constexpr , # bool
81
+ USE_SINKS : tl .constexpr , # bool
82
+ SLIDING_WINDOW : tl .constexpr , # int
83
+ stride_k_cache_0 : tl .int64 , # int
84
+ stride_k_cache_1 : tl .int64 , # int
85
+ stride_k_cache_2 : tl .int64 , # int
86
+ stride_k_cache_3 : tl .constexpr , # int
87
+ stride_v_cache_0 : tl .int64 , # int
88
+ stride_v_cache_1 : tl .int64 , # int
89
+ stride_v_cache_2 : tl .int64 , # int
90
+ stride_v_cache_3 : tl .constexpr , # int
91
+ query_start_len_ptr , # [num_seqs+1]
92
+ BLOCK_Q : tl .constexpr , # int
93
+ num_seqs : tl .int32 ,
94
+ BLOCK_M : tl .constexpr , # int
95
+ USE_FP8 : tl .constexpr , # bool
96
+ FP8_MIN : tl .constexpr = float8_info .min ,
97
+ FP8_MAX : tl .constexpr = float8_info .max ,
92
98
):
93
99
q_block_global_idx = tl .program_id (0 )
94
100
kv_head_idx = tl .program_id (1 )
@@ -281,6 +287,9 @@ def kernel_unified_attention_2d(
281
287
282
288
# epilogue
283
289
acc = acc / L [:, None ]
290
+ if USE_FP8 :
291
+ acc = acc * tl .load (out_scale )
292
+ acc = tl .clamp (acc , FP8_MIN , FP8_MAX )
284
293
285
294
output_offset = (query_offset_0 [:, None ] * output_stride_0 +
286
295
query_offset_1 [:, None ] * output_stride_1 +
@@ -552,23 +561,27 @@ def kernel_unified_attention_3d(
552
561
553
562
@triton .jit
554
563
def reduce_segments (
555
- output_ptr , # [num_tokens, num_query_heads, head_size]
556
- segm_output_ptr ,
557
- #[num_tokens, num_query_heads, max_num_segments, head_size]
558
- segm_max_ptr , # [num_tokens, num_query_heads, max_num_segments]
559
- segm_expsum_ptr , # [num_tokens, num_query_heads, max_num_segments]
560
- seq_lens_ptr , # [num_seqs]
561
- num_seqs , # int
562
- num_query_heads : tl .constexpr , # int
563
- output_stride_0 : tl .int64 , # int
564
- output_stride_1 : tl .int64 , # int, should be equal to head_size
565
- block_table_stride : tl .int64 , # int
566
- BLOCK_SIZE : tl .constexpr , # int
567
- HEAD_SIZE : tl .constexpr , # int, must be power of 2
568
- HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
569
- query_start_len_ptr , # [num_seqs+1]
570
- BLOCK_Q : tl .constexpr , # int
571
- NUM_SEGMENTS_PER_SEQ : tl .constexpr , # int
564
+ output_ptr , # [num_tokens, num_query_heads, head_size]
565
+ segm_output_ptr ,
566
+ #[num_tokens, num_query_heads, max_num_segments, head_size]
567
+ segm_max_ptr , # [num_tokens, num_query_heads, max_num_segments]
568
+ segm_expsum_ptr , # [num_tokens, num_query_heads, max_num_segments]
569
+ seq_lens_ptr , # [num_seqs]
570
+ num_seqs , # int
571
+ num_query_heads : tl .constexpr , # int
572
+ out_scale_inv , # float32
573
+ output_stride_0 : tl .int64 , # int
574
+ output_stride_1 : tl .int64 , # int, should be equal to head_size
575
+ block_table_stride : tl .int64 , # int
576
+ BLOCK_SIZE : tl .constexpr , # int
577
+ HEAD_SIZE : tl .constexpr , # int, must be power of 2
578
+ HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
579
+ query_start_len_ptr , # [num_seqs+1]
580
+ BLOCK_Q : tl .constexpr , # int
581
+ NUM_SEGMENTS_PER_SEQ : tl .constexpr , # int
582
+ USE_FP8 : tl .constexpr , # bool
583
+ FP8_MIN : tl .constexpr = float8_info .min ,
584
+ FP8_MAX : tl .constexpr = float8_info .max ,
572
585
):
573
586
query_token_idx = tl .program_id (0 )
574
587
query_head_idx = tl .program_id (1 )
@@ -624,6 +637,10 @@ def reduce_segments(
624
637
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
625
638
acc = tl .where (overall_expsum == 0.0 , 0.0 , acc_sum / overall_expsum )
626
639
640
+ if USE_FP8 :
641
+ acc = acc * tl .load (out_scale_inv )
642
+ acc = tl .clamp (acc , FP8_MIN , FP8_MAX )
643
+
627
644
# write result
628
645
output_offset = (query_token_idx * output_stride_0 +
629
646
query_head_idx * output_stride_1 +
@@ -649,6 +666,7 @@ def unified_attention(
649
666
k_descale ,
650
667
v_descale ,
651
668
alibi_slopes = None ,
669
+ output_scale = None ,
652
670
qq_bias = None ,
653
671
# Optional tensor for sinks
654
672
sinks = None ,
@@ -706,6 +724,7 @@ def unified_attention(
706
724
scale = softmax_scale ,
707
725
k_scale = k_descale ,
708
726
v_scale = v_descale ,
727
+ out_scale = 1 / output_scale if output_scale is not None else 1.0 ,
709
728
softcap = softcap ,
710
729
num_query_heads = num_query_heads ,
711
730
num_queries_per_kv = num_queries_per_kv ,
@@ -735,6 +754,7 @@ def unified_attention(
735
754
BLOCK_Q = BLOCK_Q ,
736
755
num_seqs = num_seqs ,
737
756
BLOCK_M = BLOCK_M ,
757
+ USE_FP8 = output_scale is not None ,
738
758
)
739
759
else :
740
760
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
@@ -818,6 +838,8 @@ def unified_attention(
818
838
seq_lens_ptr = seqused_k ,
819
839
num_seqs = num_seqs ,
820
840
num_query_heads = num_query_heads ,
841
+ out_scale_inv = 1 /
842
+ output_scale if output_scale is not None else 1.0 ,
821
843
output_stride_0 = out .stride (0 ),
822
844
output_stride_1 = out .stride (1 ),
823
845
block_table_stride = block_table .stride (0 ),
@@ -827,4 +849,5 @@ def unified_attention(
827
849
query_start_len_ptr = cu_seqlens_q ,
828
850
BLOCK_Q = BLOCK_Q ,
829
851
NUM_SEGMENTS_PER_SEQ = NUM_SEGMENTS ,
852
+ USE_FP8 = output_scale is not None ,
830
853
)
0 commit comments