Skip to content

Commit 389cb14

Browse files
committed
[None][feat] Run extra general warmup to warm up memory pool
Signed-off-by: Jin Li <[email protected]>
1 parent 6c1abf2 commit 389cb14

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,9 @@ def warmup(self, resource_manager: ResourceManager) -> None:
598598
self._run_torch_compile_warmup(resource_manager)
599599
self._run_autotuner_warmup(resource_manager)
600600
self._run_cuda_graph_warmup(resource_manager)
601+
if not kv_cache_manager.is_estimating_kv_cache:
602+
# Run extra general warmup to warmup memory pool before run real requests.
603+
self._general_warmup(resource_manager, reverse=True)
601604

602605
# Set the value back to the original value after all warmups are complete
603606
self.enable_spec_decode = self.is_spec_decode
@@ -612,8 +615,8 @@ def _general_warmup(self,
612615
self.original_max_draft_len), self.max_num_tokens,
613616
self.batch_size * (self.max_seq_len - 1))
614617
max_batch_size = min(
615-
self.batch_size,
616-
curr_max_num_tokens // (1 + self.runtime_draft_len))
618+
self.batch_size, curr_max_num_tokens //
619+
(1 + self.runtime_draft_len) // self.max_beam_width)
617620

618621
warmup_requests_configs = {
619622
(1, 1), # Specialize for 1 token.
@@ -936,8 +939,8 @@ def _create_warmup_request(
936939

937940
blocks_to_use = num_full_seqs * math.ceil(
938941
max_seq_len / kv_cache_manager.tokens_per_block) + math.ceil(
939-
num_left_over_tokens /
940-
kv_cache_manager.tokens_per_block) + num_gen_requests
942+
num_left_over_tokens / kv_cache_manager.tokens_per_block
943+
) + num_gen_requests * self.max_beam_width
941944

942945
if blocks_to_use > available_blocks:
943946
return None

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def __init__(
193193
idx: offset
194194
for offset, idx in enumerate(self.pp_layers)
195195
}
196+
self.is_estimating_kv_cache = is_estimating_kv_cache
196197

197198
self.kv_connector_manager = kv_connector_manager
198199

0 commit comments

Comments
 (0)