File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed
tensorrt_llm/_torch/attention_backend Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -839,6 +839,10 @@ def prepare(self) -> None:
839839 self .prepare_flash_mla ()
840840 # number of tokens needed in the kv cache for each sequence after the next pass
841841 kv_lens = cached_token_lens + self .seq_lens_kv if cached_token_lens is not None else self .seq_lens_kv
842+ # Store actual KV length (without extra tokens) for use in kv_lens_runtime.
843+ # num_extra_kv_tokens are for internal cache management but should not be reported
844+ # as actual past KV length in host_past_key_value_lengths.
845+ self .kv_lens_actual = kv_lens .clone ()
842846 # self.kv_lens is the valid kv cache length, while the self.kv_lens_cuda is
843847 # the sequence length including the cached tokens and the input tokens.
844848 self .kv_lens [:self .num_seqs ].copy_ (
@@ -881,7 +885,9 @@ def prepare(self) -> None:
881885 ) <= self .kv_cache_manager .max_seq_len , error_message
882886
883887 self .kv_lens_cuda_runtime = self .kv_lens_cuda [:self .num_seqs ]
884- self .kv_lens_runtime = self .kv_lens [:self .num_seqs ]
888+ # Use actual KV length (without extra tokens) for kv_lens_runtime,
889+ # which becomes host_past_key_value_lengths and eventually mMaxSeqLenKv.
890+ self .kv_lens_runtime = self .kv_lens_actual [:self .num_seqs ]
885891 self .prompt_lens_cuda_runtime = self .prompt_lens_cuda [:self .num_seqs ]
886892 self .prompt_lens_cpu_runtime = self .prompt_lens_cpu [:self .num_seqs ]
887893 self .host_request_types_runtime = self .host_request_types [:self .
You can’t perform that action at this time.
0 commit comments