6
6
import time
7
7
import uuid
8
8
from collections .abc import Callable , Coroutine
9
+ from functools import lru_cache
9
10
from typing import TYPE_CHECKING , Any , TypeVar
10
11
11
12
import grpc
16
17
from vllm .engine .async_llm_engine import AsyncLLMEngine
17
18
from vllm .engine .multiprocessing import MQEngineDeadError
18
19
from vllm .entrypoints .openai .serving_completion import merge_async_iterators
19
- from vllm .inputs import token_inputs
20
+ from vllm .inputs import TokensPrompt , token_inputs
20
21
from vllm .sampling_params import RequestOutputKind , SamplingParams
21
22
from vllm .tracing import (
22
23
contains_trace_headers ,
@@ -96,6 +97,12 @@ def with_default(value: _T, default: _T) -> _T:
96
97
return value if value else default
97
98
98
99
100
+ @lru_cache
101
+ def _has_argument (generate_func : _F , arg_name : str ) -> bool :
102
+ signature = inspect .signature (generate_func )
103
+ return arg_name in signature .parameters
104
+
105
+
99
106
async def _handle_exception (
100
107
e : Exception ,
101
108
func : Callable ,
@@ -213,6 +220,28 @@ async def post_init(self) -> None:
213
220
health_pb2 .HealthCheckResponse .SERVING ,
214
221
)
215
222
223
+ def _make_generator (
224
+ self ,
225
+ prompt : str ,
226
+ prompt_token_ids : list [int ],
227
+ ** kwargs : dict [str , Any ],
228
+ ) -> _F :
229
+ # V1 removed inputs in favor of prompt
230
+ if _has_argument (self .engine .generate , "inputs" ):
231
+ prompt_kwarg = {
232
+ "inputs" : token_inputs (
233
+ prompt = prompt ,
234
+ prompt_token_ids = prompt_token_ids ,
235
+ )
236
+ }
237
+ else :
238
+ prompt_kwarg = {"prompt" : TokensPrompt (prompt_token_ids = prompt_token_ids )}
239
+
240
+ return self .engine .generate (
241
+ ** prompt_kwarg ,
242
+ ** kwargs ,
243
+ )
244
+
216
245
@log_rpc_handler_errors
217
246
async def Generate (
218
247
self ,
@@ -245,10 +274,6 @@ async def Generate(
245
274
)
246
275
request_id_i = f"{ request_id } -{ i } "
247
276
248
- inputs = token_inputs (
249
- prompt = req .text ,
250
- prompt_token_ids = input_ids ,
251
- )
252
277
is_tracing_enabled = await self .engine .is_tracing_enabled ()
253
278
headers = dict (context .invocation_metadata ())
254
279
logs .set_correlation_id (request_id_i , headers .get (CORRELATION_ID_HEADER ))
@@ -257,12 +282,13 @@ async def Generate(
257
282
elif contains_trace_headers (headers ):
258
283
log_tracing_disabled_warning ()
259
284
generators .append (
260
- self .engine .generate (
261
- inputs = inputs ,
285
+ self ._make_generator (
286
+ prompt = req .text ,
287
+ prompt_token_ids = input_ids ,
262
288
sampling_params = sampling_params ,
263
289
request_id = request_id_i ,
264
290
** kwargs ,
265
- ),
291
+ )
266
292
)
267
293
268
294
result_generator : AsyncIterator [tuple [int , RequestOutput ]] = (
@@ -337,10 +363,6 @@ async def GenerateStream( # noqa: PLR0915, C901
337
363
context ,
338
364
)
339
365
340
- inputs = token_inputs (
341
- prompt = request .request .text ,
342
- prompt_token_ids = input_ids ,
343
- )
344
366
kwargs = {}
345
367
is_tracing_enabled = await self .engine .is_tracing_enabled ()
346
368
headers = dict (context .invocation_metadata ())
@@ -351,10 +373,9 @@ async def GenerateStream( # noqa: PLR0915, C901
351
373
if CORRELATION_ID_HEADER in headers :
352
374
logs .set_correlation_id (request_id , headers .get (CORRELATION_ID_HEADER ))
353
375
354
- result_generator = self .engine .generate (
355
- # prompt is supplied for observability, the text is not
356
- # re-tokenized when `prompt_token_ids` is supplied
357
- inputs = inputs ,
376
+ result_generator = self ._make_generator (
377
+ prompt = request .request .text ,
378
+ prompt_token_ids = input_ids ,
358
379
sampling_params = sampling_params ,
359
380
request_id = request_id ,
360
381
** adapter_kwargs ,
@@ -619,7 +640,7 @@ async def _validate_and_convert_params(
619
640
min_tokens = min_new_tokens ,
620
641
repetition_penalty = with_default (decoding .repetition_penalty , 1.0 ),
621
642
logits_processors = logits_processors ,
622
- stop = with_default (stopping .stop_sequences , None ),
643
+ stop = with_default (list ( stopping .stop_sequences ) , None ),
623
644
include_stop_str_in_output = stopping .include_stop_sequence
624
645
if stopping .HasField ("include_stop_sequence" )
625
646
else self .default_include_stop_seqs ,
0 commit comments