@@ -69,7 +69,6 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
6969 paged_kv_indptr_host = paged_kv_indptr .cpu ()
7070 paged_kv_last_page_len_host = paged_kv_last_page_len .cpu ()
7171 seq_len_with_cache_host = (offsets + seq_len_tensor ).cpu ()
72- seq_len_host = seq_len_tensor .cpu ()
7372
7473 # Q,K,V are computed using GEMM.
7574 q = torch .randn (BATCH_SIZE , SEQ_LEN , N_HEADS * D_HEAD , dtype = DTYPE ).to (device )
@@ -113,7 +112,6 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
113112 paged_kv_last_page_len ,
114113 paged_kv_last_page_len_host ,
115114 seq_len_with_cache_host ,
116- seq_len_host ,
117115 # EXTRA METADATA
118116 batch_indices ,
119117 positions ,
@@ -184,7 +182,6 @@ def test_flashinfer_attention_op_decode(
184182 paged_kv_indptr_host = paged_kv_indptr .cpu ()
185183 paged_kv_last_page_len_host = paged_kv_last_page_len .cpu ()
186184 seq_len_with_cache_host = (offsets + seq_len_tensor ).cpu ()
187- seq_len_host = seq_len_tensor .cpu ()
188185
189186 # Q,K,V are computed using GEMM.
190187 q = torch .randn (BATCH_SIZE , SEQ_LEN , N_HEADS * D_HEAD , dtype = DTYPE ).to (device )
@@ -259,7 +256,6 @@ def test_flashinfer_attention_op_decode(
259256 paged_kv_last_page_len ,
260257 paged_kv_last_page_len_host ,
261258 seq_len_with_cache_host ,
262- seq_len_host ,
263259 # EXTRA METADATA
264260 batch_indices ,
265261 positions ,
@@ -350,7 +346,6 @@ def test_flashinfer_attention_context_and_generate(
350346 paged_kv_indptr_host = paged_kv_indptr .cpu ()
351347 paged_kv_last_page_len_host = paged_kv_last_page_len .cpu ()
352348 seq_len_with_cache_host = (offsets + seq_len_tensor ).cpu ()
353- seq_len_host = seq_len_tensor .cpu ()
354349
355350 # Q,K,V for prefill phase
356351 q_1 = torch .randn (BATCH_SIZE , PREFILL_SEQ_LEN , N_HEADS * D_HEAD , dtype = DTYPE ).to (device )
@@ -394,7 +389,6 @@ def test_flashinfer_attention_context_and_generate(
394389 paged_kv_last_page_len ,
395390 paged_kv_last_page_len_host ,
396391 seq_len_with_cache_host ,
397- seq_len_host ,
398392 # EXTRA METADATA
399393 batch_indices ,
400394 positions ,
@@ -453,7 +447,6 @@ def test_flashinfer_attention_context_and_generate(
453447 paged_kv_indptr_host = paged_kv_indptr .cpu ()
454448 paged_kv_last_page_len_host = paged_kv_last_page_len .cpu ()
455449 seq_len_with_cache_host = (offsets + seq_len_tensor ).cpu ()
456- seq_len_host = seq_len_tensor .cpu ()
457450
458451 # Q,K,V are computed using GEMM.
459452 q_3 = torch .randn (BATCH_SIZE , 1 , N_HEADS * D_HEAD , dtype = DTYPE ).to (device )
@@ -486,7 +479,6 @@ def test_flashinfer_attention_context_and_generate(
486479 paged_kv_last_page_len ,
487480 paged_kv_last_page_len_host ,
488481 seq_len_with_cache_host ,
489- seq_len_host ,
490482 # EXTRA METADATA
491483 batch_indices ,
492484 positions ,
@@ -568,7 +560,6 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
568560 paged_kv_indptr_host = paged_kv_indptr .cpu ()
569561 paged_kv_last_page_len_host = paged_kv_last_page_len .cpu ()
570562 seq_len_with_cache_host = (offsets + seq_len_tensor ).cpu ()
571- seq_len_host = seq_len_tensor .cpu ()
572563
573564 # Q,K,V are computed using GEMM.
574565 q = torch .randn (BATCH_SIZE , SEQ_LEN , N_HEADS * D_HEAD , dtype = DTYPE ).to (device )
@@ -612,7 +603,6 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
612603 paged_kv_last_page_len ,
613604 paged_kv_last_page_len_host ,
614605 seq_len_with_cache_host ,
615- seq_len_host ,
616606 # EXTRA METADATA
617607 batch_indices ,
618608 positions ,
@@ -702,7 +692,6 @@ def test_flashinfer_attention_with_fp8_cache(
702692 paged_kv_indptr_host = paged_kv_indptr .cpu ()
703693 paged_kv_last_page_len_host = paged_kv_last_page_len .cpu ()
704694 seq_len_with_cache_host = (offsets + seq_len_tensor ).cpu ()
705- seq_len_host = seq_len_tensor .cpu ()
706695
707696 # Q,K,V are computed using GEMM, in fp16
708697 q = torch .randn (BATCH_SIZE , SEQ_LEN , N_HEADS * D_HEAD , dtype = DTYPE ).to (device )
@@ -776,7 +765,6 @@ def test_flashinfer_attention_with_fp8_cache(
776765 paged_kv_last_page_len ,
777766 paged_kv_last_page_len_host ,
778767 seq_len_with_cache_host ,
779- seq_len_host ,
780768 # EXTRA METADATA
781769 batch_indices ,
782770 positions ,
@@ -858,7 +846,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
858846 paged_kv_indptr_host = paged_kv_indptr .cpu ()
859847 paged_kv_last_page_len_host = paged_kv_last_page_len .cpu ()
860848 seq_len_with_cache_host = (offsets + seq_len_tensor ).cpu ()
861- seq_len_host = seq_len_tensor .cpu ()
862849
863850 # make sure planner is initialized
864851 workspace = torch .empty (128 * 1024 * 1024 , dtype = torch .uint8 , device = device )
@@ -887,7 +874,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
887874 paged_kv_last_page_len ,
888875 paged_kv_last_page_len_host ,
889876 seq_len_with_cache_host ,
890- seq_len_host ,
891877 # EXTRA METADATA
892878 batch_indices ,
893879 positions ,
@@ -957,7 +943,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
957943 paged_kv_indptr2_host = paged_kv_indptr2 .cpu ()
958944 paged_kv_last_page_len2_host = paged_kv_last_page_len2 .cpu ()
959945 seq_len_with_cache2_host = (offsets2 + seq_len_tensor2 ).cpu ()
960- seq_len2_host = seq_len_tensor2 .cpu ()
961946
962947 # Create FlashInferAttention class before calling the custom op
963948 _GlobalFlashInferPlanner .reset ()
@@ -985,7 +970,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
985970 paged_kv_last_page_len2 ,
986971 paged_kv_last_page_len2_host ,
987972 seq_len_with_cache2_host ,
988- seq_len2_host ,
989973 # EXTRA METADATA
990974 batch_indices ,
991975 positions ,
0 commit comments