Skip to content

Commit 7bf92f9

Browse files
committed
Merge remote-tracking branch 'upstream/main' into upstream_merge_2025_05_29
2 parents bee14ca + c290340 commit 7bf92f9

File tree

15 files changed

+70
-179
lines changed

15 files changed

+70
-179
lines changed

csrc/attention/paged_attention_v1.cu

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
// TODO(woosuk): Tune NUM_THREADS.
4949
template <typename T, typename CACHE_T, int BLOCK_SIZE,
5050
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
51-
int NUM_THREADS>
51+
int NUM_THREADS = 128>
5252
void paged_attention_v1_launcher(
5353
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
5454
torch::Tensor& value_cache, int num_kv_heads, float scale,
@@ -133,38 +133,19 @@ void paged_attention_v1_launcher(
133133
}
134134
}
135135

136-
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE, \
137-
NUM_THREADS) \
136+
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
138137
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
139-
IS_BLOCK_SPARSE, NUM_THREADS>( \
138+
IS_BLOCK_SPARSE>( \
140139
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
141140
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
142141
blocksparse_local_blocks, blocksparse_vert_stride, \
143142
blocksparse_block_size, blocksparse_head_sliding_step);
144143

145-
#define CALL_V1_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, \
146-
IS_FP8_KV_CACHE, IS_BLOCK_SPARSE) \
147-
switch (num_threads) { \
148-
case 128: \
149-
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \
150-
IS_BLOCK_SPARSE, 128); \
151-
break; \
152-
case 1024: \
153-
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \
154-
IS_BLOCK_SPARSE, 1024); \
155-
break; \
156-
default: \
157-
TORCH_CHECK(false, "Unsupported num threads: ", num_threads); \
158-
break; \
159-
}
160-
161-
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
162-
if (is_block_sparse) { \
163-
CALL_V1_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \
164-
true); \
165-
} else { \
166-
CALL_V1_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \
167-
false); \
144+
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
145+
if (is_block_sparse) { \
146+
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
147+
} else { \
148+
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
168149
}
169150

170151
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
@@ -202,7 +183,7 @@ void paged_attention_v1(
202183
torch::Tensor& v_scale, const int64_t tp_rank,
203184
const int64_t blocksparse_local_blocks,
204185
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
205-
const int64_t blocksparse_head_sliding_step, const int64_t num_threads) {
186+
const int64_t blocksparse_head_sliding_step) {
206187
const bool is_block_sparse = (blocksparse_vert_stride > 1);
207188

208189
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,

csrc/attention/paged_attention_v2.cu

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
template <typename T, typename CACHE_T, int BLOCK_SIZE,
5050
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
51-
int NUM_THREADS, int PARTITION_SIZE = 512>
51+
int NUM_THREADS = 128, int PARTITION_SIZE = 512>
5252
void paged_attention_v2_launcher(
5353
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
5454
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
@@ -139,39 +139,20 @@ void paged_attention_v2_launcher(
139139
}
140140
}
141141

142-
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE, \
143-
NUM_THREADS, PARTITION_SIZE) \
142+
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
144143
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
145-
IS_BLOCK_SPARSE, NUM_THREADS, PARTITION_SIZE>( \
144+
IS_BLOCK_SPARSE>( \
146145
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
147146
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
148147
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
149148
blocksparse_vert_stride, blocksparse_block_size, \
150149
blocksparse_head_sliding_step);
151150

152-
#define CALL_V2_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, \
153-
IS_FP8_KV_CACHE, IS_BLOCK_SPARSE) \
154-
switch (num_threads) { \
155-
case 128: \
156-
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \
157-
IS_BLOCK_SPARSE, 128, 512); \
158-
break; \
159-
case 1024: \
160-
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \
161-
IS_BLOCK_SPARSE, 1024, 1024); \
162-
break; \
163-
default: \
164-
TORCH_CHECK(false, "Unsupported num threads: ", num_threads); \
165-
break; \
166-
}
167-
168-
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
169-
if (is_block_sparse) { \
170-
CALL_V2_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \
171-
true); \
172-
} else { \
173-
CALL_V2_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \
174-
false); \
151+
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
152+
if (is_block_sparse) { \
153+
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
154+
} else { \
155+
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
175156
}
176157

