@@ -846,10 +846,6 @@ def _create_warmup_request(
846846 if num_tokens > self .max_num_tokens or num_tokens > available_tokens :
847847 return None
848848
849- num_extra_decoding_steps = self ._get_num_extra_decoding_steps ()
850- if num_extra_decoding_steps > 0 :
851- return None # Disable autotuning for fused drafting loops for now.
852-
853849 num_ctx_tokens = num_tokens - num_gen_tokens
854850 num_ctx_requests = 0
855851 ctx_requests = []
@@ -868,10 +864,16 @@ def _create_warmup_request(
868864 if num_ctx_requests + num_gen_tokens > self .batch_size :
869865 return None # Not enough batch size to fill the request
870866
867+ # For fused drafting loops, each generation request needs extra blocks
868+ # for the tokens that will be generated during the loop
869+ num_extra_decoding_steps = self ._get_num_extra_decoding_steps ()
870+ tokens_per_gen = 1 + num_extra_decoding_steps
871+ blocks_per_gen = math .ceil (tokens_per_gen /
872+ kv_cache_manager .tokens_per_block )
871873 blocks_to_use = num_full_seqs * math .ceil (
872874 max_seq_len / kv_cache_manager .tokens_per_block ) + math .ceil (
873- num_left_over_tokens /
874- kv_cache_manager . tokens_per_block ) + num_gen_tokens
875+ num_left_over_tokens / kv_cache_manager . tokens_per_block
876+ ) + num_gen_tokens * blocks_per_gen
875877
876878 if blocks_to_use > available_blocks :
877879 return None
@@ -899,7 +901,8 @@ def _create_warmup_request(
899901 token_nums = [1 ] * num_gen_tokens ,
900902 is_gen = True ,
901903 max_num_draft_tokens = self .max_total_draft_tokens ,
902- use_mrope = self .use_mrope )
904+ use_mrope = self .use_mrope ,
905+ num_extra_decoding_steps = num_extra_decoding_steps )
903906 if spec_resource_manager is not None :
904907 spec_resource_manager .add_dummy_requests (request_ids = list (
905908 range (num_ctx_requests , num_ctx_requests + num_gen_tokens )))
0 commit comments