Skip to content

Commit 93b88bb

Browse files
ReceilingPradyun Ramadorai
authored andcommitted
Add tokenization_kwargs to encode for embedding model truncation (vllm-project#21033)
1 parent 7989183 commit 93b88bb

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

vllm/engine/async_llm_engine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ async def add_request_async(
438438
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
439439
priority: int = 0,
440440
data_parallel_rank: Optional[int] = None,
441+
tokenization_kwargs: Optional[dict[str, Any]] = None,
441442
) -> None:
442443
"""
443444
Async version of
@@ -468,6 +469,7 @@ async def add_request_async(
468469
prompt,
469470
lora_request=lora_request,
470471
prompt_adapter_request=prompt_adapter_request,
472+
tokenization_kwargs=tokenization_kwargs,
471473
)
472474

473475
if isinstance(params, SamplingParams) and \
@@ -862,6 +864,7 @@ async def add_request(
862864
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
863865
priority: int = 0,
864866
data_parallel_rank: Optional[int] = None,
867+
tokenization_kwargs: Optional[dict[str, Any]] = None,
865868
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
866869
if not self.is_running:
867870
if self.start_engine_loop:
@@ -889,6 +892,7 @@ async def add_request(
889892
prompt_adapter_request=prompt_adapter_request,
890893
priority=priority,
891894
data_parallel_rank=data_parallel_rank,
895+
tokenization_kwargs=tokenization_kwargs,
892896
)
893897

894898
return stream.generator()
@@ -996,6 +1000,7 @@ async def encode(
9961000
lora_request: Optional[LoRARequest] = None,
9971001
trace_headers: Optional[Mapping[str, str]] = None,
9981002
priority: int = 0,
1003+
tokenization_kwargs: Optional[dict[str, Any]] = None,
9991004
) -> AsyncGenerator[PoolingRequestOutput, None]:
10001005
"""Generate outputs for a request from a pooling model.
10011006
@@ -1070,6 +1075,7 @@ async def encode(
10701075
lora_request=lora_request,
10711076
trace_headers=trace_headers,
10721077
priority=priority,
1078+
tokenization_kwargs=tokenization_kwargs,
10731079
):
10741080
yield LLMEngine.validate_output(output, PoolingRequestOutput)
10751081
except asyncio.CancelledError:

vllm/entrypoints/llm.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,7 @@ def encode(
965965
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
966966
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
967967
pooling_task: PoolingTask = "encode",
968+
tokenization_kwargs: Optional[dict[str, Any]] = None,
968969
) -> list[PoolingRequestOutput]:
969970
...
970971

@@ -981,6 +982,7 @@ def encode(
981982
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
982983
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
983984
pooling_task: PoolingTask = "encode",
985+
tokenization_kwargs: Optional[dict[str, Any]] = None,
984986
) -> list[PoolingRequestOutput]:
985987
...
986988

@@ -997,6 +999,7 @@ def encode(
997999
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
9981000
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
9991001
pooling_task: PoolingTask = "encode",
1002+
tokenization_kwargs: Optional[dict[str, Any]] = None,
10001003
) -> list[PoolingRequestOutput]:
10011004
...
10021005

@@ -1014,6 +1017,7 @@ def encode(
10141017
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
10151018
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
10161019
pooling_task: PoolingTask = "encode",
1020+
tokenization_kwargs: Optional[dict[str, Any]] = None,
10171021
) -> list[PoolingRequestOutput]:
10181022
...
10191023

@@ -1031,6 +1035,7 @@ def encode(
10311035
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
10321036
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
10331037
pooling_task: PoolingTask = "encode",
1038+
tokenization_kwargs: Optional[dict[str, Any]] = None,
10341039
) -> list[PoolingRequestOutput]:
10351040
...
10361041

@@ -1046,6 +1051,7 @@ def encode(
10461051
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
10471052
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
10481053
pooling_task: PoolingTask = "encode",
1054+
tokenization_kwargs: Optional[dict[str, Any]] = None,
10491055
) -> list[PoolingRequestOutput]:
10501056
...
10511057

@@ -1066,6 +1072,7 @@ def encode(
10661072
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
10671073
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
10681074
pooling_task: PoolingTask = "encode",
1075+
tokenization_kwargs: Optional[dict[str, Any]] = None,
10691076
) -> list[PoolingRequestOutput]:
10701077
"""Apply pooling to the hidden states corresponding to the input
10711078
prompts.
@@ -1131,9 +1138,11 @@ def encode(
11311138
for pooling_param in pooling_params:
11321139
pooling_param.verify(pooling_task, model_config)
11331140

1134-
tokenization_kwargs = dict[str, Any]()
1135-
_validate_truncation_size(model_config.max_model_len,
1136-
truncate_prompt_tokens, tokenization_kwargs)
1141+
if tokenization_kwargs is None:
1142+
tokenization_kwargs = dict[str, Any]()
1143+
_validate_truncation_size(model_config.max_model_len,
1144+
truncate_prompt_tokens,
1145+
tokenization_kwargs)
11371146

11381147
self._validate_and_add_requests(
11391148
prompts=parsed_prompts,

vllm/v1/engine/async_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ async def encode(
437437
lora_request: Optional[LoRARequest] = None,
438438
trace_headers: Optional[Mapping[str, str]] = None,
439439
priority: int = 0,
440+
tokenization_kwargs: Optional[dict[str, Any]] = None,
440441
) -> AsyncGenerator[PoolingRequestOutput, None]:
441442
"""
442443
Main function called by the API server to kick off a request
@@ -465,6 +466,7 @@ async def encode(
465466
lora_request=lora_request,
466467
trace_headers=trace_headers,
467468
priority=priority,
469+
tokenization_kwargs=tokenization_kwargs,
468470
)
469471

470472
# The output_handler task pushes items into the queue.

0 commit comments

Comments
 (0)