@@ -51,46 +51,46 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
5151
5252@triton .jit
5353def kernel_unified_attention_2d (
54- output_ptr , # [num_tokens, num_query_heads, head_size]
55- query_ptr , # [num_tokens, num_query_heads, head_size]
56- key_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
57- value_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
58- block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
59- seq_lens_ptr , # [num_seqs]
60- alibi_slopes_ptr , # [num_query_heads]
61- scale , # float32
62- k_scale , # float32
63- v_scale , # float32
64- out_scale , # float32
65- softcap , # float32
66- num_query_heads : tl .constexpr , # int
67- num_queries_per_kv : tl .constexpr , # int
68- block_table_stride : tl .int64 , # int
69- query_stride_0 : tl .int64 , # int
70- query_stride_1 : tl .int64 , # int, should be equal to head_size
71- output_stride_0 : tl .int64 , # int
72- output_stride_1 : tl .int64 , # int, should be equal to head_size
73- BLOCK_SIZE : tl .constexpr , # int
74- HEAD_SIZE : tl .constexpr , # int
75- HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
76- USE_ALIBI_SLOPES : tl .constexpr , # bool
77- USE_SOFTCAP : tl .constexpr , # bool
78- SLIDING_WINDOW : tl .constexpr , # int
79- stride_k_cache_0 : tl .int64 , # int
80- stride_k_cache_1 : tl .int64 , # int
81- stride_k_cache_2 : tl .int64 , # int
82- stride_k_cache_3 : tl .constexpr , # int
83- stride_v_cache_0 : tl .int64 , # int
84- stride_v_cache_1 : tl .int64 , # int
85- stride_v_cache_2 : tl .int64 , # int
86- stride_v_cache_3 : tl .constexpr , # int
87- query_start_len_ptr , # [num_seqs+1]
88- BLOCK_Q : tl .constexpr , # int
89- num_seqs : tl .int32 ,
90- BLOCK_M : tl .constexpr , # int
91- USE_FP8 : tl .constexpr , # bool
92- FP8_MIN : tl .constexpr = float8_info .min ,
93- FP8_MAX : tl .constexpr = float8_info .max ,
54+ output_ptr , # [num_tokens, num_query_heads, head_size]
55+ query_ptr , # [num_tokens, num_query_heads, head_size]
56+ key_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
57+ value_cache_ptr , # [num_blks, blk_size, num_kv_heads, head_size]
58+ block_tables_ptr , # [num_seqs, max_num_blocks_per_seq]
59+ seq_lens_ptr , # [num_seqs]
60+ alibi_slopes_ptr , # [num_query_heads]
61+ scale , # float32
62+ k_scale , # float32
63+ v_scale , # float32
64+ out_scale , # float32
65+ softcap , # float32
66+ num_query_heads : tl .constexpr , # int
67+ num_queries_per_kv : tl .constexpr , # int
68+ block_table_stride : tl .int64 , # int
69+ query_stride_0 : tl .int64 , # int
70+ query_stride_1 : tl .int64 , # int, should be equal to head_size
71+ output_stride_0 : tl .int64 , # int
72+ output_stride_1 : tl .int64 , # int, should be equal to head_size
73+ BLOCK_SIZE : tl .constexpr , # int
74+ HEAD_SIZE : tl .constexpr , # int
75+ HEAD_SIZE_PADDED : tl .constexpr , # int, must be power of 2
76+ USE_ALIBI_SLOPES : tl .constexpr , # bool
77+ USE_SOFTCAP : tl .constexpr , # bool
78+ SLIDING_WINDOW : tl .constexpr , # int
79+ stride_k_cache_0 : tl .int64 , # int
80+ stride_k_cache_1 : tl .int64 , # int
81+ stride_k_cache_2 : tl .int64 , # int
82+ stride_k_cache_3 : tl .constexpr , # int
83+ stride_v_cache_0 : tl .int64 , # int
84+ stride_v_cache_1 : tl .int64 , # int
85+ stride_v_cache_2 : tl .int64 , # int
86+ stride_v_cache_3 : tl .constexpr , # int
87+ query_start_len_ptr , # [num_seqs+1]
88+ BLOCK_Q : tl .constexpr , # int
89+ num_seqs : tl .int32 ,
90+ BLOCK_M : tl .constexpr , # int
91+ USE_FP8 : tl .constexpr , # bool
92+ FP8_MIN : tl .constexpr = float8_info .min ,
93+ FP8_MAX : tl .constexpr = float8_info .max ,
9494):
9595 q_block_global_idx = tl .program_id (0 )
9696 kv_head_idx = tl .program_id (1 )
@@ -495,6 +495,7 @@ def reduce_segments(
495495 seq_lens_ptr , # [num_seqs]
496496 num_seqs , # int
497497 num_query_heads : tl .constexpr , # int
498+ out_scale , # float32
498499 output_stride_0 : tl .int64 , # int
499500 output_stride_1 : tl .int64 , # int, should be equal to head_size
500501 block_table_stride : tl .int64 , # int
@@ -504,6 +505,9 @@ def reduce_segments(
504505 query_start_len_ptr , # [num_seqs+1]
505506 BLOCK_Q : tl .constexpr , # int
506507 NUM_SEGMENTS_PER_SEQ : tl .constexpr , # int
508+ USE_FP8 : tl .constexpr , # bool
509+ FP8_MIN : tl .constexpr = float8_info .min ,
510+ FP8_MAX : tl .constexpr = float8_info .max ,
507511):
508512 query_token_idx = tl .program_id (0 )
509513 query_head_idx = tl .program_id (1 )
@@ -559,6 +563,10 @@ def reduce_segments(
559563 # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
560564 acc = tl .where (overall_expsum == 0.0 , 0.0 , acc_sum / overall_expsum )
561565
566+ if USE_FP8 :
567+ acc = acc / tl .load (out_scale )
568+ acc = tl .clamp (acc , FP8_MIN , FP8_MAX )
569+
562570 # write result
563571 output_offset = (query_token_idx * output_stride_0 +
564572 query_head_idx * output_stride_1 +
@@ -632,6 +640,7 @@ def unified_attention(
632640 scale = softmax_scale ,
633641 k_scale = k_descale ,
634642 v_scale = v_descale ,
643+ out_scale = output_scale ,
635644 softcap = softcap ,
636645 num_query_heads = num_query_heads ,
637646 num_queries_per_kv = num_queries_per_kv ,
@@ -658,6 +667,7 @@ def unified_attention(
658667 BLOCK_Q = BLOCK_Q ,
659668 num_seqs = num_seqs ,
660669 BLOCK_M = BLOCK_M ,
670+ USE_FP8 = output_scale is not None ,
661671 )
662672 else :
663673 # for initial version, NUM_SEGMENTS = 16 is chosen as a default
@@ -701,6 +711,7 @@ def unified_attention(
701711 scale = softmax_scale ,
702712 k_scale = k_descale ,
703713 v_scale = v_descale ,
714+ out_scale = output_scale ,
704715 softcap = softcap ,
705716 num_query_heads = num_query_heads ,
706717 num_queries_per_kv = num_queries_per_kv ,
@@ -726,6 +737,7 @@ def unified_attention(
726737 num_seqs = num_seqs ,
727738 BLOCK_M = BLOCK_M ,
728739 NUM_SEGMENTS_PER_SEQ = NUM_SEGMENTS ,
740+ USE_FP8 = output_scale is not None ,
729741 )
730742
731743 reduce_segments [(q .shape [0 ], num_query_heads )](
@@ -736,6 +748,7 @@ def unified_attention(
736748 seq_lens_ptr = seqused_k ,
737749 num_seqs = num_seqs ,
738750 num_query_heads = num_query_heads ,
751+ out_scale = output_scale ,
739752 output_stride_0 = out .stride (0 ),
740753 output_stride_1 = out .stride (1 ),
741754 block_table_stride = block_table .stride (0 ),
@@ -745,4 +758,5 @@ def unified_attention(
745758 query_start_len_ptr = cu_seqlens_q ,
746759 BLOCK_Q = BLOCK_Q ,
747760 NUM_SEGMENTS_PER_SEQ = NUM_SEGMENTS ,
761+ USE_FP8 = output_scale is not None ,
748762 )
0 commit comments