1010from tests .kernels .utils import opcheck
1111from vllm import _custom_ops as ops
1212from vllm .platforms import current_platform
13- from vllm .utils import get_max_shared_memory_bytes , is_navi
13+ from vllm .utils import get_max_shared_memory_bytes
1414
1515if not current_platform .is_rocm ():
1616 from xformers import ops as xops
3737
3838# This should be sync with get_supported_head_sizes() in
3939# vllm.attention.ops.paged_attn.PagedAttention
40- HEAD_SIZES = [64 , 80 , 96 , 112 , 120 , 128 , 192 , 256 ]
40+ HEAD_SIZES = [32 , 64 , 80 , 96 , 112 , 120 , 128 , 192 , 256 ]
4141
4242BLOCK_SIZES = [16 , 32 ]
4343USE_ALIBI = [False , True ]
@@ -195,10 +195,6 @@ def test_paged_attention(
195195 # Using default kv_scale
196196 k_scale = v_scale = torch .tensor (1.0 , dtype = torch .float32 , device = device )
197197
198- # additional argument for v1/v2 pa kernel
199- num_threads = 1024 if current_platform .is_rocm () \
200- and not is_navi () else 128
201-
202198 # Call the paged attention kernel.
203199 output = torch .empty_like (query )
204200 if version == "v1" :
@@ -219,12 +215,12 @@ def test_paged_attention(
219215 v_scale ,
220216 )
221217
222- opcheck (
223- torch . ops . _C . paged_attention_v1 ,
224- ( output , query , key_cache , value_cache , num_kv_heads , scale ,
225- block_tables , seq_lens , block_size , max_seq_len , alibi_slopes ,
226- kv_cache_dtype , k_scale , v_scale , 0 , 0 , 0 , 64 , 0 , num_threads ),
227- cond = ( head_size == HEAD_SIZES [ 0 ] and block_size == BLOCK_SIZES [0 ]))
218+ opcheck (torch . ops . _C . paged_attention_v1 ,
219+ ( output , query , key_cache , value_cache , num_kv_heads , scale ,
220+ block_tables , seq_lens , block_size , max_seq_len , alibi_slopes ,
221+ kv_cache_dtype , k_scale , v_scale , 0 , 0 , 0 , 64 , 0 ) ,
222+ cond = ( head_size == HEAD_SIZES [ 0 ]
223+ and block_size == BLOCK_SIZES [0 ]))
228224
229225 elif version in ("v2" , "rocm" ):
230226 if current_platform .is_rocm () and version == "rocm" :
@@ -263,14 +259,13 @@ def test_paged_attention(
263259 v_scale ,
264260 )
265261
266- opcheck (
267- torch .ops ._C .paged_attention_v2 ,
268- (output , exp_sums , max_logits , tmp_output , query , key_cache ,
269- value_cache , num_kv_heads , scale , block_tables , seq_lens ,
270- block_size , max_seq_len , alibi_slopes , kv_cache_dtype ,
271- k_scale , v_scale , 0 , 0 , 0 , 64 , 0 , num_threads ),
272- cond = (head_size == HEAD_SIZES [0 ]
273- and block_size == BLOCK_SIZES [0 ]))
262+ opcheck (torch .ops ._C .paged_attention_v2 ,
263+ (output , exp_sums , max_logits , tmp_output , query ,
264+ key_cache , value_cache , num_kv_heads , scale , block_tables ,
265+ seq_lens , block_size , max_seq_len , alibi_slopes ,
266+ kv_cache_dtype , k_scale , v_scale , 0 , 0 , 0 , 64 , 0 ),
267+ cond = (head_size == HEAD_SIZES [0 ]
268+ and block_size == BLOCK_SIZES [0 ]))
274269
275270 else :
276271 ops .paged_attention_rocm (
0 commit comments