Skip to content

Commit 55cd9df

Browse files
committed
Merge remote-tracking branch 'origin/attention_fusion_v1_no_tests' into 0902_rc1
2 parents 16ccbde + 3f309e8 commit 55cd9df

File tree

7 files changed

+179
-103
lines changed

7 files changed

+179
-103
lines changed

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from .prefix_prefill import context_attention_fwd
1717

18+
float8_info = torch.finfo(current_platform.fp8_dtype())
19+
1820

1921
@triton.jit
2022
def cdiv_fn(x, y):
@@ -34,6 +36,7 @@ def kernel_paged_attention_2d(
3436
scale, # float32
3537
k_scale, # float32
3638
v_scale, # float32
39+
out_scale_inv,
3740
num_query_heads: tl.constexpr, # int
3841
num_queries_per_kv: tl.constexpr, # int
3942
num_queries_per_kv_padded: tl.constexpr, # int
@@ -60,7 +63,9 @@ def kernel_paged_attention_2d(
6063
filter_by_query_len: tl.constexpr, # bool
6164
query_start_len_ptr, # [num_seqs+1]
6265
USE_SINKS: tl.constexpr, # bool
63-
):
66+
USE_FP8: tl.constexpr,
67+
FP8_MIN: tl.constexpr = float8_info.min,
68+
FP8_MAX: tl.constexpr = float8_info.max):
6469
seq_idx = tl.program_id(0)
6570
kv_head_idx = tl.program_id(1)
6671

@@ -204,6 +209,9 @@ def kernel_paged_attention_2d(
204209

205210
# epilogue
206211
acc = acc / L[:, None]
212+
if USE_FP8:
213+
acc = acc * tl.load(out_scale_inv)
214+
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
207215

208216
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
209217
query_head_idx * output_stride_1)
@@ -234,6 +242,7 @@ def chunked_prefill_paged_decode(
234242
alibi_slopes=None,
235243
sliding_window=None,
236244
sm_scale=None,
245+
output_scale=None,
237246
# Optional tensor for sinks
238247
sinks=None,
239248
):
@@ -266,6 +275,7 @@ def chunked_prefill_paged_decode(
266275
sliding_window=sliding_window,
267276
sm_scale=sm_scale,
268277
skip_decode=True,
278+
fp8_out_scale=output_scale,
269279
sinks=sinks,
270280
)
271281

