From ab12ef072c36b1cfde00e213683f48e805abb2ea Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 5 Aug 2025 05:01:01 +0000 Subject: [PATCH] fix --- flashinfer/sparse.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 51c9d043c..8ac9fc65e 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -111,6 +111,7 @@ def __init__( self, float_workspace_buffer: torch.Tensor, backend: str = "auto", + kv_lens_buffer_size: int = 32768, ) -> None: r"""Constructs of :class:`BlockSparseAttentionWrapper`. @@ -124,6 +125,8 @@ def __init__( The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. If set to ``auto``, the function will automatically choose the backend based on the device architecture and kernel availability. + kv_lens_buffer_size : int + The size of the kv lens buffer (num_kv_heads * MB), defaults to 32768. """ self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device @@ -138,11 +141,11 @@ def __init__( ) # NOTE(Zihao): assume maximum batch size is 32768 self._vector_sparse_indptr_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device + (kv_lens_buffer_size,), dtype=torch.int32, device=self.device ) self._kv_lens_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device + (kv_lens_buffer_size,), dtype=torch.int32, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape,