Skip to content

Commit 751ef2e

Browse files
tjohnson31415dtrifiro
authored andcommitted
compatibility with vLLM v1 engine
Signed-off-by: Travis Johnson <[email protected]>
1 parent 8c9c80c commit 751ef2e

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

src/vllm_tgis_adapter/grpc/grpc_server.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
import uuid
88
from collections.abc import Callable, Coroutine
9+
from functools import lru_cache
910
from typing import TYPE_CHECKING, Any, TypeVar
1011

1112
import grpc
@@ -16,7 +17,7 @@
1617
from vllm.engine.async_llm_engine import AsyncLLMEngine
1718
from vllm.engine.multiprocessing import MQEngineDeadError
1819
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
19-
from vllm.inputs import token_inputs
20+
from vllm.inputs import TokensPrompt, token_inputs
2021
from vllm.sampling_params import RequestOutputKind, SamplingParams
2122
from vllm.tracing import (
2223
contains_trace_headers,
@@ -96,6 +97,12 @@ def with_default(value: _T, default: _T) -> _T:
9697
return value if value else default
9798

9899

100+
@lru_cache
101+
def _has_argument(generate_func: _F, arg_name: str) -> bool:
102+
signature = inspect.signature(generate_func)
103+
return arg_name in signature.parameters
104+
105+
99106
async def _handle_exception(
100107
e: Exception,
101108
func: Callable,
@@ -213,6 +220,28 @@ async def post_init(self) -> None:
213220
health_pb2.HealthCheckResponse.SERVING,
214221
)
215222

223+
def _make_generator(
224+
self,
225+
prompt: str,
226+
prompt_token_ids: list[int],
227+
**kwargs: dict[str, Any],
228+
) -> _F:
229+
# V1 removed inputs in favor of prompt
230+
if _has_argument(self.engine.generate, "inputs"):
231+
prompt_kwarg = {
232+
"inputs": token_inputs(
233+
prompt=prompt,
234+
prompt_token_ids=prompt_token_ids,
235+
)
236+
}
237+
else:
238+
prompt_kwarg = {"prompt": TokensPrompt(prompt_token_ids=prompt_token_ids)}
239+
240+
return self.engine.generate(
241+
**prompt_kwarg,
242+
**kwargs,
243+
)
244+
216245
@log_rpc_handler_errors
217246
async def Generate(
218247
self,
@@ -245,10 +274,6 @@ async def Generate(
245274
)
246275
request_id_i = f"{request_id}-{i}"
247276

248-
inputs = token_inputs(
249-
prompt=req.text,
250-
prompt_token_ids=input_ids,
251-
)
252277
is_tracing_enabled = await self.engine.is_tracing_enabled()
253278
headers = dict(context.invocation_metadata())
254279
logs.set_correlation_id(request_id_i, headers.get(CORRELATION_ID_HEADER))
@@ -257,12 +282,13 @@ async def Generate(
257282
elif contains_trace_headers(headers):
258283
log_tracing_disabled_warning()
259284
generators.append(
260-
self.engine.generate(
261-
inputs=inputs,
285+
self._make_generator(
286+
prompt=req.text,
287+
prompt_token_ids=input_ids,
262288
sampling_params=sampling_params,
263289
request_id=request_id_i,
264290
**kwargs,
265-
),
291+
)
266292
)
267293

268294
result_generator: AsyncIterator[tuple[int, RequestOutput]] = (
@@ -337,10 +363,6 @@ async def GenerateStream( # noqa: PLR0915, C901
337363
context,
338364
)
339365

340-
inputs = token_inputs(
341-
prompt=request.request.text,
342-
prompt_token_ids=input_ids,
343-
)
344366
kwargs = {}
345367
is_tracing_enabled = await self.engine.is_tracing_enabled()
346368
headers = dict(context.invocation_metadata())
@@ -351,10 +373,9 @@ async def GenerateStream( # noqa: PLR0915, C901
351373
if CORRELATION_ID_HEADER in headers:
352374
logs.set_correlation_id(request_id, headers.get(CORRELATION_ID_HEADER))
353375

354-
result_generator = self.engine.generate(
355-
# prompt is supplied for observability, the text is not
356-
# re-tokenized when `prompt_token_ids` is supplied
357-
inputs=inputs,
376+
result_generator = self._make_generator(
377+
prompt=request.request.text,
378+
prompt_token_ids=input_ids,
358379
sampling_params=sampling_params,
359380
request_id=request_id,
360381
**adapter_kwargs,
@@ -619,7 +640,7 @@ async def _validate_and_convert_params(
619640
min_tokens=min_new_tokens,
620641
repetition_penalty=with_default(decoding.repetition_penalty, 1.0),
621642
logits_processors=logits_processors,
622-
stop=with_default(stopping.stop_sequences, None),
643+
stop=with_default(list(stopping.stop_sequences), None),
623644
include_stop_str_in_output=stopping.include_stop_sequence
624645
if stopping.HasField("include_stop_sequence")
625646
else self.default_include_stop_seqs,

0 commit comments

Comments
 (0)