31
31
32
32
# FlashAttention forward only supports head dimension at most 128
33
33
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
34
- HEAD_SIZES = [64 , 80 , 96 , 112 , 120 , 128 , 192 , 256
35
- ] if not is_hip () else [64 , 80 , 96 , 112 , 128 ]
34
+ HEAD_SIZES = [64 , 80 , 96 , 112 , 120 , 128 , 192 , 256 ]
36
35
37
36
BLOCK_SIZES = [16 , 32 ]
38
37
USE_ALIBI = [False , True ]
@@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention(
114
113
output [i ].copy_ (out , non_blocking = True )
115
114
116
115
117
- @pytest .mark .parametrize ("version" , ["v1" , "v2" ])
116
+ @pytest .mark .parametrize (
117
+ "version" , ["v1" , "v2" ] if not is_hip () else ["v1" , "v2" , "rocm" ])
118
118
@pytest .mark .parametrize ("num_seqs" , NUM_GEN_SEQS )
119
119
@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
120
120
@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
@@ -137,7 +137,8 @@ def test_paged_attention(
137
137
seed : int ,
138
138
device : str ,
139
139
) -> None :
140
- if kv_cache_dtype == "fp8" and head_size % 16 :
140
+ if ((kv_cache_dtype == "fp8" and head_size % 16 )
141
+ or (version == "rocm" and head_size not in (64 , 128 ))):
141
142
pytest .skip ()
142
143
random .seed (seed )
143
144
torch .random .manual_seed (seed )
@@ -208,7 +209,9 @@ def test_paged_attention(
208
209
kv_cache_dtype , k_scale , v_scale , 0 , 0 , 0 , 64 , 0 ),
209
210
cond = (head_size == HEAD_SIZES [0 ]))
210
211
211
- elif version == "v2" :
212
+ elif version in ("v2" , "rocm" ):
213
+ if is_hip ():
214
+ PARTITION_SIZE = 1024 if version == "v2" else 512
212
215
num_partitions = ((max_seq_len + PARTITION_SIZE - 1 ) // PARTITION_SIZE )
213
216
assert PARTITION_SIZE % block_size == 0
214
217
num_seqs , num_heads , head_size = output .shape
@@ -221,32 +224,62 @@ def test_paged_attention(
221
224
dtype = torch .float32 ,
222
225
)
223
226
max_logits = torch .empty_like (exp_sums )
224
- ops .paged_attention_v2 (
225
- output ,
226
- exp_sums ,
227
- max_logits ,
228
- tmp_output ,
229
- query ,
230
- key_cache ,
231
- value_cache ,
232
- num_kv_heads ,
233
- scale ,
234
- block_tables ,
235
- seq_lens ,
236
- block_size ,
237
- max_seq_len ,
238
- alibi_slopes ,
239
- kv_cache_dtype ,
240
- k_scale ,
241
- v_scale ,
242
- )
243
227
244
- opcheck (torch .ops ._C .paged_attention_v2 ,
245
- (output , exp_sums , max_logits , tmp_output , query , key_cache ,
246
- value_cache , num_kv_heads , scale , block_tables , seq_lens ,
247
- block_size , max_seq_len , alibi_slopes , kv_cache_dtype ,
248
- k_scale , v_scale , 0 , 0 , 0 , 64 , 0 ),
249
- cond = (head_size == HEAD_SIZES [0 ]))
228
+ if version == "v2" :
229
+ ops .paged_attention_v2 (
230
+ output ,
231
+ exp_sums ,
232
+ max_logits ,
233
+ tmp_output ,
234
+ query ,
235
+ key_cache ,
236
+ value_cache ,
237
+ num_kv_heads ,
238
+ scale ,
239
+ block_tables ,
240
+ seq_lens ,
241
+ block_size ,
242
+ max_seq_len ,
243
+ alibi_slopes ,
244
+ kv_cache_dtype ,
245
+ k_scale ,
246
+ v_scale ,
247
+ )
248
+
249
+ opcheck (torch .ops ._C .paged_attention_v2 ,
250
+ (output , exp_sums , max_logits , tmp_output , query ,
251
+ key_cache , value_cache , num_kv_heads , scale , block_tables ,
252
+ seq_lens , block_size , max_seq_len , alibi_slopes ,
253
+ kv_cache_dtype , k_scale , v_scale , 0 , 0 , 0 , 64 , 0 ),
254
+ cond = (head_size == HEAD_SIZES [0 ]))
255
+
256
+ else :
257
+ ops .paged_attention_rocm (
258
+ output ,
259
+ exp_sums ,
260
+ max_logits ,
261
+ tmp_output ,
262
+ query ,
263
+ key_cache ,
264
+ value_cache ,
265
+ num_kv_heads ,
266
+ scale ,
267
+ block_tables ,
268
+ seq_lens ,
269
+ block_size ,
270
+ max_seq_len ,
271
+ alibi_slopes ,
272
+ kv_cache_dtype ,
273
+ k_scale ,
274
+ v_scale ,
275
+ )
276
+
277
+ opcheck (torch .ops ._rocm_C .paged_attention ,
278
+ (output , exp_sums , max_logits , tmp_output , query ,
279
+ key_cache , value_cache , num_kv_heads , scale , block_tables ,
280
+ seq_lens , block_size , max_seq_len , alibi_slopes ,
281
+ kv_cache_dtype , k_scale , v_scale ),
282
+ cond = (head_size == 64 ))
250
283
251
284
else :
252
285
raise AssertionError (f"Unknown version: { version } " )
@@ -330,173 +363,15 @@ def ref_multi_query_kv_attention(
330
363
return torch .cat (ref_outputs , dim = 0 )
331
364
332
365
333
- @pytest .mark .parametrize ("version" , ["rocm" ])
334
- @pytest .mark .parametrize ("num_seqs" , NUM_GEN_SEQS )
335
- @pytest .mark .parametrize ("num_heads" , NUM_HEADS )
336
- @pytest .mark .parametrize ("head_size" , [64 , 128 ]) # only test 64 128
337
- @pytest .mark .parametrize ("use_alibi" , USE_ALIBI )
338
- @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
339
- @pytest .mark .parametrize ("dtype" , DTYPES )
340
- @pytest .mark .parametrize ("kv_cache_dtype" , ["auto" ])
341
- @pytest .mark .parametrize ("seed" , SEEDS )
342
- @pytest .mark .parametrize ("device" , CUDA_DEVICES )
343
- @pytest .mark .skipif (not is_hip (), reason = "only for rocm" )
344
- def test_paged_attention_rocm (
345
- kv_cache_factory ,
346
- version : str ,
347
- num_seqs : int ,
348
- num_heads : Tuple [int , int ],
349
- head_size : int ,
350
- use_alibi : bool ,
351
- block_size : int ,
352
- dtype : torch .dtype ,
353
- kv_cache_dtype : str ,
354
- seed : int ,
355
- device : str ,
356
- ) -> None :
357
- random .seed (seed )
358
- torch .random .manual_seed (seed )
359
- if torch .cuda .is_available ():
360
- torch .cuda .manual_seed (seed )
361
- torch .set_default_device (device )
362
- scale = float (1.0 / (head_size ** 0.5 ))
363
- num_query_heads , num_kv_heads = num_heads
364
- query = torch .empty (num_seqs , num_query_heads , head_size , dtype = dtype )
365
- query .uniform_ (- scale , scale )
366
-
367
- assert num_query_heads % num_kv_heads == 0
368
- num_queries_per_kv = num_query_heads // num_kv_heads
369
- alibi_slopes = None
370
- if use_alibi :
371
- alibi_slopes = torch .randn (num_query_heads , dtype = torch .float )
372
-
373
- context_lens = [random .randint (1 , MAX_SEQ_LEN ) for _ in range (num_seqs )]
374
- context_lens [- 1 ] = MAX_SEQ_LEN
375
- #context_lens = [8192 for _ in range(num_seqs)]
376
- max_context_len = max (context_lens )
377
- context_lens = torch .tensor (context_lens , dtype = torch .int )
378
- #print('>>> ctx lens', context_lens)
379
-
380
- # Create the block tables.
381
- max_num_blocks_per_seq = (max_context_len + block_size - 1 ) // block_size
382
- block_tables = []
383
- for _ in range (num_seqs ):
384
- block_table = [
385
- random .randint (0 , NUM_BLOCKS - 1 )
386
- for _ in range (max_num_blocks_per_seq )
387
- ]
388
- block_tables .append (block_table )
389
- block_tables = torch .tensor (block_tables , dtype = torch .int )
390
-
391
- # Create the KV caches.
392
- key_caches , value_caches = kv_cache_factory (NUM_BLOCKS , block_size , 1 ,
393
- num_kv_heads , head_size ,
394
- kv_cache_dtype , dtype , seed ,
395
- device )
396
- key_cache , value_cache = key_caches [0 ], value_caches [0 ]
397
-
398
- # TODO(charlifu) enable fp8 kv cache
399
- # Using default kv_scale
400
- # kv_scale = 1.0
401
-
402
- # Call the paged attention kernel.
403
- output = torch .empty_like (query )
404
- PARTITION_SIZE_ROCM = 256
405
- num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1 ) //
406
- PARTITION_SIZE_ROCM )
407
- assert PARTITION_SIZE_ROCM % block_size == 0
408
- num_seqs , num_heads , head_size = output .shape
409
- tmp_output = torch .empty (
410
- size = (num_seqs , num_heads , num_partitions , head_size ),
411
- dtype = output .dtype ,
412
- )
413
- exp_sums = torch .empty (
414
- size = (num_seqs , num_heads , num_partitions ),
415
- dtype = torch .float32 ,
416
- )
417
- max_logits = torch .empty_like (exp_sums )
418
- if version == "rocm" :
419
- ops .paged_attention_rocm (
420
- output ,
421
- exp_sums ,
422
- max_logits ,
423
- tmp_output ,
424
- query ,
425
- key_cache ,
426
- value_cache ,
427
- num_kv_heads ,
428
- scale ,
429
- block_tables ,
430
- context_lens ,
431
- block_size ,
432
- max_context_len ,
433
- alibi_slopes ,
434
- kv_cache_dtype ,
435
- )
436
- else :
437
- raise AssertionError (f"Unknown version: { version } " )
438
-
439
- # Run the reference implementation.
440
- if kv_cache_dtype == "fp8" :
441
- # Convert cache data back to dtype.
442
- x = 16 // torch .tensor ([], dtype = dtype ).element_size ()
443
- key_cache_shape = (NUM_BLOCKS , num_kv_heads , head_size // x ,
444
- block_size , x )
445
- dequantized_key_cache = torch .empty (size = key_cache_shape ,
446
- dtype = dtype ,
447
- device = device )
448
- ops .convert_fp8 (key_cache , dequantized_key_cache )
449
- key_cache = dequantized_key_cache
450
-
451
- value_cache_shape = value_cache .shape
452
- dequantized_value_cache = torch .empty (size = value_cache_shape ,
453
- dtype = dtype ,
454
- device = device )
455
- ops .convert_fp8 (value_cache , dequantized_value_cache )
456
- value_cache = dequantized_value_cache
457
-
458
- ref_output = torch .empty_like (query )
459
- ref_single_query_cached_kv_attention (
460
- ref_output ,
461
- query ,
462
- num_queries_per_kv ,
463
- key_cache ,
464
- value_cache ,
465
- block_tables ,
466
- context_lens ,
467
- scale ,
468
- alibi_slopes ,
469
- )
470
-
471
- # NOTE(woosuk): Due to the kernel-level differences in the two
472
- # implementations, there is a small numerical difference in the two
473
- # outputs. Thus, we use a relaxed tolerance for the test.
474
- atol = get_default_atol (output ) if is_hip () else 1e-3
475
- rtol = get_default_rtol (output ) if is_hip () else 1e-5
476
-
477
- # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
478
- # so we use a relaxed tolerance for the test.
479
- atol , rtol = 1e-4 , 1e-5
480
- if dtype == torch .bfloat16 :
481
- atol , rtol = 2e-4 , 1e-5
482
- if use_alibi :
483
- if dtype == torch .half :
484
- atol , rtol = 5e-4 , 1e-5
485
- if dtype == torch .bfloat16 :
486
- atol , rtol = 1e-3 , 1e-5
487
- if kv_cache_dtype == "fp8" :
488
- atol , rtol = 1e-2 , 1e-5
489
- assert torch .allclose (output , ref_output , atol = atol , rtol = rtol )
490
-
491
-
492
366
# TODO(woosuk): Add tests for USE_ALIBI=True.
493
367
@pytest .mark .parametrize ("num_seqs" , NUM_PREFILL_SEQS )
494
368
@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
495
369
@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
496
370
@pytest .mark .parametrize ("dtype" , DTYPES )
497
371
@pytest .mark .parametrize ("seed" , SEEDS )
498
372
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
499
- @pytest .mark .skipif (is_hip (), reason = "skip for rocm" )
373
+ @pytest .mark .skipif (is_hip (),
374
+ reason = "Xformers backend is not supported on ROCm." )
500
375
@torch .inference_mode ()
501
376
def test_multi_query_kv_attention (
502
377
num_seqs : int ,
0 commit comments