@@ -316,7 +326,7 @@ def chunked_prefill_paged_decode(
316326
tmp_output = torch.empty(
317327
size=(total_num_seq, num_query_heads, max_num_partitions,
318328
head_size),
319-
dtype=output.dtype,
329+
dtype=query.dtype,
320330
device=output.device,
321331
)
322332
exp_sums = torch.empty(
@@ -345,6 +355,7 @@ def chunked_prefill_paged_decode(
345355
kv_cache_dtype=kv_cache_dtype,
346356
k_scale=k_scale,
347357
v_scale=v_scale,
358+
fp8_out_scale=output_scale,
348359
)
349360
else:
350361
kernel_paged_attention_2d[(
@@ -362,6 +373,8 @@ def chunked_prefill_paged_decode(
362373
scale=sm_scale,
363374
k_scale=k_scale,
364375
v_scale=v_scale,
376+
out_scale_inv=1.0 /
377+
output_scale if output_scale is not None else 1.0,
365378
num_query_heads=num_query_heads,
366379
num_queries_per_kv=num_queries_per_kv,
367380
num_queries_per_kv_padded=num_queries_per_kv_padded,
@@ -388,4 +401,5 @@ def chunked_prefill_paged_decode(
388401
filter_by_query_len=True,
389402
query_start_len_ptr=query_start_loc,
390403
USE_SINKS=sinks is not None,
404+
USE_FP8=output_scale is not None,
391405
)

vllm/attention/ops/prefix_prefill.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
# To check compatibility
1717
IS_TURING = current_platform.get_device_capability() == (7, 5)
18+
float8_info = torch.finfo(current_platform.fp8_dtype())
1819

1920

2021
# Here's an example autotuner config for this kernel. This config does provide
@@ -43,6 +44,7 @@ def _fwd_kernel(Q,
4344
sm_scale,
4445
k_scale,
4546
v_scale,
47+
out_scale_inv,
4648
B_Start_Loc,
4749
B_Seqlen,
4850
x: tl.constexpr,
@@ -82,8 +84,11 @@ def _fwd_kernel(Q,
8284
num_unroll_request: tl.constexpr,
8385
SKIP_DECODE: tl.constexpr,
8486
USE_SINKS: tl.constexpr,
87+
USE_FP8: tl.constexpr,
8588
MAX_Q_LEN: tl.constexpr = 0,
86-
MAX_CTX_LEN: tl.constexpr = 0):
89+
MAX_CTX_LEN: tl.constexpr = 0,
90+
FP8_MIN: tl.constexpr = float8_info.min,
91+
FP8_MAX: tl.constexpr = float8_info.max):
8792

8893
cur_batch = tl.program_id(0)
8994
cur_head = tl.program_id(1)
@@ -284,6 +289,9 @@ def _fwd_kernel(Q,
284289
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
285290
cur_head * stride_oh + offs_d[None, :] * stride_od)
286291
out_ptrs = Out + off_o
292+
if USE_FP8:
293+
acc = acc * tl.load(out_scale_inv)
294+
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
287295
tl.store(out_ptrs,
288296
acc,
289297
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
@@ -743,6 +751,7 @@ def context_attention_fwd(q,
743751
sliding_window=None,
744752
sm_scale=None,
745753
skip_decode=False,
754+
fp8_out_scale=None,
746755
sinks=None):
747756

748757
q_dtype_is_f32 = q.dtype is torch.float32
@@ -870,6 +879,7 @@ def context_attention_fwd(q,
870879
sm_scale,
871880
k_scale,
872881
v_scale,
882+
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
873883
b_start_loc,
874884
b_seq_len,
875885
k_cache.shape[4],
@@ -905,6 +915,7 @@ def context_attention_fwd(q,
905915
BLOCK_DMODEL_PADDED=Lk_padded,
906916
SLIDING_WINDOW=sliding_window,
907917
SKIP_DECODE=skip_decode,
918+
USE_FP8=fp8_out_scale is not None,
908919
BLOCK_M=128,
909920
BLOCK_N=64,
910921
num_unroll_cache=4,

vllm/attention/ops/triton_unified_attention.py

Lines changed: 81 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
import torch
1111

1212
from vllm.logger import init_logger
13+
from vllm.platforms import current_platform
1314
from vllm.triton_utils import tl, triton
1415

1516
logger = init_logger(__name__)
17+
float8_info = torch.finfo(current_platform.fp8_dtype())
1618

1719

1820
@triton.jit
@@ -48,47 +50,51 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
4850

4951
@triton.jit
5052
def kernel_unified_attention_2d(
51-
output_ptr, # [num_tokens, num_query_heads, head_size]
52-
query_ptr, # [num_tokens, num_query_heads, head_size]
53-
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
54-
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
55-
sink_ptr, # [num_query_heads]
56-
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
57-
seq_lens_ptr, # [num_seqs]
58-
alibi_slopes_ptr, # [num_query_heads]
59-
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
60-
scale, # float32
61-
k_scale, # float32
62-
v_scale, # float32
63-
softcap, # float32
64-
num_query_heads: tl.constexpr, # int
65-
num_queries_per_kv: tl.constexpr, # int
66-
block_table_stride: tl.int64, # int
67-
query_stride_0: tl.int64, # int
68-
query_stride_1: tl.int64, # int, should be equal to head_size
69-
output_stride_0: tl.int64, # int
70-
output_stride_1: tl.int64, # int, should be equal to head_size
71-
qq_bias_stride_0: tl.int64, # int
72-
BLOCK_SIZE: tl.constexpr, # int
73-
HEAD_SIZE: tl.constexpr, # int
74-
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
75-
USE_ALIBI_SLOPES: tl.constexpr, # bool
76-
USE_QQ_BIAS: tl.constexpr, # bool
77-
USE_SOFTCAP: tl.constexpr, # bool
78-
USE_SINKS: tl.constexpr, # bool
79-
SLIDING_WINDOW: tl.constexpr, # int
80-
stride_k_cache_0: tl.int64, # int
81-
stride_k_cache_1: tl.int64, # int
82-
stride_k_cache_2: tl.int64, # int
83-
stride_k_cache_3: tl.constexpr, # int
84-
stride_v_cache_0: tl.int64, # int
85-
stride_v_cache_1: tl.int64, # int
86-
stride_v_cache_2: tl.int64, # int
87-
stride_v_cache_3: tl.constexpr, # int
88-
query_start_len_ptr, # [num_seqs+1]
89-
BLOCK_Q: tl.constexpr, # int
90-
num_seqs: tl.int32,
91-
BLOCK_M: tl.constexpr, # int
53+
output_ptr, # [num_tokens, num_query_heads, head_size]
54+
query_ptr, # [num_tokens, num_query_heads, head_size]
55+
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
56+
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
57+
sink_ptr, # [num_query_heads]
58+
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
59+
seq_lens_ptr, # [num_seqs]
60+
alibi_slopes_ptr, # [num_query_heads]
61+
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
62+
scale, # float32
63+
k_scale, # float32
64+
v_scale, # float32
65+
out_scale, # float32
66+
softcap, # float32
67+
num_query_heads: tl.constexpr, # int
68+
num_queries_per_kv: tl.constexpr, # int
69+
block_table_stride: tl.int64, # int
70+
query_stride_0: tl.int64, # int
71+
query_stride_1: tl.int64, # int, should be equal to head_size
72+
output_stride_0: tl.int64, # int
73+
output_stride_1: tl.int64, # int, should be equal to head_size
74+
qq_bias_stride_0: tl.int64, # int
75+
BLOCK_SIZE: tl.constexpr, # int
76+
HEAD_SIZE: tl.constexpr, # int
77+
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
78+
USE_ALIBI_SLOPES: tl.constexpr, # bool
79+
USE_QQ_BIAS: tl.constexpr, # bool
80+
USE_SOFTCAP: tl.constexpr, # bool
81+
USE_SINKS: tl.constexpr, # bool
82+
SLIDING_WINDOW: tl.constexpr, # int
83+
stride_k_cache_0: tl.int64, # int
84+
stride_k_cache_1: tl.int64, # int
85+
stride_k_cache_2: tl.int64, # int
86+
stride_k_cache_3: tl.constexpr, # int
87+
stride_v_cache_0: tl.int64, # int
88+
stride_v_cache_1: tl.int64, # int
89+
stride_v_cache_2: tl.int64, # int
90+
stride_v_cache_3: tl.constexpr, # int
91+
query_start_len_ptr, # [num_seqs+1]
92+
BLOCK_Q: tl.constexpr, # int
93+
num_seqs: tl.int32,
94+
BLOCK_M: tl.constexpr, # int
95+
USE_FP8: tl.constexpr, # bool
96+
FP8_MIN: tl.constexpr = float8_info.min,
97+
FP8_MAX: tl.constexpr = float8_info.max,
9298
):
9399
q_block_global_idx = tl.program_id(0)
94100
kv_head_idx = tl.program_id(1)
@@ -281,6 +287,9 @@ def kernel_unified_attention_2d(
281287

282288
# epilogue
283289
acc = acc / L[:, None]
290+
if USE_FP8:
291+
acc = acc * tl.load(out_scale)
292+
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
284293

285294
output_offset = (query_offset_0[:, None] * output_stride_0 +
286295
query_offset_1[:, None] * output_stride_1 +
@@ -552,23 +561,27 @@ def kernel_unified_attention_3d(
552561

553562
@triton.jit
554563
def reduce_segments(
555-
output_ptr, # [num_tokens, num_query_heads, head_size]
556-
segm_output_ptr,
557-
#[num_tokens, num_query_heads, max_num_segments, head_size]
558-
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
559-
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
560-
seq_lens_ptr, # [num_seqs]
561-
num_seqs, # int
562-
num_query_heads: tl.constexpr, # int
563-
output_stride_0: tl.int64, # int
564-
output_stride_1: tl.int64, # int, should be equal to head_size
565-
block_table_stride: tl.int64, # int
566-
BLOCK_SIZE: tl.constexpr, # int
567-
HEAD_SIZE: tl.constexpr, # int, must be power of 2
568-
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
569-
query_start_len_ptr, # [num_seqs+1]
570-
BLOCK_Q: tl.constexpr, # int
571-
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
564+
output_ptr, # [num_tokens, num_query_heads, head_size]
565+
segm_output_ptr,
566+
#[num_tokens, num_query_heads, max_num_segments, head_size]
567+
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
568+
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
569+
seq_lens_ptr, # [num_seqs]
570+
num_seqs, # int
571+
num_query_heads: tl.constexpr, # int
572+
out_scale_inv, # float32
573+
output_stride_0: tl.int64, # int
574+
output_stride_1: tl.int64, # int, should be equal to head_size
575+
block_table_stride: tl.int64, # int
576+
BLOCK_SIZE: tl.constexpr, # int
577+
HEAD_SIZE: tl.constexpr, # int, must be power of 2
578+
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
579+
query_start_len_ptr, # [num_seqs+1]
580+
BLOCK_Q: tl.constexpr, # int
581+
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
582+
USE_FP8: tl.constexpr, # bool
583+
FP8_MIN: tl.constexpr = float8_info.min,
584+
FP8_MAX: tl.constexpr = float8_info.max,
572585
):
573586
query_token_idx = tl.program_id(0)
574587
query_head_idx = tl.program_id(1)
@@ -624,6 +637,10 @@ def reduce_segments(
624637
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
625638
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
626639

640+
if USE_FP8:
641+
acc = acc * tl.load(out_scale_inv)
642+
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
643+
627644
# write result
628645
output_offset = (query_token_idx * output_stride_0 +
629646
query_head_idx * output_stride_1 +
@@ -649,6 +666,7 @@ def unified_attention(
649666
k_descale,
650667
v_descale,
651668
alibi_slopes=None,
669+
output_scale=None,
652670
qq_bias=None,
653671
# Optional tensor for sinks
654672
sinks=None,
@@ -706,6 +724,7 @@ def unified_attention(
706724
scale=softmax_scale,
707725
k_scale=k_descale,
708726
v_scale=v_descale,
727+
out_scale=1 / output_scale if output_scale is not None else 1.0,
709728
softcap=softcap,
710729
num_query_heads=num_query_heads,
711730
num_queries_per_kv=num_queries_per_kv,
@@ -735,6 +754,7 @@ def unified_attention(
735754
BLOCK_Q=BLOCK_Q,
736755
num_seqs=num_seqs,
737756
BLOCK_M=BLOCK_M,
757+
USE_FP8=output_scale is not None,
738758
)
739759
else:
740760
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
@@ -818,6 +838,8 @@ def unified_attention(
818838
seq_lens_ptr=seqused_k,
819839
num_seqs=num_seqs,
820840
num_query_heads=num_query_heads,
841+
out_scale_inv=1 /
842+
output_scale if output_scale is not None else 1.0,
821843
output_stride_0=out.stride(0),
822844
output_stride_1=out.stride(1),
823845
block_table_stride=block_table.stride(0),
@@ -827,4 +849,5 @@ def unified_attention(
827849
query_start_len_ptr=cu_seqlens_q,
828850
BLOCK_Q=BLOCK_Q,
829851
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
852+
USE_FP8=output_scale is not None,
830853
)

vllm/compilation/backends.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,12 @@ def configure_post_pass(self):
454454
inductor_config = config.inductor_compile_config
455455
PASS_KEY = "post_grad_custom_post_pass"
456456
if PASS_KEY in inductor_config:
457-
# Config should automatically wrap all inductor passes
458457
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
458+
# PassManager already added to config, make sure it's correct
459459
assert (inductor_config[PASS_KEY].uuid() ==
460460
self.post_grad_pass_manager.uuid())
461461
else:
462+
# Config should automatically wrap all inductor passes
462463
assert isinstance(inductor_config[PASS_KEY], InductorPass)
463464
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
464465
inductor_config[PASS_KEY] = self.post_grad_pass_manager

0 commit comments

Comments
 (0)