Skip to content

Commit d628399

Browse files
committed
update unit test
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent b85376f commit d628399

File tree

1 file changed

+0
-16
lines changed

1 file changed

+0
-16
lines changed

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
6969
paged_kv_indptr_host = paged_kv_indptr.cpu()
7070
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
7171
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
72-
seq_len_host = seq_len_tensor.cpu()
7372

7473
# Q,K,V are computed using GEMM.
7574
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@@ -113,7 +112,6 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
113112
paged_kv_last_page_len,
114113
paged_kv_last_page_len_host,
115114
seq_len_with_cache_host,
116-
seq_len_host,
117115
# EXTRA METADATA
118116
batch_indices,
119117
positions,
@@ -184,7 +182,6 @@ def test_flashinfer_attention_op_decode(
184182
paged_kv_indptr_host = paged_kv_indptr.cpu()
185183
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
186184
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
187-
seq_len_host = seq_len_tensor.cpu()
188185

189186
# Q,K,V are computed using GEMM.
190187
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@@ -259,7 +256,6 @@ def test_flashinfer_attention_op_decode(
259256
paged_kv_last_page_len,
260257
paged_kv_last_page_len_host,
261258
seq_len_with_cache_host,
262-
seq_len_host,
263259
# EXTRA METADATA
264260
batch_indices,
265261
positions,
@@ -350,7 +346,6 @@ def test_flashinfer_attention_context_and_generate(
350346
paged_kv_indptr_host = paged_kv_indptr.cpu()
351347
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
352348
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
353-
seq_len_host = seq_len_tensor.cpu()
354349

355350
# Q,K,V for prefill phase
356351
q_1 = torch.randn(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@@ -394,7 +389,6 @@ def test_flashinfer_attention_context_and_generate(
394389
paged_kv_last_page_len,
395390
paged_kv_last_page_len_host,
396391
seq_len_with_cache_host,
397-
seq_len_host,
398392
# EXTRA METADATA
399393
batch_indices,
400394
positions,
@@ -453,7 +447,6 @@ def test_flashinfer_attention_context_and_generate(
453447
paged_kv_indptr_host = paged_kv_indptr.cpu()
454448
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
455449
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
456-
seq_len_host = seq_len_tensor.cpu()
457450

458451
# Q,K,V are computed using GEMM.
459452
q_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@@ -486,7 +479,6 @@ def test_flashinfer_attention_context_and_generate(
486479
paged_kv_last_page_len,
487480
paged_kv_last_page_len_host,
488481
seq_len_with_cache_host,
489-
seq_len_host,
490482
# EXTRA METADATA
491483
batch_indices,
492484
positions,
@@ -568,7 +560,6 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
568560
paged_kv_indptr_host = paged_kv_indptr.cpu()
569561
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
570562
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
571-
seq_len_host = seq_len_tensor.cpu()
572563

573564
# Q,K,V are computed using GEMM.
574565
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@@ -612,7 +603,6 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
612603
paged_kv_last_page_len,
613604
paged_kv_last_page_len_host,
614605
seq_len_with_cache_host,
615-
seq_len_host,
616606
# EXTRA METADATA
617607
batch_indices,
618608
positions,
@@ -702,7 +692,6 @@ def test_flashinfer_attention_with_fp8_cache(
702692
paged_kv_indptr_host = paged_kv_indptr.cpu()
703693
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
704694
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
705-
seq_len_host = seq_len_tensor.cpu()
706695

707696
# Q,K,V are computed using GEMM, in fp16
708697
q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device)
@@ -776,7 +765,6 @@ def test_flashinfer_attention_with_fp8_cache(
776765
paged_kv_last_page_len,
777766
paged_kv_last_page_len_host,
778767
seq_len_with_cache_host,
779-
seq_len_host,
780768
# EXTRA METADATA
781769
batch_indices,
782770
positions,
@@ -858,7 +846,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
858846
paged_kv_indptr_host = paged_kv_indptr.cpu()
859847
paged_kv_last_page_len_host = paged_kv_last_page_len.cpu()
860848
seq_len_with_cache_host = (offsets + seq_len_tensor).cpu()
861-
seq_len_host = seq_len_tensor.cpu()
862849

863850
# make sure planner is initialized
864851
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
@@ -887,7 +874,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
887874
paged_kv_last_page_len,
888875
paged_kv_last_page_len_host,
889876
seq_len_with_cache_host,
890-
seq_len_host,
891877
# EXTRA METADATA
892878
batch_indices,
893879
positions,
@@ -957,7 +943,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
957943
paged_kv_indptr2_host = paged_kv_indptr2.cpu()
958944
paged_kv_last_page_len2_host = paged_kv_last_page_len2.cpu()
959945
seq_len_with_cache2_host = (offsets2 + seq_len_tensor2).cpu()
960-
seq_len2_host = seq_len_tensor2.cpu()
961946

962947
# Create FlashInferAttention class before calling the custom op
963948
_GlobalFlashInferPlanner.reset()
@@ -985,7 +970,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
985970
paged_kv_last_page_len2,
986971
paged_kv_last_page_len2_host,
987972
seq_len_with_cache2_host,
988-
seq_len2_host,
989973
# EXTRA METADATA
990974
batch_indices,
991975
positions,

0 commit comments

Comments
 (0)