Skip to content

Commit 82acc9c

Browse files
committed
Merge leftover
1 parent d0e46dc commit 82acc9c

File tree

1 file changed

+54
-40
lines changed

1 file changed

+54
-40
lines changed

vllm/attention/ops/triton_unified_attention.py

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -51,46 +51,46 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
5151

5252
@triton.jit
5353
def kernel_unified_attention_2d(
54-
output_ptr, # [num_tokens, num_query_heads, head_size]
55-
query_ptr, # [num_tokens, num_query_heads, head_size]
56-
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
57-
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
58-
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
59-
seq_lens_ptr, # [num_seqs]
60-
alibi_slopes_ptr, # [num_query_heads]
61-
scale, # float32
62-
k_scale, # float32
63-
v_scale, # float32
64-
out_scale, # float32
65-
softcap, # float32
66-
num_query_heads: tl.constexpr, # int
67-
num_queries_per_kv: tl.constexpr, # int
68-
block_table_stride: tl.int64, # int
69-
query_stride_0: tl.int64, # int
70-
query_stride_1: tl.int64, # int, should be equal to head_size
71-
output_stride_0: tl.int64, # int
72-
output_stride_1: tl.int64, # int, should be equal to head_size
73-
BLOCK_SIZE: tl.constexpr, # int
74-
HEAD_SIZE: tl.constexpr, # int
75-
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
76-
USE_ALIBI_SLOPES: tl.constexpr, # bool
77-
USE_SOFTCAP: tl.constexpr, # bool
78-
SLIDING_WINDOW: tl.constexpr, # int
79-
stride_k_cache_0: tl.int64, # int
80-
stride_k_cache_1: tl.int64, # int
81-
stride_k_cache_2: tl.int64, # int
82-
stride_k_cache_3: tl.constexpr, # int
83-
stride_v_cache_0: tl.int64, # int
84-
stride_v_cache_1: tl.int64, # int
85-
stride_v_cache_2: tl.int64, # int
86-
stride_v_cache_3: tl.constexpr, # int
87-
query_start_len_ptr, # [num_seqs+1]
88-
BLOCK_Q: tl.constexpr, # int
89-
num_seqs: tl.int32,
90-
BLOCK_M: tl.constexpr, # int
91-
USE_FP8: tl.constexpr, # bool
92-
FP8_MIN: tl.constexpr = float8_info.min,
93-
FP8_MAX: tl.constexpr = float8_info.max,
54+
output_ptr, # [num_tokens, num_query_heads, head_size]
55+
query_ptr, # [num_tokens, num_query_heads, head_size]
56+
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
57+
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
58+
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
59+
seq_lens_ptr, # [num_seqs]
60+
alibi_slopes_ptr, # [num_query_heads]
61+
scale, # float32
62+
k_scale, # float32
63+
v_scale, # float32
64+
out_scale, # float32
65+
softcap, # float32
66+
num_query_heads: tl.constexpr, # int
67+
num_queries_per_kv: tl.constexpr, # int
68+
block_table_stride: tl.int64, # int
69+
query_stride_0: tl.int64, # int
70+
query_stride_1: tl.int64, # int, should be equal to head_size
71+
output_stride_0: tl.int64, # int
72+
output_stride_1: tl.int64, # int, should be equal to head_size
73+
BLOCK_SIZE: tl.constexpr, # int
74+
HEAD_SIZE: tl.constexpr, # int
75+
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
76+
USE_ALIBI_SLOPES: tl.constexpr, # bool
77+
USE_SOFTCAP: tl.constexpr, # bool
78+
SLIDING_WINDOW: tl.constexpr, # int
79+
stride_k_cache_0: tl.int64, # int
80+
stride_k_cache_1: tl.int64, # int
81+
stride_k_cache_2: tl.int64, # int
82+
stride_k_cache_3: tl.constexpr, # int
83+
stride_v_cache_0: tl.int64, # int
84+
stride_v_cache_1: tl.int64, # int
85+
stride_v_cache_2: tl.int64, # int
86+
stride_v_cache_3: tl.constexpr, # int
87+
query_start_len_ptr, # [num_seqs+1]
88+
BLOCK_Q: tl.constexpr, # int
89+
num_seqs: tl.int32,
90+
BLOCK_M: tl.constexpr, # int
91+
USE_FP8: tl.constexpr, # bool
92+
FP8_MIN: tl.constexpr = float8_info.min,
93+
FP8_MAX: tl.constexpr = float8_info.max,
9494
):
9595
q_block_global_idx = tl.program_id(0)
9696
kv_head_idx = tl.program_id(1)
@@ -495,6 +495,7 @@ def reduce_segments(
495495
seq_lens_ptr, # [num_seqs]
496496
num_seqs, # int
497497
num_query_heads: tl.constexpr, # int
498+
out_scale, # float32
498499
output_stride_0: tl.int64, # int
499500
output_stride_1: tl.int64, # int, should be equal to head_size
500501
block_table_stride: tl.int64, # int
@@ -504,6 +505,9 @@ def reduce_segments(
504505
query_start_len_ptr, # [num_seqs+1]
505506
BLOCK_Q: tl.constexpr, # int
506507
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
508+
USE_FP8: tl.constexpr, # bool
509+
FP8_MIN: tl.constexpr = float8_info.min,
510+
FP8_MAX: tl.constexpr = float8_info.max,
507511
):
508512
query_token_idx = tl.program_id(0)
509513
query_head_idx = tl.program_id(1)
@@ -559,6 +563,10 @@ def reduce_segments(
559563
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
560564
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
561565

566+
if USE_FP8:
567+
acc = acc / tl.load(out_scale)
568+
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
569+
562570
# write result
563571
output_offset = (query_token_idx * output_stride_0 +
564572
query_head_idx * output_stride_1 +
@@ -632,6 +640,7 @@ def unified_attention(
632640
scale=softmax_scale,
633641
k_scale=k_descale,
634642
v_scale=v_descale,
643+
out_scale=output_scale,
635644
softcap=softcap,
636645
num_query_heads=num_query_heads,
637646
num_queries_per_kv=num_queries_per_kv,
@@ -658,6 +667,7 @@ def unified_attention(
658667
BLOCK_Q=BLOCK_Q,
659668
num_seqs=num_seqs,
660669
BLOCK_M=BLOCK_M,
670+
USE_FP8=output_scale is not None,
661671
)
662672
else:
663673
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
@@ -701,6 +711,7 @@ def unified_attention(
701711
scale=softmax_scale,
702712
k_scale=k_descale,
703713
v_scale=v_descale,
714+
out_scale=output_scale,
704715
softcap=softcap,
705716
num_query_heads=num_query_heads,
706717
num_queries_per_kv=num_queries_per_kv,
@@ -726,6 +737,7 @@ def unified_attention(
726737
num_seqs=num_seqs,
727738
BLOCK_M=BLOCK_M,
728739
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
740+
USE_FP8=output_scale is not None,
729741
)
730742

731743
reduce_segments[(q.shape[0], num_query_heads)](
@@ -736,6 +748,7 @@ def unified_attention(
736748
seq_lens_ptr=seqused_k,
737749
num_seqs=num_seqs,
738750
num_query_heads=num_query_heads,
751+
out_scale=output_scale,
739752
output_stride_0=out.stride(0),
740753
output_stride_1=out.stride(1),
741754
block_table_stride=block_table.stride(0),
@@ -745,4 +758,5 @@ def unified_attention(
745758
query_start_len_ptr=cu_seqlens_q,
746759
BLOCK_Q=BLOCK_Q,
747760
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
761+
USE_FP8=output_scale is not None,
748762
)

0 commit comments

Comments
 (0)