15
15
from vllm .config import ModelConfig , VllmConfig
16
16
from vllm .engine .arg_utils import AsyncEngineArgs
17
17
from vllm .engine .protocol import EngineClient
18
+ from vllm .entrypoints .utils import _validate_truncation_size
18
19
from vllm .envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
19
20
from vllm .inputs import PromptType
20
21
from vllm .inputs .preprocess import InputPreprocessor
@@ -348,13 +349,23 @@ async def generate(
348
349
# to handle startup failure gracefully in the OpenAI server.
349
350
self ._run_output_handler ()
350
351
352
+ tokenization_kwargs : dict [str , Any ] = {}
353
+ truncate_prompt_tokens = sampling_params .truncate_prompt_tokens
354
+
355
+ _validate_truncation_size (
356
+ self .model_config .max_model_len ,
357
+ truncate_prompt_tokens ,
358
+ tokenization_kwargs ,
359
+ )
360
+
351
361
q = await self .add_request (
352
362
request_id ,
353
363
prompt ,
354
364
sampling_params ,
355
365
lora_request = lora_request ,
356
366
trace_headers = trace_headers ,
357
367
priority = priority ,
368
+ tokenization_kwargs = tokenization_kwargs ,
358
369
data_parallel_rank = data_parallel_rank ,
359
370
)
360
371
@@ -481,6 +492,7 @@ async def encode(
481
492
lora_request : Optional [LoRARequest ] = None ,
482
493
trace_headers : Optional [Mapping [str , str ]] = None ,
483
494
priority : int = 0 ,
495
+ truncate_prompt_tokens : Optional [int ] = None ,
484
496
tokenization_kwargs : Optional [dict [str , Any ]] = None ,
485
497
) -> AsyncGenerator [PoolingRequestOutput , None ]:
486
498
"""
@@ -503,6 +515,14 @@ async def encode(
503
515
# to handle startup failure gracefully in the OpenAI server.
504
516
self ._run_output_handler ()
505
517
518
+ if tokenization_kwargs is None :
519
+ tokenization_kwargs = dict [str , Any ]()
520
+ _validate_truncation_size (
521
+ self .model_config .max_model_len ,
522
+ truncate_prompt_tokens ,
523
+ tokenization_kwargs ,
524
+ )
525
+
506
526
q = await self .add_request (
507
527
request_id ,
508
528
prompt ,
0 commit comments