|
16 | 16 |
|
17 | 17 | import pytest
|
18 | 18 | import torch
|
| 19 | +from functools import partial |
19 | 20 | from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules
|
20 | 21 |
|
21 | 22 | import flashinfer
|
@@ -185,6 +186,155 @@ def test_batch_decode_with_paged_kv_cache(
|
185 | 186 | torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3)
|
186 | 187 |
|
187 | 188 |
|
| 189 | +@pytest.mark.parametrize("batch_size", [12, 17, 128]) |
| 190 | +@pytest.mark.parametrize("kv_len", [54, 97, 512, 2048, 16384]) |
| 191 | +@pytest.mark.parametrize("page_size", [1, 8, 16]) |
| 192 | +@pytest.mark.parametrize("num_kv_heads", [4]) |
| 193 | +@pytest.mark.parametrize("num_qo_heads", [4, 32]) |
| 194 | +@pytest.mark.parametrize("head_dim", [128, 256]) |
| 195 | +@pytest.mark.parametrize("kv_layout", ["NHD"]) |
| 196 | +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) |
| 197 | +@pytest.mark.parametrize("logits_soft_cap", [0.0]) |
| 198 | +@pytest.mark.parametrize("return_lse", [True]) |
| 199 | +@pytest.mark.parametrize("q_dtype", [torch.float16]) |
| 200 | +@pytest.mark.parametrize("kv_dtype", [torch.float16, torch.float8_e4m3fn]) |
| 201 | +@pytest.mark.parametrize("contiguous_kv", [True]) |
| 202 | +def test_batch_decode_with_paged_kv_cache_with_fast_plan( |
| 203 | + batch_size, |
| 204 | + kv_len, |
| 205 | + page_size, |
| 206 | + num_kv_heads, |
| 207 | + num_qo_heads, |
| 208 | + head_dim, |
| 209 | + kv_layout, |
| 210 | + pos_encoding_mode, |
| 211 | + logits_soft_cap, |
| 212 | + return_lse, |
| 213 | + q_dtype, |
| 214 | + kv_dtype, |
| 215 | + contiguous_kv, |
| 216 | +): |
| 217 | + q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) |
| 218 | + num_pages_per_seq = (kv_len + page_size - 1) // page_size |
| 219 | + total_num_pages = num_pages_per_seq * batch_size |
| 220 | + |
| 221 | + if kv_layout == "HND": |
| 222 | + kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] |
| 223 | + else: |
| 224 | + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] |
| 225 | + if not contiguous_kv: |
| 226 | + tmp = [kv_shape[0]] |
| 227 | + for v in kv_shape[1:]: |
| 228 | + tmp.append(2) |
| 229 | + tmp.append(v) |
| 230 | + kv_shape = tmp |
| 231 | + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") |
| 232 | + kv_data = kv_data_fp32.to(kv_dtype) |
| 233 | + kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] |
| 234 | + kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] |
| 235 | + # actual data is stored in non-contiguous memory |
| 236 | + assert ( |
| 237 | + kv_data.stride(-4) |
| 238 | + != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] |
| 239 | + ) |
| 240 | + else: |
| 241 | + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") |
| 242 | + kv_data = kv_data_fp32.to(kv_dtype) |
| 243 | + kv_indptr = ( |
| 244 | + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) |
| 245 | + * num_pages_per_seq |
| 246 | + ) |
| 247 | + kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) |
| 248 | + kv_last_page_len = torch.full( |
| 249 | + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" |
| 250 | + ) |
| 251 | + |
| 252 | + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0") |
| 253 | + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( |
| 254 | + workspace_buffer, kv_layout |
| 255 | + ) |
| 256 | + wrapper.plan( |
| 257 | + kv_indptr, |
| 258 | + kv_indices, |
| 259 | + kv_last_page_len, |
| 260 | + num_qo_heads, |
| 261 | + num_kv_heads, |
| 262 | + head_dim, |
| 263 | + page_size, |
| 264 | + logits_soft_cap=logits_soft_cap, |
| 265 | + pos_encoding_mode=pos_encoding_mode, |
| 266 | + data_type=kv_dtype, |
| 267 | + q_data_type=q_dtype, |
| 268 | + ) |
| 269 | + wrapper.plan = partial(flashinfer.fast_decode_plan, wrapper) |
| 270 | + wrapper.plan( |
| 271 | + kv_indptr, |
| 272 | + kv_indices, |
| 273 | + kv_last_page_len, |
| 274 | + num_qo_heads, |
| 275 | + num_kv_heads, |
| 276 | + head_dim, |
| 277 | + page_size, |
| 278 | + logits_soft_cap=logits_soft_cap, |
| 279 | + pos_encoding_mode=pos_encoding_mode, |
| 280 | + data_type=kv_dtype, |
| 281 | + q_data_type=q_dtype, |
| 282 | + non_blocking=True, |
| 283 | + ) |
| 284 | + if return_lse: |
| 285 | + o, _ = wrapper.run(q, kv_data, return_lse=True) |
| 286 | + else: |
| 287 | + o = wrapper.run(q, kv_data) |
| 288 | + |
| 289 | + for i in range(batch_size): |
| 290 | + perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] |
| 291 | + perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] |
| 292 | + qi = q[i] |
| 293 | + ki = torch.cat( |
| 294 | + [ |
| 295 | + kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] |
| 296 | + .permute(*perm_dims) |
| 297 | + .reshape(-1, num_kv_heads, head_dim), |
| 298 | + ( |
| 299 | + kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] |
| 300 | + if kv_layout == "HND" |
| 301 | + else kv_data_fp32[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :] |
| 302 | + ) |
| 303 | + .permute(*perm_dims_last) |
| 304 | + .reshape(-1, num_kv_heads, head_dim), |
| 305 | + ], |
| 306 | + dim=0, |
| 307 | + ).to(kv_dtype) |
| 308 | + vi = torch.cat( |
| 309 | + [ |
| 310 | + kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] |
| 311 | + .permute(*perm_dims) |
| 312 | + .reshape(-1, num_kv_heads, head_dim), |
| 313 | + ( |
| 314 | + kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] |
| 315 | + if kv_layout == "HND" |
| 316 | + else kv_data_fp32[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :] |
| 317 | + ) |
| 318 | + .permute(*perm_dims_last) |
| 319 | + .reshape(-1, num_kv_heads, head_dim), |
| 320 | + ], |
| 321 | + dim=0, |
| 322 | + ).to(kv_dtype) |
| 323 | + o_ref_i = flashinfer.decode.single_decode_with_kv_cache( |
| 324 | + qi, |
| 325 | + ki, |
| 326 | + vi, |
| 327 | + pos_encoding_mode=pos_encoding_mode, |
| 328 | + logits_soft_cap=logits_soft_cap, |
| 329 | + ) |
| 330 | + torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) |
| 331 | + |
| 332 | + # test user-allocated output |
| 333 | + o_buffer = torch.empty_like(o) |
| 334 | + wrapper.run(q, kv_data, out=o_buffer) |
| 335 | + torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) |
| 336 | + |
| 337 | + |
188 | 338 | @pytest.mark.parametrize("batch_size", [12, 17, 128])
|
189 | 339 | @pytest.mark.parametrize("kv_len", [54, 97, 512, 2048, 16384])
|
190 | 340 | @pytest.mark.parametrize("page_size", [1, 8, 16])
|
|
0 commit comments