4
4
from typing import (Any , AsyncGenerator , Callable , Dict , Iterable , List ,
5
5
Mapping , Optional , Set , Tuple , Type , Union )
6
6
7
- from typing_extensions import assert_never
8
-
9
7
import vllm .envs as envs
10
8
from vllm .config import (DecodingConfig , EngineConfig , LoRAConfig , ModelConfig ,
11
9
ParallelConfig , SchedulerConfig )
12
10
from vllm .core .scheduler import SchedulerOutputs
13
11
from vllm .engine .arg_utils import AsyncEngineArgs
14
12
from vllm .engine .async_timeout import asyncio_timeout
15
- from vllm .engine .llm_engine import (DecoderPromptComponents , LLMEngine ,
16
- PromptComponents , SchedulerOutputState )
13
+ from vllm .engine .llm_engine import LLMEngine , SchedulerOutputState
17
14
from vllm .engine .metrics_types import StatLoggerBase
18
15
from vllm .executor .executor_base import ExecutorAsyncBase
19
16
from vllm .executor .ray_utils import initialize_ray_cluster
20
- from vllm .inputs import (EncoderDecoderLLMInputs , LLMInputs , PromptInputs ,
21
- SingletonPromptInputs )
22
- from vllm .inputs .parse import is_explicit_encoder_decoder_prompt
17
+ from vllm .inputs import PromptInputs
23
18
from vllm .logger import init_logger
24
19
from vllm .lora .request import LoRARequest
25
20
from vllm .model_executor .layers .sampler import SamplerOutput
@@ -403,139 +398,6 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
403
398
"""Stop the remote worker execution loop."""
404
399
await self .model_executor .stop_remote_worker_execution_loop_async ()
405
400
406
- async def _tokenize_prompt_async (
407
- self ,
408
- prompt : str ,
409
- request_id : str ,
410
- lora_request : Optional [LoRARequest ],
411
- ) -> List [int ]:
412
- """Async version of :meth:`_tokenize_prompt`."""
413
- tokenizer = self .get_tokenizer_group (
414
- missing_msg = "prompts must be None if skip_tokenizer_init is True" )
415
-
416
- return await tokenizer .encode_async (request_id = request_id ,
417
- prompt = prompt ,
418
- lora_request = lora_request )
419
-
420
- async def _extract_prompt_components_async (
421
- self ,
422
- inputs : SingletonPromptInputs ,
423
- request_id : str ,
424
- lora_request : Optional [LoRARequest ] = None ,
425
- ) -> PromptComponents :
426
- """Async version of :meth:`_extract_prompt_components`."""
427
- if isinstance (inputs , str ):
428
- prompt = inputs
429
- prompt_token_ids = await self ._tokenize_prompt_async (
430
- prompt ,
431
- request_id = request_id ,
432
- lora_request = lora_request ,
433
- )
434
- multi_modal_data = None
435
- elif isinstance (inputs , dict ):
436
- if "prompt_token_ids" in inputs :
437
- prompt = None
438
- prompt_token_ids = inputs ["prompt_token_ids" ]
439
- else :
440
- # NOTE: This extra assignment is required to pass mypy
441
- prompt = parsed_prompt = inputs ["prompt" ]
442
- prompt_token_ids = await self ._tokenize_prompt_async (
443
- parsed_prompt ,
444
- request_id = request_id ,
445
- lora_request = lora_request ,
446
- )
447
-
448
- multi_modal_data = inputs .get ("multi_modal_data" )
449
- else :
450
- assert_never (inputs )
451
-
452
- return prompt , prompt_token_ids , multi_modal_data
453
-
454
- async def _process_encoder_decoder_prompt_async (
455
- self ,
456
- inputs : PromptInputs ,
457
- request_id : str ,
458
- ) -> EncoderDecoderLLMInputs :
459
- """Async version of :meth:`_process_encoder_decoder_prompt`."""
460
- encoder_comps : PromptComponents
461
- decoder_comps : DecoderPromptComponents
462
-
463
- if is_explicit_encoder_decoder_prompt (inputs ):
464
- encoder_task = self ._extract_prompt_components_async (
465
- inputs ["encoder_prompt" ],
466
- request_id = request_id ,
467
- )
468
-
469
- if (decoder_input := inputs ["decoder_prompt" ]) is None :
470
- encoder_comps = await encoder_task
471
- decoder_comps = None , None , None
472
- else :
473
- decoder_task = self ._extract_prompt_components_async (
474
- decoder_input ,
475
- request_id = request_id ,
476
- )
477
-
478
- encoder_comps , decoder_comps = await asyncio .gather (
479
- encoder_task , decoder_task )
480
- else :
481
- encoder_comps = await self ._extract_prompt_components_async (
482
- inputs ,
483
- request_id = request_id ,
484
- )
485
-
486
- decoder_comps = None , None , None
487
-
488
- return self ._build_enc_dec_llm_inputs (encoder_comps , decoder_comps )
489
-
490
- async def _process_decoder_only_prompt_async (
491
- self ,
492
- inputs : SingletonPromptInputs ,
493
- request_id : str ,
494
- lora_request : Optional [LoRARequest ] = None ,
495
- prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
496
- ) -> LLMInputs :
497
- """Async version of :meth:`_process_decoder_only_prompt`."""
498
- prompt_comps = await self ._extract_prompt_components_async (
499
- inputs ,
500
- request_id = request_id ,
501
- lora_request = lora_request ,
502
- )
503
-
504
- return self ._build_decoder_only_llm_inputs (
505
- prompt_comps ,
506
- prompt_adapter_request = prompt_adapter_request ,
507
- )
508
-
509
- async def process_model_inputs_async (
510
- self ,
511
- inputs : PromptInputs ,
512
- request_id : str ,
513
- lora_request : Optional [LoRARequest ] = None ,
514
- prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
515
- ) -> Union [LLMInputs , EncoderDecoderLLMInputs ]:
516
- """Async version of :meth:`process_model_inputs`."""
517
- if self .is_encoder_decoder_model ():
518
- # Encoder-decoder model requires special mapping of
519
- # input prompts to encoder & decoder
520
- model_inputs = await self ._process_encoder_decoder_prompt_async (
521
- inputs ,
522
- request_id = request_id ,
523
- )
524
- else :
525
- if is_explicit_encoder_decoder_prompt (inputs ):
526
- raise ValueError ("Cannot pass encoder-decoder prompt "
527
- "to decoder-only models" )
528
-
529
- # Decoder-only operation
530
- model_inputs = await self ._process_decoder_only_prompt_async (
531
- inputs ,
532
- request_id = request_id ,
533
- lora_request = lora_request ,
534
- prompt_adapter_request = prompt_adapter_request ,
535
- )
536
-
537
- return self .input_processor (model_inputs )
538
-
539
401
async def add_request_async (
540
402
self ,
541
403
request_id : str ,
@@ -553,12 +415,13 @@ async def add_request_async(
553
415
if arrival_time is None :
554
416
arrival_time = time .time ()
555
417
556
- processed_inputs = await self .process_model_inputs_async (
418
+ preprocessed_inputs = await self .input_preprocessor . preprocess_async (
557
419
inputs ,
558
420
request_id = request_id ,
559
421
lora_request = lora_request ,
560
422
prompt_adapter_request = prompt_adapter_request ,
561
423
)
424
+ processed_inputs = self .input_processor (preprocessed_inputs )
562
425
563
426
self ._add_processed_request (
564
427
request_id = request_id ,
0 commit comments