@@ -857,15 +857,16 @@ def _create_warmup_request(
857857 return None
858858
859859 num_extra_decoding_steps = self ._get_num_extra_decoding_steps ()
860- if num_extra_decoding_steps > 0 :
861- return None # Disable autotuning for fused drafting loops for now.
862860
863861 num_ctx_tokens = num_tokens - num_gen_tokens
864862 num_ctx_requests = 0
865863 ctx_requests = []
866864 gen_requests = []
867865
868- max_seq_len = self .max_seq_len - 1
866+ # For drafting loops, reduce max_seq_len to leave room for extra decoding steps
867+ max_seq_len = self .max_seq_len - 1 - num_extra_decoding_steps
868+ if max_seq_len < 1 :
869+ return None # Not enough sequence length for drafting loop
869870 num_full_seqs = 0
870871 num_left_over_tokens = 0
871872
@@ -896,7 +897,8 @@ def _create_warmup_request(
896897 token_nums = ctx_token_nums ,
897898 is_gen = False ,
898899 max_num_draft_tokens = self .runtime_draft_len ,
899- use_mrope = self .use_mrope )
900+ use_mrope = self .use_mrope ,
901+ num_extra_decoding_steps = num_extra_decoding_steps )
900902
901903 if spec_resource_manager is not None :
902904 spec_resource_manager .add_dummy_requests (
@@ -909,7 +911,8 @@ def _create_warmup_request(
909911 token_nums = [1 ] * num_gen_tokens ,
910912 is_gen = True ,
911913 max_num_draft_tokens = self .max_total_draft_tokens ,
912- use_mrope = self .use_mrope )
914+ use_mrope = self .use_mrope ,
915+ num_extra_decoding_steps = num_extra_decoding_steps )
913916 if spec_resource_manager is not None :
914917 spec_resource_manager .add_dummy_requests (request_ids = list (
915918 range (num_ctx_requests , num_ctx_requests + num_gen_tokens )))
0 commit comments