@@ -174,6 +174,16 @@ def _create_accepted_tokens_request(self, request: LlmRequest,
174174 input_tokens ) - num_accepted_tokens - 1
175175 return new_request
176176
177+ def _get_previous_draft_request (
178+ self , request : LlmRequest ) -> Optional [LlmRequest ]:
179+ """Get the previous draft request for the given request."""
180+ if self .previous_draft_batch is None :
181+ return None
182+ for req in self .previous_draft_batch .all_requests ():
183+ if req .py_request_id == request .py_request_id :
184+ return req
185+ return None
186+
177187 def _create_accepted_tokens_request_for_trtllm_attn (
178188 self , request : LlmRequest , input_tokens : Any ,
179189 num_accepted_tokens : int ) -> LlmRequest :
@@ -186,14 +196,24 @@ def _create_accepted_tokens_request_for_trtllm_attn(
186196 # because at most max_draft_len draft tokens are accepted.
187197 input_tokens .extend (
188198 0 for _ in range (self .max_draft_len - num_accepted_tokens ))
189- new_request = self ._create_draft_request (request , input_tokens )
190- new_request .state = LlmRequestState .GENERATION_IN_PROGRESS
191- new_request .py_num_accepted_draft_tokens = request .py_num_accepted_draft_tokens
192- new_request .py_is_first_draft = True
199+
200+ # Reuse the previous draft request if it exists.
201+ # This can reduce host overhead significantly.
202+ draft_request = self ._get_previous_draft_request (request )
203+ if draft_request is not None :
204+ generated_tokens = input_tokens [draft_request .py_prompt_len :]
205+ draft_request .set_generated_tokens ([generated_tokens ])
206+ else :
207+ draft_request = self ._create_draft_request (request , input_tokens )
208+
209+ draft_request .state = LlmRequestState .GENERATION_IN_PROGRESS
210+ draft_request .py_num_accepted_draft_tokens = request .py_num_accepted_draft_tokens
211+ draft_request .py_is_first_draft = True
193212 # For tree decoding, we need to store the accepted tokens indices for these requests,
194213 # which will be used to update the hidden_states_read_indices.
195- new_request .py_num_accepted_draft_tokens_indices = request .py_num_accepted_draft_tokens_indices
196- return new_request
214+ draft_request .py_num_accepted_draft_tokens_indices = request .py_num_accepted_draft_tokens_indices
215+
216+ return draft_request
197217
198218 def _create_draft_request_for_request (
199219 self , request : LlmRequest ) -> Optional [LlmRequest ]:
0 commit comments