Skip to content

Commit 8fa324d

Browse files
committed
https://nvbugs/5590408: Fix the setting of mMaxSeqLenKv
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
1 parent 24f5cd7 commit 8fa324d

File tree

1 file changed

+7
-1
lines changed
  • tensorrt_llm/_torch/attention_backend

1 file changed

+7
-1
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,10 @@ def prepare(self) -> None:
839839
self.prepare_flash_mla()
840840
# number of tokens needed in the kv cache for each sequence after the next pass
841841
kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv
842+
# Store actual KV length (without extra tokens) for use in kv_lens_runtime.
843+
# num_extra_kv_tokens are for internal cache management but should not be reported
844+
# as actual past KV length in host_past_key_value_lengths.
845+
self.kv_lens_actual = kv_lens.clone()
842846
# self.kv_lens is the valid kv cache length, while the self.kv_lens_cuda is
843847
# the sequence length including the cached tokens and the input tokens.
844848
self.kv_lens[:self.num_seqs].copy_(
@@ -881,7 +885,9 @@ def prepare(self) -> None:
881885
) <= self.kv_cache_manager.max_seq_len, error_message
882886

883887
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
884-
self.kv_lens_runtime = self.kv_lens[:self.num_seqs]
888+
# Use actual KV length (without extra tokens) for kv_lens_runtime,
889+
# which becomes host_past_key_value_lengths and eventually mMaxSeqLenKv.
890+
self.kv_lens_runtime = self.kv_lens_actual[:self.num_seqs]
885891
self.prompt_lens_cuda_runtime = self.prompt_lens_cuda[:self.num_seqs]
886892
self.prompt_lens_cpu_runtime = self.prompt_lens_cpu[:self.num_seqs]
887893
self.host_request_types_runtime = self.host_request_types[:self.

0 commit comments

Comments
 (0)