Skip to content

Commit 0652b35

Browse files
committed
[None][fix] Set token IDs on request after router tokenization to avoid re-tokenization
KvCacheAwareRouter now sets prompt_token_ids (ChatCompletionRequest) or replaces prompt with token IDs (CompletionRequest) after tokenizing, so the downstream worker server skips redundant tokenization. Also adds proper ChatCompletionRequest handling via apply_chat_template. Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent 0467fcd commit 0652b35

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

tensorrt_llm/serve/router.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -638,14 +638,17 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
638638
if isinstance(request, ChatCompletionRequest):
639639
if request.prompt_token_ids is not None:
640640
return [request.prompt_token_ids]
641-
# TODO: send tokenize-only request instead of tokenizing locally
642641
tokenizer = self._get_tokenizer(request.model)
643-
messages = [{"role": m["role"], "content": m.get("content", "")}
644-
for m in request.messages
645-
if "role" in m]
646-
text = tokenizer.apply_chat_template(
647-
messages, add_generation_prompt=True, tokenize=False)
648-
token_ids = tokenizer.encode(text, add_special_tokens=False)
642+
token_ids = tokenizer.apply_chat_template(
643+
[
644+
msg if isinstance(msg, dict) else dict(msg)
645+
for msg in request.messages
646+
],
647+
add_generation_prompt=request.add_generation_prompt,
648+
tokenize=True,
649+
)
650+
# Set prompt_token_ids so the worker server skips re-tokenization
651+
request.prompt_token_ids = token_ids
649652
return [token_ids]
650653

651654
# Handle CompletionRequest (has prompt)
@@ -659,9 +662,12 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
659662
else:
660663
assert isinstance(prompts, list) and isinstance(prompts[0], str)
661664

662-
# TODO: send tokenize-only request instead of tokenizing locally
663665
tokenizer = self._get_tokenizer(request.model)
664-
return [tokenizer(prompt)["input_ids"] for prompt in prompts]
666+
token_lists = [tokenizer(prompt)["input_ids"] for prompt in prompts]
667+
# Replace string prompts with token IDs so the worker server
668+
# skips re-tokenization
669+
request.prompt = token_lists if len(token_lists) > 1 else token_lists[0]
670+
return token_lists
665671

666672
async def get_next_server(
667673
self,

0 commit comments

Comments
 (0)