Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 22f3a4b

Browse files
authored
[Bugfix] lookahead block table with cuda graph max capture (vllm-project#8340)
[Bugfix] Ensure multistep lookahead allocation is compatible with cuda graph max capture (vllm-project#8340)
1 parent b1f3e18 commit 22f3a4b

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

vllm/attention/backends/flash_attn.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,19 @@ def build(self, seq_lens: List[int], query_lens: List[int],
471471
# The shape of graph_block_tables is
472472
# [max batch size, max context len // block size].
473473
input_block_tables = self.runner.graph_block_tables[:batch_size]
474+
max_blocks = input_block_tables.shape[1]
474475
for i, block_table in enumerate(self.block_tables):
475476
if block_table:
476-
input_block_tables[i, :len(block_table)] = block_table
477+
num_blocks = len(block_table)
478+
if num_blocks <= max_blocks:
479+
input_block_tables[i, :num_blocks] = block_table
480+
else:
481+
# It may be possible to have more blocks allocated due
482+
# to lookahead slots of multi-step, however, they are
483+
# not used anyway, so can be safely ignored.
484+
input_block_tables[
485+
i, :max_blocks] = block_table[:max_blocks]
486+
477487
block_tables = torch.from_numpy(input_block_tables).to(
478488
device=device, non_blocking=True)
479489
else:

0 commit comments

Comments
 (0)