177158
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
@@ -213,7 +194,7 @@ void paged_attention_v2(
213194
torch::Tensor& v_scale, const int64_t tp_rank,
214195
const int64_t blocksparse_local_blocks,
215196
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
216-
const int64_t blocksparse_head_sliding_step, const int64_t num_threads) {
197+
const int64_t blocksparse_head_sliding_step) {
217198
const bool is_block_sparse = (blocksparse_vert_stride > 1);
218199
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
219200
CALL_V2_LAUNCHER_BLOCK_SIZE)

csrc/ops.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void paged_attention_v1(
3838
torch::Tensor& v_scale, const int64_t tp_rank,
3939
const int64_t blocksparse_local_blocks,
4040
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
41-
const int64_t blocksparse_head_sliding_step, const int64_t num_threads);
41+
const int64_t blocksparse_head_sliding_step);
4242

4343
void paged_attention_v2(
4444
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
@@ -50,7 +50,7 @@ void paged_attention_v2(
5050
torch::Tensor& v_scale, const int64_t tp_rank,
5151
const int64_t blocksparse_local_blocks,
5252
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
53-
const int64_t blocksparse_head_sliding_step, const int64_t num_threads);
53+
const int64_t blocksparse_head_sliding_step);
5454

5555
#ifndef USE_ROCM
5656
void merge_attn_states(torch::Tensor& output,

csrc/torch_bindings.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
4747
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
4848
" int tp_rank, int blocksparse_local_blocks,"
4949
" int blocksparse_vert_stride, int blocksparse_block_size,"
50-
" int blocksparse_head_sliding_step,"
51-
" int num_threads) -> ()");
50+
" int blocksparse_head_sliding_step) -> ()");
5251
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
5352

5453
// PagedAttention V2.
@@ -62,8 +61,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
6261
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
6362
" int tp_rank, int blocksparse_local_blocks,"
6463
" int blocksparse_vert_stride, int blocksparse_block_size,"
65-
" int blocksparse_head_sliding_step,"
66-
" int num_threads) -> ()");
64+
" int blocksparse_head_sliding_step) -> ()");
6765
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
6866

6967
#ifndef USE_ROCM

tests/entrypoints/llm/test_init.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

tests/kernels/attention/test_attention.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tests.kernels.utils import opcheck
1111
from vllm import _custom_ops as ops
1212
from vllm.platforms import current_platform
13-
from vllm.utils import get_max_shared_memory_bytes, is_navi
13+
from vllm.utils import get_max_shared_memory_bytes
1414

1515
if not current_platform.is_rocm():
1616
from xformers import ops as xops
@@ -37,7 +37,7 @@
3737

3838
# This should be sync with get_supported_head_sizes() in
3939
# vllm.attention.ops.paged_attn.PagedAttention
40-
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
40+
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
4141

