@@ -598,6 +598,9 @@ def warmup(self, resource_manager: ResourceManager) -> None:
598598 self ._run_torch_compile_warmup (resource_manager )
599599 self ._run_autotuner_warmup (resource_manager )
600600 self ._run_cuda_graph_warmup (resource_manager )
601+ if not kv_cache_manager .is_estimating_kv_cache and not self .is_draft_model :
602+ # Run extra general warmup to warmup memory pool before running real requests.
603+ self ._general_warmup (resource_manager , reverse = True )
601604
602605 # Set the value back to the original value after all warmups are complete
603606 self .enable_spec_decode = self .is_spec_decode
@@ -612,8 +615,8 @@ def _general_warmup(self,
612615 self .original_max_draft_len ), self .max_num_tokens ,
613616 self .batch_size * (self .max_seq_len - 1 ))
614617 max_batch_size = min (
615- self .batch_size ,
616- curr_max_num_tokens // (1 + self .runtime_draft_len ))
618+ self .batch_size , curr_max_num_tokens //
619+ (1 + self .runtime_draft_len ) // self . max_beam_width )
617620
618621 warmup_requests_configs = {
619622 (1 , 1 ), # Specialize for 1 token.
@@ -936,8 +939,8 @@ def _create_warmup_request(
936939
937940 blocks_to_use = num_full_seqs * math .ceil (
938941 max_seq_len / kv_cache_manager .tokens_per_block ) + math .ceil (
939- num_left_over_tokens /
940- kv_cache_manager . tokens_per_block ) + num_gen_requests
942+ num_left_over_tokens / kv_cache_manager . tokens_per_block
943+ ) + num_gen_requests * self . max_beam_width
941944
942945 if blocks_to_use > available_blocks :
943946 return None
@@ -2586,16 +2589,14 @@ def previous_seq_slots_device():
25862589 num_generation_requests = len (gen_request_seq_slots )
25872590 # Cache indirection is only used for beam search on generation requests
25882591 if self .use_beam_search and num_generation_requests > 0 :
2589- # CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph
2590- is_cuda_graph_during_warmup = self .is_warmup and attn_metadata .is_cuda_graph
25912592 if cache_indirection_buffer is not None :
25922593 #Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
25932594 # Convert to GPU tensor to avoid implicit sync
25942595 gen_request_seq_slots_tensor = torch .tensor (
25952596 gen_request_seq_slots , dtype = torch .long , device = 'cuda' )
25962597 self .cache_indirection_attention [:num_generation_requests ].copy_ (
25972598 cache_indirection_buffer [gen_request_seq_slots_tensor ])
2598- if cache_indirection_buffer is not None or is_cuda_graph_during_warmup :
2599+ if cache_indirection_buffer is not None or self . is_warmup :
25992600 attn_metadata .beam_width = self .max_beam_width
26002601 else :
26012602 attn_metadata .beam_width = 1
0 commit comments