@@ -77,6 +77,7 @@ def ref_paged_attn(
7777@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
7878@pytest .mark .parametrize ("dtype" , DTYPES )
7979@pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
80+ @pytest .mark .parametrize ("sliding_window" , [None , 64 ])
8081@torch .inference_mode
8182def test_flashinfer_decode_with_paged_kv (
8283 kv_lens : list [int ],
@@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv(
8586 dtype : torch .dtype ,
8687 block_size : int ,
8788 soft_cap : Optional [float ],
89+ sliding_window : Optional [int ],
8890) -> None :
8991 torch .set_default_device ("cuda" )
9092 current_platform .seed_everything (0 )
@@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv(
136138 use_tensor_cores = (
137139 (num_query_heads // num_kv_heads ) > 4 )
138140 )
139- wrapper .plan (kv_indptr ,
140- kv_indices ,
141- kv_last_page_lens ,
142- num_query_heads ,
143- num_kv_heads ,
144- head_size ,
145- block_size ,
146- "NONE" ,
147- q_data_type = dtype ,
148- kv_data_type = dtype ,
149- logits_soft_cap = soft_cap )
141+ wrapper .plan (
142+ kv_indptr ,
143+ kv_indices ,
144+ kv_last_page_lens ,
145+ num_query_heads ,
146+ num_kv_heads ,
147+ head_size ,
148+ block_size ,
149+ "NONE" ,
150+ window_left = sliding_window - 1 if sliding_window is not None else - 1 ,
151+ q_data_type = dtype ,
152+ kv_data_type = dtype ,
153+ logits_soft_cap = soft_cap ,
154+ )
150155
151156 output = wrapper .run (query , key_value_cache )
152157
@@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv(
157162 kv_lens = kv_lens ,
158163 block_tables = block_tables ,
159164 scale = scale ,
160- soft_cap = soft_cap )
165+ soft_cap = soft_cap ,
166+ sliding_window = sliding_window )
161167 torch .testing .assert_close (output , ref_output , atol = 1e-2 , rtol = 1e-2 ), \
162168 f"{ torch .max (torch .abs (output - ref_output ))} "
163169
@@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv(
168174@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
169175@pytest .mark .parametrize ("dtype" , DTYPES )
170176@pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
177+ @pytest .mark .parametrize ("sliding_window" , [None , 64 ])
171178@torch .inference_mode
172- def test_flashinfer_prefill_with_paged_kv (seq_lens : list [tuple [int , int ]],
173- num_heads : tuple [int , int ],
174- head_size : int , dtype : torch .dtype ,
175- block_size : int ,
176- soft_cap : Optional [float ]) -> None :
179+ def test_flashinfer_prefill_with_paged_kv (
180+ seq_lens : list [tuple [int , int ]],
181+ num_heads : tuple [int , int ],
182+ head_size : int ,
183+ dtype : torch .dtype ,
184+ block_size : int ,
185+ soft_cap : Optional [float ],
186+ sliding_window : Optional [int ],
187+ ) -> None :
177188 torch .set_default_device ("cuda" )
178189 current_platform .seed_everything (0 )
179190 num_seqs = len (seq_lens )
@@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
242253 num_kv_heads ,
243254 head_size ,
244255 block_size ,
256+ window_left = sliding_window - 1 if sliding_window is not None else - 1 ,
245257 q_data_type = dtype ,
246258 kv_data_type = dtype ,
247259 logits_soft_cap = soft_cap ,
@@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
259271 kv_lens = kv_lens ,
260272 block_tables = block_tables ,
261273 scale = scale ,
262- soft_cap = soft_cap )
274+ soft_cap = soft_cap ,
275+ sliding_window = sliding_window )
263276 torch .testing .assert_close (output , ref_output , atol = 5e-2 , rtol = 1e-2 ), \
264277 f"{ torch .max (torch .abs (output - ref_output ))} "
265278
0 commit comments