Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,9 @@ def warmup(self, resource_manager: ResourceManager) -> None:
self._run_torch_compile_warmup(resource_manager)
self._run_autotuner_warmup(resource_manager)
self._run_cuda_graph_warmup(resource_manager)
if not kv_cache_manager.is_estimating_kv_cache and not self.is_draft_model:
# Run extra general warmup to warmup memory pool before running real requests.
self._general_warmup(resource_manager, reverse=True)

# Set the value back to the original value after all warmups are complete
self.enable_spec_decode = self.is_spec_decode
Expand All @@ -612,8 +615,8 @@ def _general_warmup(self,
self.original_max_draft_len), self.max_num_tokens,
self.batch_size * (self.max_seq_len - 1))
max_batch_size = min(
self.batch_size,
curr_max_num_tokens // (1 + self.runtime_draft_len))
self.batch_size, curr_max_num_tokens //
(1 + self.runtime_draft_len) // self.max_beam_width)

warmup_requests_configs = {
(1, 1), # Specialize for 1 token.
Expand Down Expand Up @@ -936,8 +939,8 @@ def _create_warmup_request(

blocks_to_use = num_full_seqs * math.ceil(
max_seq_len / kv_cache_manager.tokens_per_block) + math.ceil(
num_left_over_tokens /
kv_cache_manager.tokens_per_block) + num_gen_requests
num_left_over_tokens / kv_cache_manager.tokens_per_block
) + num_gen_requests * self.max_beam_width

if blocks_to_use > available_blocks:
return None
Expand Down Expand Up @@ -2586,16 +2589,14 @@ def previous_seq_slots_device():
num_generation_requests = len(gen_request_seq_slots)
# Cache indirection is only used for beam search on generation requests
if self.use_beam_search and num_generation_requests > 0:
# CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph
is_cuda_graph_during_warmup = self.is_warmup and attn_metadata.is_cuda_graph
if cache_indirection_buffer is not None:
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
# Convert to GPU tensor to avoid implicit sync
gen_request_seq_slots_tensor = torch.tensor(
gen_request_seq_slots, dtype=torch.long, device='cuda')
self.cache_indirection_attention[:num_generation_requests].copy_(
cache_indirection_buffer[gen_request_seq_slots_tensor])
if cache_indirection_buffer is not None or is_cuda_graph_during_warmup:
if cache_indirection_buffer is not None or self.is_warmup:
attn_metadata.beam_width = self.max_beam_width
else:
attn_metadata.beam_width = 1
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(
idx: offset
for offset, idx in enumerate(self.pp_layers)
}
self.is_estimating_kv_cache = is_estimating_kv_cache

self.kv_connector_manager = kv_connector_manager

Expand Down