Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,7 +1982,6 @@ def trtllm_batch_decode_with_kv_cache(
workspace_buffer: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
bmm1_scale: float,
bmm2_scale: float, # todo(Yingyi): add dynamic scale tensor later
window_left: int = -1,
Expand Down Expand Up @@ -2011,9 +2010,6 @@ def trtllm_batch_decode_with_kv_cache(
seq_lens : torch.Tensor
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``

max_seq_len : int
max sequence length for kv_cache

bmm1_scale : float
fused scale for bmm1 input.

Expand Down Expand Up @@ -2110,6 +2106,9 @@ def trtllm_batch_decode_with_kv_cache(
else:
raise ValueError(f"Invalid out_dtype: {out_dtype}")

page_size = k_cache.shape[3]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be an indexing error in the calculation of page_size. Based on the function's docstring and the tensor unpacking logic, k_cache has a shape of [num_pages, num_kv_heads, page_size, head_dim] when kv_layout is HND.

The current implementation uses k_cache.shape[3], which corresponds to head_dim, not page_size. The correct index for page_size is 2.

This will cause an incorrect max_seq_len to be calculated, which is a critical bug that could lead to incorrect kernel behavior or memory access errors.

Suggested change
page_size = k_cache.shape[3]
page_size = k_cache.shape[2]

num_pages = block_tables.shape[1]
max_seq_len = num_pages * page_size
run_func(
out,
out_scale_factor,
Expand Down
6 changes: 3 additions & 3 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3052,7 +3052,6 @@ def trtllm_batch_context_with_kv_cache(
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_q_len: int,
max_kv_len: int,
bmm1_scale: float,
bmm2_scale: float,
batch_size: int,
Expand Down Expand Up @@ -3081,8 +3080,6 @@ def trtllm_batch_context_with_kv_cache(
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``
max_q_len : int
max sequence length for query
max_kv_len : int
max sequence length for kv_cache
bmm1_scale : float
fused scale for bmm1 input.
bmm2_scale : float
Expand Down Expand Up @@ -3178,6 +3175,9 @@ def trtllm_batch_context_with_kv_cache(
else:
raise ValueError(f"Invalid out_dtype: {out_dtype}")

page_size = k_cache.shape[3]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is an indexing error when calculating page_size. According to the function's docstring and tensor unpacking logic, k_cache has a shape of [num_pages, num_kv_heads, page_size, head_dim] when kv_layout is HND. Therefore, page_size should be accessed via k_cache.shape[2].

The current code uses k_cache.shape[3], which incorrectly retrieves the head_dim. This will lead to an incorrect max_kv_len and is a critical bug.

Suggested change
page_size = k_cache.shape[3]
page_size = k_cache.shape[2]

num_pages = block_tables.shape[1]
max_kv_len = num_pages * page_size
run_func(
out,
out_scale_factor,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_trtllm_gen_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def test_trtllm_batch_context_wrapper(
block_tables=block_tables,
seq_lens=seq_lens,
max_q_len=qo_len,
max_kv_len=kv_len,
bmm1_scale=q_scale / math.sqrt(head_dim),
bmm2_scale=1,
batch_size=batch_size,
Expand Down Expand Up @@ -383,7 +382,6 @@ def test_trtllm_batch_prefill(
block_tables,
seq_lens_gpu,
max_q_len,
max_seq_len,
q_scale * k_scale * sm_scale, # bmm1_scale
v_scale / o_scale, # bmm2_scale
batch_size,
Expand Down
1 change: 0 additions & 1 deletion tests/test_trtllm_gen_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def test_trtllm_batch_decode_fmha(
workspace_buffer,
block_tables,
seq_lens_gpu,
max_seq_len,
q_scale * k_scale * sm_scale, # bmm1_scale
v_scale / o_scale, # bmm2_scale
window_left, # window_left
Expand Down