@@ -598,6 +598,9 @@ def warmup(self, resource_manager: ResourceManager) -> None:
598598 self ._run_torch_compile_warmup (resource_manager )
599599 self ._run_autotuner_warmup (resource_manager )
600600 self ._run_cuda_graph_warmup (resource_manager )
601+ if not kv_cache_manager .is_estimating_kv_cache :
602+ # Run extra general warmup to warmup memory pool before run real requests.
603+ self ._general_warmup (resource_manager , reverse = True )
601604
602605 # Set the value back to the original value after all warmups are complete
603606 self .enable_spec_decode = self .is_spec_decode
@@ -612,8 +615,8 @@ def _general_warmup(self,
612615 self .original_max_draft_len ), self .max_num_tokens ,
613616 self .batch_size * (self .max_seq_len - 1 ))
614617 max_batch_size = min (
615- self .batch_size ,
616- curr_max_num_tokens // (1 + self .runtime_draft_len ))
618+ self .batch_size , curr_max_num_tokens //
619+ (1 + self .runtime_draft_len ) // self . max_beam_width )
617620
618621 warmup_requests_configs = {
619622 (1 , 1 ), # Specialize for 1 token.
@@ -936,8 +939,8 @@ def _create_warmup_request(
936939
937940 blocks_to_use = num_full_seqs * math .ceil (
938941 max_seq_len / kv_cache_manager .tokens_per_block ) + math .ceil (
939- num_left_over_tokens /
940- kv_cache_manager . tokens_per_block ) + num_gen_requests
942+ num_left_over_tokens / kv_cache_manager . tokens_per_block
943+ ) + num_gen_requests * self . max_beam_width
941944
942945 if blocks_to_use > available_blocks :
943946 return None
0 commit comments