-
Notifications
You must be signed in to change notification settings - Fork 477
refactor: update trtllm-gen attn interface #1407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
1ea7c32
8743781
bb69587
d07345e
53c5783
1df7d3e
e7a5ff9
b0010b7
c58276d
3c26b74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3052,7 +3052,6 @@ def trtllm_batch_context_with_kv_cache( | |||||
block_tables: torch.Tensor, | ||||||
seq_lens: torch.Tensor, | ||||||
max_q_len: int, | ||||||
max_kv_len: int, | ||||||
bmm1_scale: float, | ||||||
bmm2_scale: float, | ||||||
batch_size: int, | ||||||
|
@@ -3081,8 +3080,6 @@ def trtllm_batch_context_with_kv_cache( | |||||
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` | ||||||
max_q_len : int | ||||||
max sequence length for query | ||||||
max_kv_len : int | ||||||
max sequence length for kv_cache | ||||||
bmm1_scale : float | ||||||
fused scale for bmm1 input. | ||||||
bmm2_scale : float | ||||||
|
@@ -3178,6 +3175,9 @@ def trtllm_batch_context_with_kv_cache( | |||||
else: | ||||||
raise ValueError(f"Invalid out_dtype: {out_dtype}") | ||||||
|
||||||
page_size = k_cache.shape[3] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is an indexing error when calculating The current code uses
Suggested change
|
||||||
num_pages = block_tables.shape[1] | ||||||
max_kv_len = num_pages * page_size | ||||||
run_func( | ||||||
out, | ||||||
out_scale_factor, | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be an indexing error in the calculation of
page_size
. Based on the function's docstring and the tensor unpacking logic,k_cache
has a shape of[num_pages, num_kv_heads, page_size, head_dim]
when kv_layout is HND.The current implementation uses
k_cache.shape[3]
, which corresponds tohead_dim
, notpage_size
. The correct index forpage_size
is2
.This will cause an incorrect
max_seq_len
to be calculated, which is a critical bug that could lead to incorrect kernel behavior or memory access errors.