16
16
from vllm .transformers_utils .tokenizer import AnyTokenizer
17
17
18
18
from .data import (DecoderOnlyInputs , EmbedsInputs , EmbedsPrompt ,
19
- EncoderDecoderInputs , ProcessorInputs , PromptType ,
20
- SingletonInputs , SingletonPrompt , TextPrompt , TokenInputs ,
21
- TokensPrompt , embeds_inputs , token_inputs )
19
+ EncoderDecoderInputs , ExplicitEncoderDecoderPrompt ,
20
+ ProcessorInputs , PromptType , SingletonInputs ,
21
+ SingletonPrompt , TextPrompt , TokenInputs , TokensPrompt ,
22
+ embeds_inputs , token_inputs )
22
23
from .parse import is_explicit_encoder_decoder_prompt , parse_singleton_prompt
23
24
24
25
logger = init_logger (__name__ )
@@ -322,7 +323,7 @@ def _process_tokens(
322
323
mm_uuids = mm_uuids ,
323
324
)
324
325
else :
325
- inputs = token_inputs (prompt_token_ids = prompt_token_ids )
326
+ inputs = token_inputs (prompt_token_ids )
326
327
327
328
if cache_salt := parsed_content .get ("cache_salt" ):
328
329
inputs ["cache_salt" ] = cache_salt
@@ -352,10 +353,7 @@ def _process_text(
352
353
prompt_text ,
353
354
tokenization_kwargs = tokenization_kwargs ,
354
355
)
355
- inputs = token_inputs (
356
- prompt = prompt_text ,
357
- prompt_token_ids = prompt_token_ids ,
358
- )
356
+ inputs = token_inputs (prompt_token_ids )
359
357
360
358
if cache_salt := parsed_content .get ("cache_salt" ):
361
359
inputs ["cache_salt" ] = cache_salt
@@ -473,22 +471,17 @@ def _split_enc_dec_mm_inputs(
473
471
decoder_inputs : SingletonInputs
474
472
475
473
if inputs ["type" ] == "multimodal" : # Multimodal data inputs
476
- if not ("encoder_prompt" in inputs
477
- and "encoder_prompt_token_ids" in inputs ):
474
+ if "encoder_prompt_token_ids" not in inputs :
478
475
raise RuntimeError ("You should register an encoder-decoder "
479
476
"multi-modal processor for encoder-decoder "
480
477
"models." )
481
478
inputs = cast (MultiModalEncDecInputs , inputs )
482
479
483
- encoder_inputs = token_inputs (
484
- prompt = inputs ["encoder_prompt" ],
485
- prompt_token_ids = inputs ["encoder_prompt_token_ids" ],
486
- )
480
+ encoder_inputs = token_inputs (inputs ["encoder_prompt_token_ids" ])
487
481
488
482
decoder_prompt_inputs = decoder_inputs_to_override or inputs
489
483
decoder_inputs = MultiModalInputs (
490
484
type = "multimodal" ,
491
- prompt = decoder_prompt_inputs .get ("prompt" , "" ),
492
485
prompt_token_ids = decoder_prompt_inputs ["prompt_token_ids" ],
493
486
mm_kwargs = inputs ["mm_kwargs" ],
494
487
mm_hashes = inputs ["mm_hashes" ],
@@ -498,7 +491,7 @@ def _split_enc_dec_mm_inputs(
498
491
decoder_inputs ["cache_salt" ] = cache_salt
499
492
500
493
elif inputs ["type" ] == "token" : # Text-only inputs
501
- encoder_inputs = token_inputs (prompt = "" , prompt_token_ids = [])
494
+ encoder_inputs = token_inputs (prompt_token_ids = [])
502
495
decoder_inputs = decoder_inputs_to_override or inputs
503
496
else :
504
497
assert_never (inputs ) # type: ignore[arg-type]
@@ -549,12 +542,14 @@ def _process_encoder_decoder_prompt(
549
542
decoder_inputs : Optional [SingletonInputs ]
550
543
551
544
if is_explicit_encoder_decoder_prompt (prompt ):
545
+ # `cast` is needed for mypy, but not pyright
546
+ prompt_ = cast (ExplicitEncoderDecoderPrompt , prompt )
552
547
encoder_inputs = self ._prompt_to_llm_inputs (
553
- prompt ["encoder_prompt" ],
548
+ prompt_ ["encoder_prompt" ],
554
549
tokenization_kwargs = tokenization_kwargs ,
555
550
mm_uuids = mm_uuids ,
556
551
)
557
- if (decoder_input := prompt ["decoder_prompt" ]) is None :
552
+ if (decoder_input := prompt_ ["decoder_prompt" ]) is None :
558
553
decoder_inputs = None
559
554
else :
560
555
decoder_inputs = self ._prompt_to_llm_inputs (decoder_input )
@@ -565,8 +560,9 @@ def _process_encoder_decoder_prompt(
565
560
self ._split_enc_dec_mm_inputs (encoder_inputs ,
566
561
decoder_inputs ))
567
562
else :
563
+ # `cast` is needed for mypy, but not pyright
568
564
inputs = self ._prompt_to_llm_inputs (
569
- prompt ,
565
+ cast ( SingletonPrompt , prompt ) ,
570
566
tokenization_kwargs = tokenization_kwargs ,
571
567
mm_uuids = mm_uuids ,
572
568
)
@@ -641,8 +637,9 @@ def preprocess(
641
637
"to decoder-only models" )
642
638
643
639
# Decoder-only operation
640
+ # `cast` is needed for mypy, but not pyright
644
641
return self ._process_decoder_only_prompt (
645
- prompt ,
642
+ cast ( SingletonPrompt , prompt ) ,
646
643
tokenization_kwargs = tokenization_kwargs ,
647
644
mm_uuids = mm_uuids ,
648
645
)
0 commit comments