Skip to content

Commit 4317859

Browse files
authored
[TRTLLM-10143][feat] Reuse previous draft requests if possible (#10263)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent c4b36d3 commit 4317859

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,16 @@ def _create_accepted_tokens_request(self, request: LlmRequest,
174174
input_tokens) - num_accepted_tokens - 1
175175
return new_request
176176

177+
def _get_previous_draft_request(
178+
self, request: LlmRequest) -> Optional[LlmRequest]:
179+
"""Get the previous draft request for the given request."""
180+
if self.previous_draft_batch is None:
181+
return None
182+
for req in self.previous_draft_batch.all_requests():
183+
if req.py_request_id == request.py_request_id:
184+
return req
185+
return None
186+
177187
def _create_accepted_tokens_request_for_trtllm_attn(
178188
self, request: LlmRequest, input_tokens: Any,
179189
num_accepted_tokens: int) -> LlmRequest:
@@ -186,14 +196,24 @@ def _create_accepted_tokens_request_for_trtllm_attn(
186196
# because at most max_draft_len draft tokens are accepted.
187197
input_tokens.extend(
188198
0 for _ in range(self.max_draft_len - num_accepted_tokens))
189-
new_request = self._create_draft_request(request, input_tokens)
190-
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
191-
new_request.py_num_accepted_draft_tokens = request.py_num_accepted_draft_tokens
192-
new_request.py_is_first_draft = True
199+
200+
# Reuse the previous draft request if it exists.
201+
# This can reduce host overhead significantly.
202+
draft_request = self._get_previous_draft_request(request)
203+
if draft_request is not None:
204+
generated_tokens = input_tokens[draft_request.py_prompt_len:]
205+
draft_request.set_generated_tokens([generated_tokens])
206+
else:
207+
draft_request = self._create_draft_request(request, input_tokens)
208+
209+
draft_request.state = LlmRequestState.GENERATION_IN_PROGRESS
210+
draft_request.py_num_accepted_draft_tokens = request.py_num_accepted_draft_tokens
211+
draft_request.py_is_first_draft = True
193212
# For tree decoding, we need to store the accepted tokens indices for these requests,
194213
# which will be used to update the hidden_states_read_indices.
195-
new_request.py_num_accepted_draft_tokens_indices = request.py_num_accepted_draft_tokens_indices
196-
return new_request
214+
draft_request.py_num_accepted_draft_tokens_indices = request.py_num_accepted_draft_tokens_indices
215+
216+
return draft_request
197217

198218
def _create_draft_request_for_request(
199219
self, request: LlmRequest) -> Optional[LlmRequest]:

0 commit comments

Comments
 (0)