77import pytest
88import torch
99
10- from vllm .platforms import current_platform
10+ from vllm .platforms import CpuArchEnum , current_platform
11+ from vllm .v1 .attention .backends .cpu_attn import _get_attn_isa
1112
1213if not current_platform .is_cpu ():
1314 pytest .skip ("skipping CPU-only tests" , allow_module_level = True )
3637]
3738
3839
40+ def get_attn_isa (
41+ block_size : int | None = None ,
42+ dtype : torch .dtype | None = None ,
43+ ):
44+ if block_size and dtype :
45+ return _get_attn_isa (dtype , block_size )
46+ else :
47+ if current_platform .get_cpu_architecture () == CpuArchEnum .ARM :
48+ return "neon"
49+ elif torch ._C ._cpu ._is_amx_tile_supported ():
50+ return "amx"
51+ else :
52+ return "vec"
53+
54+
3955# rand number generation takes too much time, cache rand tensors
4056@functools .lru_cache (maxsize = 128 , typed = False )
4157def tensor_cache (
@@ -452,6 +468,49 @@ def test_varlen_with_paged_kv_normal_vec16(
452468 )
453469
454470
471+ @pytest .mark .parametrize ("seq_lens" , SEQ_LENS )
472+ @pytest .mark .parametrize ("num_heads" , NUM_HEADS )
473+ @pytest .mark .parametrize ("head_size" , HEAD_SIZES )
474+ @pytest .mark .parametrize ("block_size" , [96 , 128 ])
475+ @pytest .mark .parametrize ("sliding_window" , SLIDING_WINDOWS )
476+ @pytest .mark .parametrize ("dtype" , QTYPES )
477+ @pytest .mark .parametrize ("soft_cap" , [None ])
478+ @pytest .mark .parametrize ("num_blocks" , NUM_BLOCKS )
479+ @pytest .mark .parametrize ("use_alibi" , [False ])
480+ @pytest .mark .parametrize ("use_sink" , [False ])
481+ @pytest .mark .parametrize ("isa" , ["neon" ])
482+ @pytest .mark .skipif (
483+ current_platform .get_cpu_architecture () != CpuArchEnum .ARM ,
484+ reason = "Not an Arm CPU." ,
485+ )
486+ def test_varlen_with_paged_kv_normal_neon (
487+ seq_lens : list [tuple [int , int ]],
488+ num_heads : tuple [int , int ],
489+ head_size : int ,
490+ sliding_window : int | None ,
491+ dtype : torch .dtype ,
492+ block_size : int ,
493+ soft_cap : float | None ,
494+ num_blocks : int ,
495+ use_alibi : bool ,
496+ use_sink : bool ,
497+ isa : str ,
498+ ) -> None :
499+ varlen_with_paged_kv (
500+ seq_lens = seq_lens ,
501+ num_heads = num_heads ,
502+ head_size = head_size ,
503+ sliding_window = sliding_window ,
504+ dtype = dtype ,
505+ block_size = block_size ,
506+ soft_cap = soft_cap ,
507+ num_blocks = num_blocks ,
508+ use_alibi = use_alibi ,
509+ use_sink = use_sink ,
510+ isa = isa ,
511+ )
512+
513+
455514@pytest .mark .parametrize ("seq_lens" , SEQ_LENS )
456515@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
457516@pytest .mark .parametrize ("head_size" , [96 ])
@@ -462,9 +521,7 @@ def test_varlen_with_paged_kv_normal_vec16(
462521@pytest .mark .parametrize ("num_blocks" , NUM_BLOCKS )
463522@pytest .mark .parametrize ("use_alibi" , [False ])
464523@pytest .mark .parametrize ("use_sink" , [False ])
465- @pytest .mark .parametrize (
466- "isa" , ["amx" ] if torch ._C ._cpu ._is_amx_tile_supported () else ["vec" ]
467- )
524+ @pytest .mark .parametrize ("isa" , [get_attn_isa ()])
468525def test_varlen_with_paged_kv_softcap (
469526 seq_lens : list [tuple [int , int ]],
470527 num_heads : tuple [int , int ],
@@ -503,9 +560,7 @@ def test_varlen_with_paged_kv_softcap(
503560@pytest .mark .parametrize ("num_blocks" , NUM_BLOCKS )
504561@pytest .mark .parametrize ("use_alibi" , [True ])
505562@pytest .mark .parametrize ("use_sink" , [False ])
506- @pytest .mark .parametrize (
507- "isa" , ["amx" ] if torch ._C ._cpu ._is_amx_tile_supported () else ["vec" ]
508- )
563+ @pytest .mark .parametrize ("isa" , [get_attn_isa ()])
509564def test_varlen_with_paged_kv_alibi (
510565 seq_lens : list [tuple [int , int ]],
511566 num_heads : tuple [int , int ],
@@ -544,9 +599,7 @@ def test_varlen_with_paged_kv_alibi(
544599@pytest .mark .parametrize ("num_blocks" , NUM_BLOCKS )
545600@pytest .mark .parametrize ("use_alibi" , [False ])
546601@pytest .mark .parametrize ("use_sink" , [True ])
547- @pytest .mark .parametrize (
548- "isa" , ["amx" ] if torch ._C ._cpu ._is_amx_tile_supported () else ["vec" ]
549- )
602+ @pytest .mark .parametrize ("isa" , [get_attn_isa ()])
550603def test_varlen_with_paged_kv_sink (
551604 seq_lens : list [tuple [int , int ]],
552605 num_heads : tuple [int , int ],
0 commit comments