Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit a67b65b

Browse files
Fix PA custom and PA v2 tests and partition sizes (#196)
* update custom PA kernel with support for fp8 kv cache dtype; change custom PA partition size to 512 to prefer throughput scenarios at cost of latency * Fix lint * Fix BF16 with FP8 KV cache (scaled conversion incorrectly done in fp16) * Fix custom PA tests * Merge branch 'main' of [email protected]:ROCm/vllm.git into mawong/fix_custom_pa_tests * Fix partition sizes for PAv2, PAcustom * Fix linting * Fix a few names and variable scopes * Rename custom to rocm as per suggestion --------- Co-authored-by: Shomy Sanyal <[email protected]>
1 parent d21cf99 commit a67b65b

File tree

7 files changed

+81
-444
lines changed

7 files changed

+81
-444
lines changed

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
from vllm import _custom_ops as ops
88
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
9-
create_kv_caches_with_random)
9+
create_kv_caches_with_random, is_hip)
1010

1111
NUM_BLOCKS = 1024 * 1024
12-
PARTITION_SIZE = 256
12+
PARTITION_SIZE = 512
1313

1414

1515
@torch.inference_mode()
@@ -80,9 +80,9 @@ def main(
8080
# Prepare for the paged attention kernel.
8181
output = torch.empty_like(query)
8282
if version == "v2":
83-
if not args.custom_paged_attn:
83+
if is_hip() and not args.custom_paged_attn:
8484
global PARTITION_SIZE
85-
PARTITION_SIZE = 512
85+
PARTITION_SIZE = 1024
8686
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
8787
tmp_output = torch.empty(
8888
size=(num_seqs, num_query_heads, num_partitions, head_size),

tests/kernels/test_attention.py

Lines changed: 65 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131

3232
# FlashAttention forward only supports head dimension at most 128
3333
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
34-
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
35-
] if not is_hip() else [64, 80, 96, 112, 128]
34+
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
3635

3736
BLOCK_SIZES = [16, 32]
3837
USE_ALIBI = [False, True]
@@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention(
114113
output[i].copy_(out, non_blocking=True)
115114

116115

117-
@pytest.mark.parametrize("version", ["v1", "v2"])
116+
@pytest.mark.parametrize(
117+
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
118118
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
119119
@pytest.mark.parametrize("num_heads", NUM_HEADS)
120120
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@@ -137,7 +137,8 @@ def test_paged_attention(
137137
seed: int,
138138
device: str,
139139
) -> None:
140-
if kv_cache_dtype == "fp8" and head_size % 16:
140+
if ((kv_cache_dtype == "fp8" and head_size % 16)
141+
or (version == "rocm" and head_size not in (64, 128))):
141142
pytest.skip()
142143
random.seed(seed)
143144
torch.random.manual_seed(seed)
@@ -208,7 +209,9 @@ def test_paged_attention(
208209
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
209210
cond=(head_size == HEAD_SIZES[0]))
210211

211-
elif version == "v2":
212+
elif version in ("v2", "rocm"):
213+
if is_hip():
214+
PARTITION_SIZE = 1024 if version == "v2" else 512
212215
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
213216
assert PARTITION_SIZE % block_size == 0
214217
num_seqs, num_heads, head_size = output.shape
@@ -221,32 +224,62 @@ def test_paged_attention(
221224
dtype=torch.float32,
222225
)
223226
max_logits = torch.empty_like(exp_sums)
224-
ops.paged_attention_v2(
225-
output,
226-
exp_sums,
227-
max_logits,
228-
tmp_output,
229-
query,
230-
key_cache,
231-
value_cache,
232-
num_kv_heads,
233-
scale,
234-
block_tables,
235-
seq_lens,
236-
block_size,
237-
max_seq_len,
238-
alibi_slopes,
239-
kv_cache_dtype,
240-
k_scale,
241-
v_scale,
242-
)
243227

244-
opcheck(torch.ops._C.paged_attention_v2,
245-
(output, exp_sums, max_logits, tmp_output, query, key_cache,
246-
value_cache, num_kv_heads, scale, block_tables, seq_lens,
247-
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
248-
k_scale, v_scale, 0, 0, 0, 64, 0),
249-
cond=(head_size == HEAD_SIZES[0]))
228+
if version == "v2":
229+
ops.paged_attention_v2(
230+
output,
231+
exp_sums,
232+
max_logits,
233+
tmp_output,
234+
query,
235+
key_cache,
236+
value_cache,
237+
num_kv_heads,
238+
scale,
239+
block_tables,
240+
seq_lens,
241+
block_size,
242+
max_seq_len,
243+
alibi_slopes,
244+
kv_cache_dtype,
245+
k_scale,
246+
v_scale,
247+
)
248+
249+
opcheck(torch.ops._C.paged_attention_v2,
250+
(output, exp_sums, max_logits, tmp_output, query,
251+
key_cache, value_cache, num_kv_heads, scale, block_tables,
252+
seq_lens, block_size, max_seq_len, alibi_slopes,
253+
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
254+
cond=(head_size == HEAD_SIZES[0]))
255+
256+
else:
257+
ops.paged_attention_rocm(
258+
output,
259+
exp_sums,
260+
max_logits,
261+
tmp_output,
262+
query,
263+
key_cache,
264+
value_cache,
265+
num_kv_heads,
266+
scale,
267+
block_tables,
268+
seq_lens,
269+
block_size,
270+
max_seq_len,
271+
alibi_slopes,
272+
kv_cache_dtype,
273+
k_scale,
274+
v_scale,
275+
)
276+
277+
opcheck(torch.ops._rocm_C.paged_attention,
278+
(output, exp_sums, max_logits, tmp_output, query,
279+
key_cache, value_cache, num_kv_heads, scale, block_tables,
280+
seq_lens, block_size, max_seq_len, alibi_slopes,
281+
kv_cache_dtype, k_scale, v_scale),
282+
cond=(head_size == 64))
250283

251284
else:
252285
raise AssertionError(f"Unknown version: {version}")
@@ -330,173 +363,15 @@ def ref_multi_query_kv_attention(
330363
return torch.cat(ref_outputs, dim=0)
331364

332365

333-
@pytest.mark.parametrize("version", ["rocm"])
334-
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
335-
@pytest.mark.parametrize("num_heads", NUM_HEADS)
336-
@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128
337-
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
338-
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
339-
@pytest.mark.parametrize("dtype", DTYPES)
340-
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
341-
@pytest.mark.parametrize("seed", SEEDS)
342-
@pytest.mark.parametrize("device", CUDA_DEVICES)
343-
@pytest.mark.skipif(not is_hip(), reason="only for rocm")
344-
def test_paged_attention_rocm(
345-
kv_cache_factory,
346-
version: str,
347-
num_seqs: int,
348-
num_heads: Tuple[int, int],
349-
head_size: int,
350-
use_alibi: bool,
351-
block_size: int,
352-
dtype: torch.dtype,
353-
kv_cache_dtype: str,
354-
seed: int,
355-
device: str,
356-
) -> None:
357-
random.seed(seed)
358-
torch.random.manual_seed(seed)
359-
if torch.cuda.is_available():
360-
torch.cuda.manual_seed(seed)
361-
torch.set_default_device(device)
362-
scale = float(1.0 / (head_size**0.5))
363-
num_query_heads, num_kv_heads = num_heads
364-
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
365-
query.uniform_(-scale, scale)
366-
367-
assert num_query_heads % num_kv_heads == 0
368-
num_queries_per_kv = num_query_heads // num_kv_heads
369-
alibi_slopes = None
370-
if use_alibi:
371-
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
372-
373-
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
374-
context_lens[-1] = MAX_SEQ_LEN
375-
#context_lens = [8192 for _ in range(num_seqs)]
376-
max_context_len = max(context_lens)
377-
context_lens = torch.tensor(context_lens, dtype=torch.int)
378-
#print('>>> ctx lens', context_lens)
379-
380-
# Create the block tables.
381-
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
382-
block_tables = []
383-
for _ in range(num_seqs):
384-
block_table = [
385-
random.randint(0, NUM_BLOCKS - 1)
386-
for _ in range(max_num_blocks_per_seq)
387-
]
388-
block_tables.append(block_table)
389-
block_tables = torch.tensor(block_tables, dtype=torch.int)
390-
391-
# Create the KV caches.
392-
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
393-
num_kv_heads, head_size,
394-
kv_cache_dtype, dtype, seed,
395-
device)
396-
key_cache, value_cache = key_caches[0], value_caches[0]
397-
398-
# TODO(charlifu) enable fp8 kv cache
399-
# Using default kv_scale
400-
# kv_scale = 1.0
401-
402-
# Call the paged attention kernel.
403-
output = torch.empty_like(query)
404-
PARTITION_SIZE_ROCM = 256
405-
num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
406-
PARTITION_SIZE_ROCM)
407-
assert PARTITION_SIZE_ROCM % block_size == 0
408-
num_seqs, num_heads, head_size = output.shape
409-
tmp_output = torch.empty(
410-
size=(num_seqs, num_heads, num_partitions, head_size),
411-
dtype=output.dtype,
412-
)
413-
exp_sums = torch.empty(
414-
size=(num_seqs, num_heads, num_partitions),
415-
dtype=torch.float32,
416-
)
417-
max_logits = torch.empty_like(exp_sums)
418-
if version == "rocm":
419-
ops.paged_attention_rocm(
420-
output,
421-
exp_sums,
422-
max_logits,
423-
tmp_output,
424-
query,
425-
key_cache,
426-
value_cache,
427-
num_kv_heads,
428-
scale,
429-
block_tables,
430-
context_lens,
431-
block_size,
432-
max_context_len,
433-
alibi_slopes,
434-
kv_cache_dtype,
435-
)
436-
else:
437-
raise AssertionError(f"Unknown version: {version}")
438-
439-
# Run the reference implementation.
440-
if kv_cache_dtype == "fp8":
441-
# Convert cache data back to dtype.
442-
x = 16 // torch.tensor([], dtype=dtype).element_size()
443-
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
444-
block_size, x)
445-
dequantized_key_cache = torch.empty(size=key_cache_shape,
446-
dtype=dtype,
447-
device=device)
448-
ops.convert_fp8(key_cache, dequantized_key_cache)
449-
key_cache = dequantized_key_cache
450-
451-
value_cache_shape = value_cache.shape
452-
dequantized_value_cache = torch.empty(size=value_cache_shape,
453-
dtype=dtype,
454-
device=device)
455-
ops.convert_fp8(value_cache, dequantized_value_cache)
456-
value_cache = dequantized_value_cache
457-
458-
ref_output = torch.empty_like(query)
459-
ref_single_query_cached_kv_attention(
460-
ref_output,
461-
query,
462-
num_queries_per_kv,
463-
key_cache,
464-
value_cache,
465-
block_tables,
466-
context_lens,
467-
scale,
468-
alibi_slopes,
469-
)
470-
471-
# NOTE(woosuk): Due to the kernel-level differences in the two
472-
# implementations, there is a small numerical difference in the two
473-
# outputs. Thus, we use a relaxed tolerance for the test.
474-
atol = get_default_atol(output) if is_hip() else 1e-3
475-
rtol = get_default_rtol(output) if is_hip() else 1e-5
476-
477-
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
478-
# so we use a relaxed tolerance for the test.
479-
atol, rtol = 1e-4, 1e-5
480-
if dtype == torch.bfloat16:
481-
atol, rtol = 2e-4, 1e-5
482-
if use_alibi:
483-
if dtype == torch.half:
484-
atol, rtol = 5e-4, 1e-5
485-
if dtype == torch.bfloat16:
486-
atol, rtol = 1e-3, 1e-5
487-
if kv_cache_dtype == "fp8":
488-
atol, rtol = 1e-2, 1e-5
489-
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
490-
491-
492366
# TODO(woosuk): Add tests for USE_ALIBI=True.
493367
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
494368
@pytest.mark.parametrize("num_heads", NUM_HEADS)
495369
@pytest.mark.parametrize("head_size", HEAD_SIZES)
496370
@pytest.mark.parametrize("dtype", DTYPES)
497371
@pytest.mark.parametrize("seed", SEEDS)
498372
@pytest.mark.parametrize("device", CUDA_DEVICES)
499-
@pytest.mark.skipif(is_hip(), reason="skip for rocm")
373+
@pytest.mark.skipif(is_hip(),
374+
reason="Xformers backend is not supported on ROCm.")
500375
@torch.inference_mode()
501376
def test_multi_query_kv_attention(
502377
num_seqs: int,

0 commit comments

Comments
 (0)