4242
BLOCK_SIZES = [16, 32]
4343
USE_ALIBI = [False, True]
@@ -195,10 +195,6 @@ def test_paged_attention(
195195
# Using default kv_scale
196196
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
197197

198-
# additional argument for v1/v2 pa kernel
199-
num_threads = 1024 if current_platform.is_rocm() \
200-
and not is_navi() else 128
201-
202198
# Call the paged attention kernel.
203199
output = torch.empty_like(query)
204200
if version == "v1":
@@ -219,12 +215,12 @@ def test_paged_attention(
219215
v_scale,
220216
)
221217

222-
opcheck(
223-
torch.ops._C.paged_attention_v1,
224-
(output, query, key_cache, value_cache, num_kv_heads, scale,
225-
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
226-
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, num_threads),
227-
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]))
218+
opcheck(torch.ops._C.paged_attention_v1,
219+
(output, query, key_cache, value_cache, num_kv_heads, scale,
220+
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
221+
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
222+
cond=(head_size == HEAD_SIZES[0]
223+
and block_size == BLOCK_SIZES[0]))
228224

229225
elif version in ("v2", "rocm"):
230226
if current_platform.is_rocm() and version == "rocm":
@@ -263,14 +259,13 @@ def test_paged_attention(
263259
v_scale,
264260
)
265261

266-
opcheck(
267-
torch.ops._C.paged_attention_v2,
268-
(output, exp_sums, max_logits, tmp_output, query, key_cache,
269-
value_cache, num_kv_heads, scale, block_tables, seq_lens,
270-
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
271-
k_scale, v_scale, 0, 0, 0, 64, 0, num_threads),
272-
cond=(head_size == HEAD_SIZES[0]
273-
and block_size == BLOCK_SIZES[0]))
262+
opcheck(torch.ops._C.paged_attention_v2,
263+
(output, exp_sums, max_logits, tmp_output, query,
264+
key_cache, value_cache, num_kv_heads, scale, block_tables,
265+
seq_lens, block_size, max_seq_len, alibi_slopes,
266+
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
267+
cond=(head_size == HEAD_SIZES[0]
268+
and block_size == BLOCK_SIZES[0]))
274269

275270
else:
276271
ops.paged_attention_rocm(

tests/kernels/attention/test_blocksparse_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# There may not be enough gpu memory due to large NUM_BLOCKS.
2323
# Reduce NUM_BLOCKS when it happens.
2424
NUM_BLOCKS = 4321 # Arbitrary values for testing
25-
PARTITION_SIZE = 512 if not current_platform.is_rocm() else 1024
25+
PARTITION_SIZE = 512
2626
DTYPES = [torch.half, torch.bfloat16]
2727
NUM_GEN_SEQS = [3] # Arbitrary values for testing
2828
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing

vllm/_custom_ops.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def paged_attention_v1(
6464
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
6565
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
6666
blocksparse_vert_stride, blocksparse_block_size,
67-
blocksparse_head_sliding_step,
68-
num_threads = 1024 if current_platform.is_rocm() \
69-
and not is_navi() else 128)
67+
blocksparse_head_sliding_step)
7068

7169

7270
def paged_attention_v2(
@@ -98,9 +96,7 @@ def paged_attention_v2(
9896
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
9997
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
10098
blocksparse_local_blocks, blocksparse_vert_stride,
101-
blocksparse_block_size, blocksparse_head_sliding_step,
102-
num_threads = 1024 if current_platform.is_rocm() \
103-
and not is_navi() else 128)
99+
blocksparse_block_size, blocksparse_head_sliding_step)
104100

105101

106102
def paged_attention_rocm(

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -766,12 +766,13 @@ def forward(
766766
query.dtype,
767767
seq_lens,
768768
make_attn_mask=causal_mask) # type: ignore
769+
use_fp8_scales = (layer._q_scale and layer._k_scale
770+
and layer._v_scale and layer._prob_scale
771+
and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN)
769772
full_scales = (
770773
layer._q_scale.item(), layer._k_scale.item(),
771-
layer._v_scale.item(), layer._prob_scale.item()) if (
772-
layer._out_scale and layer._q_scale
773-
and layer._prob_scale
774-
and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None
774+
layer._v_scale.item(),
775+
layer._prob_scale.item()) if use_fp8_scales else None
775776
self.triton_attn_func(
776777
query,
777778
key,

vllm/attention/ops/paged_attn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66
import torch
77

88
from vllm import _custom_ops as ops
9-
from vllm.platforms import current_platform
109
from vllm.triton_utils import HAS_TRITON
11-
from vllm.utils import is_navi
1210

1311
if HAS_TRITON:
1412
from vllm.attention.ops.prefix_prefill import context_attention_fwd
1513

1614
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
17-
_PARTITION_SIZE = 512 if not current_platform.is_rocm() or is_navi() else 1024
15+
_PARTITION_SIZE = 512
1816

1917

2018
@dataclass

0 commit comments

Comments
 (0)