Skip to content

Commit 434ac76

Browse files
authored
[cpu][ci] Add CPU Attention Tests for Neon Backend (vllm-project#30347)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
1 parent ed7af31 commit 434ac76

File tree

1 file changed

+63
-10
lines changed

1 file changed

+63
-10
lines changed

tests/kernels/attention/test_cpu_attn.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import pytest
88
import torch
99

10-
from vllm.platforms import current_platform
10+
from vllm.platforms import CpuArchEnum, current_platform
11+
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
1112

1213
if not current_platform.is_cpu():
1314
pytest.skip("skipping CPU-only tests", allow_module_level=True)
@@ -36,6 +37,21 @@
3637
]
3738

3839

40+
def get_attn_isa(
41+
block_size: int | None = None,
42+
dtype: torch.dtype | None = None,
43+
):
44+
if block_size and dtype:
45+
return _get_attn_isa(dtype, block_size)
46+
else:
47+
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
48+
return "neon"
49+
elif torch._C._cpu._is_amx_tile_supported():
50+
return "amx"
51+
else:
52+
return "vec"
53+
54+
3955
# rand number generation takes too much time, cache rand tensors
4056
@functools.lru_cache(maxsize=128, typed=False)
4157
def tensor_cache(
@@ -452,6 +468,49 @@ def test_varlen_with_paged_kv_normal_vec16(
452468
)
453469

454470

471+
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
472+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
473+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
474+
@pytest.mark.parametrize("block_size", [96, 128])
475+
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
476+
@pytest.mark.parametrize("dtype", QTYPES)
477+
@pytest.mark.parametrize("soft_cap", [None])
478+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
479+
@pytest.mark.parametrize("use_alibi", [False])
480+
@pytest.mark.parametrize("use_sink", [False])
481+
@pytest.mark.parametrize("isa", ["neon"])
482+
@pytest.mark.skipif(
483+
current_platform.get_cpu_architecture() != CpuArchEnum.ARM,
484+
reason="Not an Arm CPU.",
485+
)
486+
def test_varlen_with_paged_kv_normal_neon(
487+
seq_lens: list[tuple[int, int]],
488+
num_heads: tuple[int, int],
489+
head_size: int,
490+
sliding_window: int | None,
491+
dtype: torch.dtype,
492+
block_size: int,
493+
soft_cap: float | None,
494+
num_blocks: int,
495+
use_alibi: bool,
496+
use_sink: bool,
497+
isa: str,
498+
) -> None:
499+
varlen_with_paged_kv(
500+
seq_lens=seq_lens,
501+
num_heads=num_heads,
502+
head_size=head_size,
503+
sliding_window=sliding_window,
504+
dtype=dtype,
505+
block_size=block_size,
506+
soft_cap=soft_cap,
507+
num_blocks=num_blocks,
508+
use_alibi=use_alibi,
509+
use_sink=use_sink,
510+
isa=isa,
511+
)
512+
513+
455514
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
456515
@pytest.mark.parametrize("num_heads", NUM_HEADS)
457516
@pytest.mark.parametrize("head_size", [96])
@@ -462,9 +521,7 @@ def test_varlen_with_paged_kv_normal_vec16(
462521
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
463522
@pytest.mark.parametrize("use_alibi", [False])
464523
@pytest.mark.parametrize("use_sink", [False])
465-
@pytest.mark.parametrize(
466-
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
467-
)
524+
@pytest.mark.parametrize("isa", [get_attn_isa()])
468525
def test_varlen_with_paged_kv_softcap(
469526
seq_lens: list[tuple[int, int]],
470527
num_heads: tuple[int, int],
@@ -503,9 +560,7 @@ def test_varlen_with_paged_kv_softcap(
503560
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
504561
@pytest.mark.parametrize("use_alibi", [True])
505562
@pytest.mark.parametrize("use_sink", [False])
506-
@pytest.mark.parametrize(
507-
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
508-
)
563+
@pytest.mark.parametrize("isa", [get_attn_isa()])
509564
def test_varlen_with_paged_kv_alibi(
510565
seq_lens: list[tuple[int, int]],
511566
num_heads: tuple[int, int],
@@ -544,9 +599,7 @@ def test_varlen_with_paged_kv_alibi(
544599
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
545600
@pytest.mark.parametrize("use_alibi", [False])
546601
@pytest.mark.parametrize("use_sink", [True])
547-
@pytest.mark.parametrize(
548-
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
549-
)
602+
@pytest.mark.parametrize("isa", [get_attn_isa()])
550603
def test_varlen_with_paged_kv_sink(
551604
seq_lens: list[tuple[int, int]],
552605
num_heads: tuple[int, int],

0 commit comments

Comments
 (0)