@@ -1559,8 +1559,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
15591559 f"K[{ chunk .k_token_start } :{ chunk .k_token_end } ] ({ num_k } tokens)" )
15601560
15611561 indexer ._update_k_cache (k_fp8 , k_scale , metadata_chunked )
1562- topk_indices_chunked = indexer .sparse_attn_indexer (metadata_chunked ,
1563- hidden_states , q_fp8 ,
1562+ topk_indices_chunked = indexer .sparse_attn_indexer (metadata_chunked , q_fp8 ,
15641563 k_fp8 , k_scale , weights )
15651564
15661565 print (f"✓ Chunked execution completed, shape: { topk_indices_chunked .shape } " )
@@ -1592,8 +1591,8 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
15921591
15931592 indexer ._update_k_cache (k_fp8 , k_scale , metadata_baseline )
15941593 topk_indices_baseline = indexer .sparse_attn_indexer (metadata_baseline ,
1595- hidden_states , q_fp8 ,
1596- k_fp8 , k_scale , weights )
1594+ q_fp8 , k_fp8 , k_scale ,
1595+ weights )
15971596
15981597 print (
15991598 f"✓ Non-chunked execution completed, shape: { topk_indices_baseline .shape } "
@@ -1866,7 +1865,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
18661865
18671866 try :
18681867 topk_indices_custom = indexer .sparse_attn_indexer (metadata_custom ,
1869- hidden_states ,
18701868 q_fp8 ,
18711869 k_fp8 ,
18721870 k_scale ,
@@ -1892,7 +1890,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
18921890 Indexer .prepare (metadata_fallback )
18931891 indexer ._update_k_cache (k_fp8 , k_scale , metadata_fallback )
18941892 topk_indices_fallback = indexer .sparse_attn_indexer (metadata_fallback ,
1895- hidden_states ,
18961893 q_fp8 ,
18971894 k_fp8 ,
18981895 k_scale ,
@@ -1921,7 +1918,6 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
19211918 try :
19221919 topk_indices_skip = indexer .sparse_attn_indexer (
19231920 metadata_skip ,
1924- hidden_states ,
19251921 q_fp8 ,
19261922 k_fp8 ,
19271923 k_scale ,
@@ -2035,7 +2031,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk,
20352031
20362032 try :
20372033 topk_indices_custom = indexer .sparse_attn_indexer (metadata_custom ,
2038- hidden_states ,
20392034 q_fp8 ,
20402035 k_fp8 ,
20412036 k_scale ,
@@ -2055,7 +2050,6 @@ def test_indexer_prefill_chunked_custom_vs_fallback(batch_size, index_topk,
20552050 Indexer .prepare (metadata_fallback )
20562051 indexer ._update_k_cache (k_fp8 , k_scale , metadata_fallback )
20572052 topk_indices_fallback = indexer .sparse_attn_indexer (metadata_fallback ,
2058- hidden_states ,
20592053 q_fp8 ,
20602054 k_fp8 ,
20612055 k_scale ,
@@ -2143,7 +2137,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
21432137
21442138 try :
21452139 topk_indices_custom = indexer .sparse_attn_indexer (metadata_custom ,
2146- hidden_states ,
21472140 q_fp8 ,
21482141 k_fp8 ,
21492142 k_scale ,
@@ -2166,7 +2159,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
21662159 metadata_fallback .indexer_prefill_chunks = None
21672160
21682161 topk_indices_fallback = indexer .sparse_attn_indexer (metadata_fallback ,
2169- hidden_states ,
21702162 q_fp8 ,
21712163 k_fp8 ,
21722164 k_scale ,
@@ -2191,7 +2183,6 @@ def test_indexer_prefill_single_pass_custom_vs_fallback(batch_size, index_topk,
21912183
21922184 try :
21932185 topk_indices_skip = indexer .sparse_attn_indexer (metadata_skip ,
2194- hidden_states ,
21952186 q_fp8 ,
21962187 k_fp8 ,
21972188 k_scale ,
@@ -2302,7 +2293,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
23022293
23032294 # Test custom kernel
23042295 topk_custom = indexer .sparse_attn_indexer (metadata ,
2305- hidden_states ,
23062296 q_fp8 ,
23072297 k_fp8 ,
23082298 k_scale ,
@@ -2311,7 +2301,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
23112301
23122302 # Test fallback
23132303 topk_fallback = indexer .sparse_attn_indexer (metadata ,
2314- hidden_states ,
23152304 q_fp8 ,
23162305 k_fp8 ,
23172306 k_scale ,
@@ -2337,7 +2326,6 @@ def test_indexer_topk_multi_request_with_different_cache(enable_indexer_skip):
23372326 Indexer .prepare (metadata_skip )
23382327 indexer ._update_k_cache (k_fp8 , k_scale , metadata_skip )
23392328 topk_indices_skip = indexer .sparse_attn_indexer (metadata_skip ,
2340- hidden_states ,
23412329 q_fp8 ,
23422330 k_fp8 ,
23432331 k_scale ,
0 commit comments