Skip to content

Commit 05d839c

Browse files
authored
Fix(async): Add support for truncate_prompt_tokens in AsyncLLM (vllm-project#23800)
1 parent 6597d7a commit 05d839c

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

vllm/v1/engine/async_llm.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.config import ModelConfig, VllmConfig
1616
from vllm.engine.arg_utils import AsyncEngineArgs
1717
from vllm.engine.protocol import EngineClient
18+
from vllm.entrypoints.utils import _validate_truncation_size
1819
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
1920
from vllm.inputs import PromptType
2021
from vllm.inputs.preprocess import InputPreprocessor
@@ -348,13 +349,23 @@ async def generate(
348349
# to handle startup failure gracefully in the OpenAI server.
349350
self._run_output_handler()
350351

352+
tokenization_kwargs: dict[str, Any] = {}
353+
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
354+
355+
_validate_truncation_size(
356+
self.model_config.max_model_len,
357+
truncate_prompt_tokens,
358+
tokenization_kwargs,
359+
)
360+
351361
q = await self.add_request(
352362
request_id,
353363
prompt,
354364
sampling_params,
355365
lora_request=lora_request,
356366
trace_headers=trace_headers,
357367
priority=priority,
368+
tokenization_kwargs=tokenization_kwargs,
358369
data_parallel_rank=data_parallel_rank,
359370
)
360371

@@ -481,6 +492,7 @@ async def encode(
481492
lora_request: Optional[LoRARequest] = None,
482493
trace_headers: Optional[Mapping[str, str]] = None,
483494
priority: int = 0,
495+
truncate_prompt_tokens: Optional[int] = None,
484496
tokenization_kwargs: Optional[dict[str, Any]] = None,
485497
) -> AsyncGenerator[PoolingRequestOutput, None]:
486498
"""
@@ -503,6 +515,14 @@ async def encode(
503515
# to handle startup failure gracefully in the OpenAI server.
504516
self._run_output_handler()
505517

518+
if tokenization_kwargs is None:
519+
tokenization_kwargs = dict[str, Any]()
520+
_validate_truncation_size(
521+
self.model_config.max_model_len,
522+
truncate_prompt_tokens,
523+
tokenization_kwargs,
524+
)
525+
506526
q = await self.add_request(
507527
request_id,
508528
prompt,

0 commit comments

Comments
 (0)