Skip to content

Commit abcf8cb

Browse files
committed
fix test.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
1 parent 2a5d5e9 commit abcf8cb

File tree

1 file changed

+3
-15
lines changed

1 file changed

+3
-15
lines changed

tests/unittest/_torch/attention/sparse/test_dsa_indexer.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,8 +1559,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
15591559
f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)")
15601560

15611561
indexer._update_k_cache(k_fp8, k_scale, metadata_chunked)
1562-
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked,
1563-
hidden_states, q_fp8,
1562+
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked, q_fp8,
15641563
k_fp8, k_scale, weights)
15651564

15661565
print(f"✓ Chunked execution completed, shape: {topk_indices_chunked.shape}")
@@ -1592,8 +1591,8 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
15921591

15931592
indexer._update_k_cache(k_fp8, k_scale, metadata_baseline)
15941593
topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline,
1595-
hidden_states, q_fp8,
1596-
k_fp8, k_scale, weights)
1594+
q_fp8, k_fp8, k_scale,
1595+
weights)
15971596

15981597
print(
15991598
f"✓ Non-chunked execution completed, shape: {topk_indices_baseline.shape}"
@@ -1866,7 +1865,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
18661865

18671866
try:
18681867
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
1869-
hidden_states,
18701868
q_fp8,
18711869
k_fp8,
18721870
k_scale,
@@ -1892,7 +1890,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
18921890
Indexer.prepare(metadata_fallback)
18931891
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)
18941892
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
1895-
hidden_states,
18961893
q_fp8,
18971894
k_fp8,
18981895
k_scale,
@@ -1921,7 +1918,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
19211918
try:
19221919
topk_indices_skip = indexer.sparse_attn_indexer(
19231920
metadata_skip,
1924-
hidden_states,
19251921
q_fp8,
19261922
k_fp8,
19271923
k_scale,
@@ -2035,7 +2031,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk,
20352031

20362032
try:
20372033
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
2038-
hidden_states,
20392034
q_fp8,
20402035
k_fp8,
20412036
k_scale,
@@ -2055,7 +2050,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk,
20552050
Indexer.prepare(metadata_fallback)
20562051
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)
20572052
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
2058-
hidden_states,
20592053
q_fp8,
20602054
k_fp8,
20612055
k_scale,
@@ -2143,7 +2137,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
21432137

21442138
try:
21452139
topk_indices_custom = indexer.sparse_attn_indexer(metadata_custom,
2146-
hidden_states,
21472140
q_fp8,
21482141
k_fp8,
21492142
k_scale,
@@ -2166,7 +2159,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
21662159
metadata_fallback.indexer_prefill_chunks = None
21672160

21682161
topk_indices_fallback = indexer.sparse_attn_indexer(metadata_fallback,
2169-
hidden_states,
21702162
q_fp8,
21712163
k_fp8,
21722164
k_scale,
@@ -2191,7 +2183,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
21912183

21922184
try:
21932185
topk_indices_skip = indexer.sparse_attn_indexer(metadata_skip,
2194-
hidden_states,
21952186
q_fp8,
21962187
k_fp8,
21972188
k_scale,
@@ -2302,7 +2293,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
23022293

23032294
# Test custom kernel
23042295
topk_custom = indexer.sparse_attn_indexer(metadata,
2305-
hidden_states,
23062296
q_fp8,
23072297
k_fp8,
23082298
k_scale,
@@ -2311,7 +2301,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
23112301

23122302
# Test fallback
23132303
topk_fallback = indexer.sparse_attn_indexer(metadata,
2314-
hidden_states,
23152304
q_fp8,
23162305
k_fp8,
23172306
k_scale,
@@ -2337,7 +2326,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
23372326
Indexer.prepare(metadata_skip)
23382327
indexer._update_k_cache(k_fp8, k_scale, metadata_skip)
23392328
topk_indices_skip = indexer.sparse_attn_indexer(metadata_skip,
2340-
hidden_states,
23412329
q_fp8,
23422330
k_fp8,
23432331
k_scale,

0 commit comments

Comments
 (0)