Skip to content

Commit ba3f324

Browse files
authored
bugfix: fix integer overflow in FA2 customized_mask & add buffer overflow warning. (#1290)
<!-- .github/pull_request_template.md --> ## 📌 Description 1. Per discussion with @haochengxi and @Radioheading, this PR moves the `plan` function in `VariableBlockSparseAttentionWrapper` to the GPU side, to avoid expensive (hundreds ms) host operations. 2. This PR also enlarges the default internal buffer size to accommodate video DiT use cases. 3. This PR fixes the **INT overflow** during offset calculation in attention map. This causes errors in `customized_mask` mode of FA2 prefill template. E.g., with a `kv_len=128K`, the last element of the attention map will be `128*128*1e6=1e10`, which is larger than `INT32_MAX`. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues This PR should solve #1271 <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent e353f11 commit ba3f324

File tree

5 files changed

+94
-30
lines changed

5 files changed

+94
-30
lines changed

flashinfer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,7 @@
117117
from .sampling import top_p_renorm_probs as top_p_renorm_probs
118118
from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs
119119
from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper
120+
from .sparse import (
121+
VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper,
122+
)
120123
from .utils import next_positive_power_of_2 as next_positive_power_of_2

flashinfer/sparse.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,10 @@ def __init__(
131131
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
132132
)
133133
if backend in ["fa3", "auto"]:
134-
# NOTE(Zihao): assume maximum accumulate kv length is 4M
134+
# NOTE(Zihao): assume maximum accumulate kv length is 128M
135+
# NOTE(Yilong): 128M is required by video DiT models
135136
self._vector_sparse_indices_buffer = torch.empty(
136-
(4 * 1024 * 1024,), dtype=torch.int32, device=self.device
137+
(128 * 1024 * 1024,), dtype=torch.int32, device=self.device
137138
)
138139
# NOTE(Zihao): assume maximum batch size is 32768
139140
self._vector_sparse_indptr_buffer = torch.empty(
@@ -164,7 +165,11 @@ def __init__(
164165
self._backend = backend
165166

166167
def reset_workspace_buffer(
167-
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor
168+
self,
169+
float_workspace_buffer: torch.Tensor,
170+
int_workspace_buffer: torch.Tensor,
171+
vector_sparse_indices_buffer: Optional[torch.Tensor] = None,
172+
vector_sparse_indptr_buffer: Optional[torch.Tensor] = None,
168173
) -> None:
169174
r"""Reset the workspace buffer.
170175
@@ -186,6 +191,12 @@ def reset_workspace_buffer(
186191
pin_memory=True,
187192
)
188193

194+
# Enable user-defined size
195+
if vector_sparse_indices_buffer is not None:
196+
self._vector_sparse_indices_buffer = vector_sparse_indices_buffer
197+
if vector_sparse_indptr_buffer is not None:
198+
self._vector_sparse_indptr_buffer = vector_sparse_indptr_buffer
199+
189200
def plan(
190201
self,
191202
indptr: torch.Tensor,
@@ -589,6 +600,14 @@ def run(
589600

590601
if self._use_tensor_cores:
591602
if self._backend == "fa3":
603+
if (
604+
self._vector_sparse_indices_buffer.numel()
605+
<= self._paged_kv_indices_buf.numel() * self.C
606+
):
607+
raise ValueError(
608+
"_vector_sparse_indices_buffer is not large enough. Please increase the size."
609+
)
610+
592611
sparse_indices = block_sparse_indices_to_vector_sparse_offsets(
593612
self._paged_kv_indices_buf,
594613
self._paged_kv_indptr_buf,
@@ -725,11 +744,9 @@ def __init__(
725744
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
726745
)
727746
if backend in ["fa3", "auto"]:
728-
# NOTE(Zihao): assume maximum accumulate kv length is 4M
729747
self._vector_sparse_indices_buffer = torch.empty(
730-
(4 * 1024 * 1024,), dtype=torch.int32, device=self.device
748+
(128 * 1024 * 1024,), dtype=torch.int32, device=self.device
731749
)
732-
# NOTE(Zihao): assume maximum batch size is 32768
733750
self._vector_sparse_indptr_buffer = torch.empty(
734751
(32768,), dtype=torch.int32, device=self.device
735752
)
@@ -752,7 +769,11 @@ def __init__(
752769
self._backend = backend
753770

754771
def reset_workspace_buffer(
755-
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor
772+
self,
773+
float_workspace_buffer: torch.Tensor,
774+
int_workspace_buffer: torch.Tensor,
775+
vector_sparse_indices_buffer: Optional[torch.Tensor] = None,
776+
vector_sparse_indptr_buffer: Optional[torch.Tensor] = None,
756777
) -> None:
757778
r"""Reset the workspace buffer.
758779
@@ -774,6 +795,12 @@ def reset_workspace_buffer(
774795
pin_memory=True,
775796
)
776797

798+
# Enable user-defined size
799+
if vector_sparse_indices_buffer is not None:
800+
self._vector_sparse_indices_buffer = vector_sparse_indices_buffer
801+
if vector_sparse_indptr_buffer is not None:
802+
self._vector_sparse_indptr_buffer = vector_sparse_indptr_buffer
803+
777804
def plan(
778805
self,
779806
block_mask_map: torch.Tensor,
@@ -860,14 +887,14 @@ def plan(
860887

861888
# q layout: [seq_len, num_kv_heads, gqa_group_size, head_dim]
862889
# padded into: [seq_len * num_kv_heads, 1, gqa_group_size, head_dim]
863-
qo_indptr_host = torch.cat(
890+
qo_indptr = torch.cat(
864891
[
865892
torch.zeros(1, dtype=torch.int32, device=block_row_sz.device),
866893
torch.cumsum(block_row_sz.flatten(), dim=0, dtype=torch.int32),
867894
],
868895
dim=0,
869896
)
870-
qo_indptr = qo_indptr_host.to(block_mask_map.device, non_blocking=non_blocking)
897+
qo_indptr_host = qo_indptr.to("cpu", non_blocking=non_blocking)
871898
last_block_len = torch.full(
872899
(num_blocks_row * num_kv_heads,),
873900
1,
@@ -926,36 +953,37 @@ def _block_mask_map_to_expanded_indices(
926953
dtype=dtype_i, device=device
927954
)
928955

929-
kv_indptr_host, kv_indices_host = _block_mask_map_to_expanded_indices(
956+
kv_indptr, kv_indices = _block_mask_map_to_expanded_indices(
930957
block_mask_map, block_col_sz
931958
)
959+
kv_indptr_host = kv_indptr.to("cpu", non_blocking=non_blocking)
960+
kv_indices_host = kv_indices.to("cpu", non_blocking=non_blocking)
932961

933962
self._qo_indptr = qo_indptr.to(self.device, non_blocking=non_blocking)
934-
self._paged_kv_indptr_buf = kv_indptr_host.to(
935-
self.device, non_blocking=non_blocking
936-
)
937-
self._paged_kv_indices_buf = kv_indices_host.to(
963+
self._paged_kv_indptr_buf = kv_indptr.to(self.device, non_blocking=non_blocking)
964+
self._paged_kv_indices_buf = kv_indices.to(
938965
self.device, non_blocking=non_blocking
939966
)
940967
self._paged_kv_last_page_len = last_block_len.to(
941968
self.device, non_blocking=non_blocking
942969
)
970+
torch.cuda.synchronize() # for non-blocking copy
943971
self._mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value
944972

945973
# Sanity check
946974
assert (
947975
num_qo_heads % num_kv_heads == 0
948976
), "num_qo_heads must be a multiple of num_kv_heads"
949977
assert num_blocks_row * num_kv_heads + 1 == kv_indptr_host.shape[0]
950-
assert kv_indptr_host[-1].item() == kv_indices_host.shape[0]
978+
assert (
979+
kv_indptr_host[-1].item() == kv_indices_host.shape[0]
980+
), f"{kv_indptr_host[-1].item()} != {kv_indices_host.shape[0]}"
951981
assert num_kv_heads == block_mask_map.shape[0]
952982
assert num_kv_heads == block_row_sz.shape[0]
953983
assert num_kv_heads == block_col_sz.shape[0]
954984
assert num_blocks_row == block_mask_map.shape[1]
955985
assert num_blocks_col == block_mask_map.shape[2]
956986

957-
kv_indptr_host = kv_indptr_host.to("cpu")
958-
959987
if self._backend == "auto":
960988
self._backend = determine_attention_backend(
961989
self.device,
@@ -986,8 +1014,12 @@ def _block_mask_map_to_expanded_indices(
9861014
)
9871015

9881016
if self._backend == "fa3":
989-
self._vector_sparse_indptr_buffer[: len(kv_indptr_host)].copy_(
990-
kv_indptr_host, non_blocking=non_blocking
1017+
if self._vector_sparse_indptr_buffer.numel() <= kv_indptr.numel():
1018+
raise ValueError(
1019+
"_vector_sparse_indptr_buffer is not large enough. Please increase the buffer size."
1020+
)
1021+
self._vector_sparse_indptr_buffer[: len(kv_indptr)].copy_(
1022+
kv_indptr, non_blocking=non_blocking
9911023
)
9921024

9931025
self._plan_info = self._cached_module.plan(
@@ -1135,6 +1167,14 @@ def run(
11351167
_check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
11361168

11371169
if self._backend == "fa3":
1170+
if (
1171+
self._vector_sparse_indices_buffer.numel()
1172+
<= self._paged_kv_indices_buf.numel()
1173+
):
1174+
raise ValueError(
1175+
"_vector_sparse_indices_buffer is not large enough. Please increase the buffer size."
1176+
)
1177+
11381178
sparse_indices = block_sparse_indices_to_vector_sparse_offsets(
11391179
self._paged_kv_indices_buf,
11401180
self._paged_kv_indptr_buf,

include/flashinfer/attention/variants.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct DefaultAttention : AttentionVariantBase {
8383
if (qo_idx >= qo_len || kv_idx >= kv_len) {
8484
mask = false;
8585
} else {
86-
const uint32_t offset = qo_idx * kv_len + kv_idx;
86+
const uint64_t offset = static_cast<uint64_t>(qo_idx) * kv_len + kv_idx;
8787
mask &= ((custom_mask_ptr[offset / 8] >> (offset % 8)) & 1);
8888
}
8989
}

include/flashinfer/quantization.cuh

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,19 @@ enum class BitOrder { kBig = 0U, kLittle = 1U };
3838

3939
template <BitOrder BITORDER>
4040
__global__ void PackBitsKernel(bool* input, uint8_t* output, int64_t num_elements) {
41-
int64_t start_offset = blockIdx.x * blockDim.x * 8, tx = threadIdx.x;
41+
int64_t start_offset = static_cast<int64_t>(blockIdx.x) * blockDim.x * 8, tx = threadIdx.x;
4242
uint8_t ret = 0;
4343
bool input_vec[8];
4444
typedef cub::BlockLoad<bool, 256, 8, cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
4545
__shared__ typename BlockLoad::TempStorage temp_storage;
46-
BlockLoad(temp_storage)
47-
.Load(input + start_offset, input_vec, num_elements - start_offset, /*default=*/0);
46+
47+
// This fix the INT32_T overflow issue, which is possible in DiT video models
48+
// where the kv_len could be 128K.
49+
// ref:
50+
// https://github.com/NVIDIA/cub/blob/0fc3c3701632a4be906765b73be20a9ad0da603d/cub/block/block_load.cuh#L711C13-L711C100
51+
int block_items_end =
52+
(num_elements - start_offset > INT32_MAX) ? INT32_MAX : num_elements - start_offset;
53+
BlockLoad(temp_storage).Load(input + start_offset, input_vec, block_items_end, /*default=*/0);
4854

4955
if constexpr (BITORDER == BitOrder::kBig) {
5056
ret = (input_vec[0] << 7) | (input_vec[1] << 6) | (input_vec[2] << 5) | (input_vec[3] << 4) |

tests/test_block_sparse.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ def bsr_attention_ref(
6767
return o
6868

6969

70+
def set_seed(seed: int = 42):
71+
torch.cuda.manual_seed(seed)
72+
torch.manual_seed(seed)
73+
np.random.seed(seed)
74+
75+
7076
@pytest.mark.parametrize("R", [1, 4, 16])
7177
@pytest.mark.parametrize("C", [1, 4, 16])
7278
@pytest.mark.parametrize("M", [64, 128, 256])
@@ -80,7 +86,10 @@ def test_block_sparse_attention(
8086
):
8187
if num_qo_heads % num_kv_heads != 0:
8288
pytest.skip("num_qo_heads must be divisible by num_kv_heads")
89+
90+
set_seed(33)
8391
rng = np.random.default_rng()
92+
8493
MB = M // R
8594
NB = N // C
8695
S = sp.sparse.random(MB, NB, density=0.25, random_state=rng).tocsr()
@@ -182,6 +191,8 @@ def test_variable_block_sparse_attention_wrapper(
182191
if seq_len // num_blocks_col < 1:
183192
pytest.skip("seq_len must be greater than num_blocks_col")
184193

194+
set_seed(330)
195+
185196
def random_partition_batch(
186197
seq_len: int,
187198
num_blocks: int,
@@ -209,7 +220,7 @@ def random_partition_batch(
209220
assert sizes.max() <= seq_len
210221
assert torch.all(sizes.sum(dim=-1) == seq_len)
211222

212-
return sizes
223+
return sizes.to(device=device)
213224

214225
def _test_variable_block_sparse_attention(
215226
num_qo_heads: int,
@@ -260,12 +271,15 @@ def _test_variable_block_sparse_attention(
260271
)
261272
torch.testing.assert_close(o[kv_head_idx], o_ref, atol=1e-2, rtol=1e-2)
262273

263-
block_row_sz = random_partition_batch(seq_len, num_blocks_row, num_kv_heads)
264-
block_col_sz = random_partition_batch(seq_len, num_blocks_col, num_kv_heads)
274+
block_row_sz = random_partition_batch(
275+
seq_len, num_blocks_row, num_kv_heads, device="cuda:0"
276+
)
277+
block_col_sz = random_partition_batch(
278+
seq_len, num_blocks_col, num_kv_heads, device="cuda:0"
279+
)
265280
block_mask_map = (
266281
torch.rand(num_kv_heads, num_blocks_row, num_blocks_col) > block_density
267-
)
268-
block_mask_map = block_mask_map.to(dtype=torch.bool, device="cpu")
282+
).to(device="cuda:0")
269283

270284
_test_variable_block_sparse_attention(
271285
num_qo_heads,
@@ -278,5 +292,6 @@ def _test_variable_block_sparse_attention(
278292

279293

280294
if __name__ == "__main__":
281-
test_block_sparse_attention(1, 1, 64, 64, 1, 1, 128, False)
282-
test_block_sparse_attention(16, 16, 256, 256, 16, 16, 256, True)
295+
# This test verifies the INT32_T overflow issue.
296+
for seq_len in [16 * 1024, 32 * 1024, 40 * 1024, 48 * 1024, 64 * 1024]:
297+
test_block_sparse_attention(128, 128, seq_len, seq_len, 1, 1, 128, False)

0 commit comments

Comments
 (0)