@@ -73,11 +73,14 @@ def ref_paged_attn(
7373@pytest .mark .parametrize ("dtype" , DTYPES )
7474@pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
7575@torch .inference_mode
76- def test_flashinfer_decode_with_paged_kv (kv_lens : List [int ],
77- num_heads : Tuple [int ,
78- int ], head_size : int ,
79- dtype : torch .dtype , block_size : int ,
80- soft_cap : Optional [float ]) -> None :
76+ def test_flashinfer_decode_with_paged_kv (
77+ kv_lens : List [int ],
78+ num_heads : Tuple [int , int ],
79+ head_size : int ,
80+ dtype : torch .dtype ,
81+ block_size : int ,
82+ soft_cap : Optional [float ],
83+ ) -> None :
8184 torch .set_default_device ("cuda" )
8285 torch .cuda .manual_seed_all (0 )
8386 num_seqs = len (kv_lens )
@@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
8891 scale = head_size ** - 0.5
8992
9093 query = torch .randn (num_seqs , num_query_heads , head_size , dtype = dtype )
94+
9195 key_value_cache = torch .randn (NUM_BLOCKS ,
9296 2 ,
9397 block_size ,
@@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
125129 wrapper = flashinfer .\
126130 BatchDecodeWithPagedKVCacheWrapper (workspace_buffer , "NHD" ,
127131 use_tensor_cores = (
128- (num_query_heads // num_kv_heads ) not in ( 1 , 2 , 4 , 8 ) )
132+ (num_query_heads // num_kv_heads ) > 4 )
129133 )
130134 wrapper .begin_forward (kv_indptr ,
131135 kv_indices ,
@@ -249,3 +253,215 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
249253 soft_cap = soft_cap )
250254 torch .testing .assert_close (output , ref_output , atol = 1e-2 , rtol = 1e-2 ), \
251255 f"{ torch .max (torch .abs (output - ref_output ))} "
256+
257+
258+ @pytest .mark .parametrize ("seq_lens" , [[(1 , 132 ), (5 , 18 )]])
259+ @pytest .mark .parametrize ("num_heads" , [(32 , 8 ), (6 , 1 )])
260+ @pytest .mark .parametrize ("head_size" , HEAD_SIZES )
261+ @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
262+ @pytest .mark .parametrize ("dtype" , DTYPES )
263+ @pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
264+ def test_flashinfer_prefill_with_paged_fp8_kv (
265+ seq_lens : List [Tuple [int , int ]], num_heads : Tuple [int , int ],
266+ head_size : int , dtype : torch .dtype , block_size : int ,
267+ soft_cap : Optional [float ]) -> None :
268+ torch .set_default_device ("cuda" )
269+ torch .cuda .manual_seed_all (0 )
270+ num_seqs = len (seq_lens )
271+ query_lens = [x [0 ] for x in seq_lens ]
272+ kv_lens = [x [1 ] for x in seq_lens ]
273+ num_query_heads = num_heads [0 ]
274+ num_kv_heads = num_heads [1 ]
275+ assert num_query_heads % num_kv_heads == 0
276+ max_kv_len = max (kv_lens )
277+ scale = head_size ** - 0.5
278+
279+ kv_cache_dtype = torch .float8_e4m3fn
280+
281+ query = torch .randn (sum (query_lens ),
282+ num_query_heads ,
283+ head_size ,
284+ dtype = dtype )
285+ NUM_BLOCKS_FP8 = 2048
286+ key_value_cache = torch .randn (NUM_BLOCKS_FP8 ,
287+ 2 ,
288+ block_size ,
289+ num_kv_heads ,
290+ head_size ,
291+ dtype = dtype )
292+ key_cache , value_cache = torch .chunk (key_value_cache , 2 , dim = 1 )
293+ key_cache /= head_size ** 0.5
294+ value_cache /= head_size ** 0.5
295+
296+ k_scale = key_cache .amax ().item () / 448.0
297+ v_scale = value_cache .amax ().item () / 448.0
298+
299+ kv_cache_fp8 = torch .cat ([key_cache / k_scale , value_cache / v_scale ],
300+ dim = 1 ).to (kv_cache_dtype )
301+
302+ assert (kv_cache_fp8 .shape == key_value_cache .shape )
303+ max_num_blocks_per_seq = (max_kv_len + block_size - 1 ) // block_size
304+ block_tables = torch .randint (0 ,
305+ NUM_BLOCKS_FP8 ,
306+ (num_seqs , max_num_blocks_per_seq ),
307+ dtype = torch .int32 )
308+
309+ qo_indptr = [0 ]
310+ kv_indptr = [0 ]
311+ kv_indices = []
312+ kv_last_page_lens = []
313+ for i in range (num_seqs ):
314+ seq_len = kv_lens [i ]
315+ assert seq_len > 0
316+ num_blocks = (seq_len + block_size - 1 ) // block_size
317+ kv_indices .extend (block_tables [i , :num_blocks ])
318+ kv_indptr .append (kv_indptr [- 1 ] + num_blocks )
319+ kv_last_page_len = seq_len % block_size
320+ if kv_last_page_len == 0 :
321+ kv_last_page_len = block_size
322+ kv_last_page_lens .append (kv_last_page_len )
323+ qo_indptr .append (qo_indptr [- 1 ] + query_lens [i ])
324+
325+ qo_indptr = torch .tensor (qo_indptr , dtype = torch .int32 )
326+ kv_indptr = torch .tensor (kv_indptr , dtype = torch .int32 )
327+ kv_indices = torch .tensor (kv_indices , dtype = torch .int32 )
328+ kv_last_page_lens = torch .tensor (kv_last_page_lens , dtype = torch .int32 )
329+
330+ workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 )
331+ wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
332+ workspace_buffer , "NHD" )
333+ wrapper .begin_forward (
334+ qo_indptr ,
335+ kv_indptr ,
336+ kv_indices ,
337+ kv_last_page_lens ,
338+ num_query_heads ,
339+ num_kv_heads ,
340+ head_size ,
341+ block_size ,
342+ )
343+
344+ output = wrapper .forward (query ,
345+ kv_cache_fp8 ,
346+ logits_soft_cap = soft_cap ,
347+ k_scale = k_scale ,
348+ v_scale = v_scale )
349+
350+ ref_output = ref_paged_attn (query = query ,
351+ key_cache = key_cache .squeeze (1 ),
352+ value_cache = value_cache .squeeze (1 ),
353+ query_lens = query_lens ,
354+ kv_lens = kv_lens ,
355+ block_tables = block_tables ,
356+ scale = scale ,
357+ soft_cap = soft_cap )
358+ del query
359+ del block_tables
360+ # verify prefill fp8
361+ torch .testing .assert_close (output , ref_output , atol = 1e-2 , rtol = 1e-2 ), \
362+ f"{ torch .max (torch .abs (output - ref_output ))} "
363+
364+
365+ @pytest .mark .parametrize ("kv_lens" , [[1328 , 18 , 463 ], [1 , 54 , 293 , 70 ]])
366+ @pytest .mark .parametrize ("num_heads" , [(32 , 8 ), (64 , 8 ), (6 , 1 )])
367+ @pytest .mark .parametrize ("head_size" , HEAD_SIZES )
368+ @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
369+ @pytest .mark .parametrize ("dtype" , DTYPES )
370+ @pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
371+ @torch .inference_mode
372+ def test_flashinfer_decode_with_paged_fp8_kv (
373+ kv_lens : List [int ],
374+ num_heads : Tuple [int , int ],
375+ head_size : int ,
376+ dtype : torch .dtype ,
377+ block_size : int ,
378+ soft_cap : Optional [float ],
379+ ) -> None :
380+ # test doesn't work for num_heads = (16,16)
381+ torch .set_default_device ("cuda" )
382+ torch .cuda .manual_seed_all (0 )
383+ num_seqs = len (kv_lens )
384+ num_query_heads = num_heads [0 ]
385+ num_kv_heads = num_heads [1 ]
386+ assert num_query_heads % num_kv_heads == 0
387+ max_kv_len = max (kv_lens )
388+ scale = head_size ** - 0.5
389+ use_tensor_cores = (num_query_heads // num_kv_heads ) > 4
390+ kv_cache_dtype = torch .float8_e4m3fn
391+
392+ query = torch .randn (num_seqs , num_query_heads , head_size , dtype = dtype )
393+ NUM_BLOCKS_FP8 = 2048
394+ key_value_cache = torch .randn (NUM_BLOCKS_FP8 ,
395+ 2 ,
396+ block_size ,
397+ num_kv_heads ,
398+ head_size ,
399+ dtype = dtype )
400+ key_cache , value_cache = torch .chunk (key_value_cache , 2 , dim = 1 )
401+ key_cache /= head_size ** 0.5
402+ value_cache /= head_size ** 0.5
403+
404+ k_scale = key_cache .amax ().item () / 448.0
405+ v_scale = value_cache .amax ().item () / 448.0
406+
407+ key_cache_fp8 = (key_cache / k_scale ).to (kv_cache_dtype )
408+ value_cache_fp8 = (value_cache / v_scale ).to (kv_cache_dtype )
409+ assert (key_cache_fp8 .shape [1 ] == 1 and value_cache_fp8 .shape [1 ] == 1 )
410+ kv_cache_fp8 = torch .cat ([key_cache_fp8 , value_cache_fp8 ], dim = 1 )
411+
412+ max_num_blocks_per_seq = (max_kv_len + block_size - 1 ) // block_size
413+ block_tables = torch .randint (0 ,
414+ NUM_BLOCKS_FP8 ,
415+ (num_seqs , max_num_blocks_per_seq ),
416+ dtype = torch .int32 )
417+
418+ kv_indptr = [0 ]
419+ kv_indices = []
420+ kv_last_page_lens = []
421+ for i in range (num_seqs ):
422+ seq_len = kv_lens [i ]
423+ assert seq_len > 0
424+ num_blocks = (seq_len + block_size - 1 ) // block_size
425+ kv_indices .extend (block_tables [i , :num_blocks ])
426+ kv_indptr .append (kv_indptr [- 1 ] + num_blocks )
427+ kv_last_page_len = seq_len % block_size
428+ if kv_last_page_len == 0 :
429+ kv_last_page_len = block_size
430+ kv_last_page_lens .append (kv_last_page_len )
431+
432+ kv_indptr = torch .tensor (kv_indptr , dtype = torch .int32 )
433+ kv_indices = torch .tensor (kv_indices , dtype = torch .int32 )
434+ kv_last_page_lens = torch .tensor (kv_last_page_lens , dtype = torch .int32 )
435+
436+ workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 )
437+ wrapper = flashinfer .\
438+ BatchDecodeWithPagedKVCacheWrapper (workspace_buffer , "NHD" ,
439+ use_tensor_cores = use_tensor_cores )
440+ wrapper .begin_forward (kv_indptr ,
441+ kv_indices ,
442+ kv_last_page_lens ,
443+ num_query_heads ,
444+ num_kv_heads ,
445+ head_size ,
446+ block_size ,
447+ "NONE" ,
448+ data_type = dtype )
449+ output = wrapper .forward (query ,
450+ kv_cache_fp8 ,
451+ logits_soft_cap = soft_cap ,
452+ k_scale = k_scale ,
453+ v_scale = v_scale )
454+ key_cache = key_value_cache [:, 0 , :, :, :].squeeze (1 )
455+ value_cache = key_value_cache [:, 1 , :, :, :].squeeze (1 )
456+
457+ ref_output = ref_paged_attn (query = query ,
458+ key_cache = key_cache ,
459+ value_cache = value_cache ,
460+ query_lens = [1 ] * num_seqs ,
461+ kv_lens = kv_lens ,
462+ block_tables = block_tables ,
463+ scale = scale ,
464+ soft_cap = soft_cap )
465+ # Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
466+ torch .testing .assert_close (output , ref_output , atol = 2e-2 , rtol = 1e-2 ), \
467+ f"{ torch .max (torch .abs (output - ref_output ))} "
0 commit comments