25
25
from vllm .model_executor .models .module_mapping import MultiModelKeys
26
26
from vllm .multimodal import MULTIMODAL_REGISTRY
27
27
from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig ,
28
- MultiModalKwargsItems , NestedTensors )
28
+ MultiModalKwargsItems )
29
29
from vllm .multimodal .parse import (ImageEmbeddingItems , ImageProcessorItems ,
30
30
ImageSize , MultiModalDataItems )
31
31
from vllm .multimodal .processing import (BaseMultiModalProcessor ,
39
39
40
40
from .interfaces import (MultiModalEmbeddings , SupportsLoRA ,
41
41
SupportsMultiModal , SupportsPP )
42
- from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
42
+ from .utils import (AutoWeightsLoader , WeightsMapper ,
43
43
init_vllm_registered_model , maybe_prefix )
44
44
45
45
@@ -304,7 +304,7 @@ def _call_hf_processor(
304
304
mm_data : Mapping [str , object ],
305
305
mm_kwargs : Mapping [str , object ],
306
306
tok_kwargs : Mapping [str , object ],
307
- ) -> Mapping [ str , NestedTensors ] :
307
+ ) -> BatchFeature :
308
308
mm_data = dict (mm_data )
309
309
videos = mm_data .pop ("videos" , [])
310
310
images = mm_data .pop ("images" , [])
@@ -342,7 +342,7 @@ def _call_hf_processor(
342
342
image_placeholder , 1 )
343
343
344
344
num_patches = [len (item ) for item in image_pixel_values ]
345
- image_outputs : dict [ str , NestedTensors ] = {
345
+ image_outputs = {
346
346
"pixel_values" : torch .concat (image_pixel_values ),
347
347
"image_num_patches" : torch .tensor (num_patches ),
348
348
"image_token_id" : torch .tensor (hf_processor .image_token_id ),
@@ -370,7 +370,7 @@ def _call_hf_processor(
370
370
video_placeholder , 1 )
371
371
372
372
num_frames = [len (item ) for item in video_pixel_values ]
373
- video_outputs : dict [ str , NestedTensors ] = {
373
+ video_outputs = {
374
374
"pixel_values_videos" : torch .concat (video_pixel_values ),
375
375
"video_num_patches" : torch .tensor (num_frames ),
376
376
"video_token_id" : torch .tensor (video_token_id ),
@@ -382,16 +382,11 @@ def _call_hf_processor(
382
382
prompt )
383
383
text_outputs = tokenizer (prompt , ** tok_kwargs , return_tensors = "pt" )
384
384
385
- combined_outputs = dict (
386
- ** text_outputs ,
387
- ** image_outputs ,
388
- ** video_outputs ,
389
- )
390
- return BatchFeature (combined_outputs )
385
+ return BatchFeature ({** text_outputs , ** image_outputs , ** video_outputs })
391
386
392
387
def _get_mm_fields_config (
393
388
self ,
394
- hf_inputs : Mapping [ str , NestedTensors ] ,
389
+ hf_inputs : BatchFeature ,
395
390
hf_processor_mm_kwargs : Mapping [str , object ],
396
391
) -> Mapping [str , MultiModalFieldConfig ]:
397
392
@@ -487,6 +482,7 @@ def get_replacement_interns1_video(item_idx: int):
487
482
dummy_inputs = InternS1DummyInputsBuilder )
488
483
class InternS1ForConditionalGeneration (nn .Module , SupportsMultiModal ,
489
484
SupportsPP , SupportsLoRA ):
485
+ merge_by_field_config = True
490
486
491
487
# To ensure correct weight loading and mapping.
492
488
hf_to_vllm_mapper = WeightsMapper (
@@ -561,7 +557,7 @@ def _init_vision_model(
561
557
prefix = prefix ,
562
558
)
563
559
564
- def _init_mlp1 (self , config : PretrainedConfig ) -> nn .Sequential :
560
+ def _init_mlp1 (self , config : PretrainedConfig ) -> nn .Module :
565
561
return InternS1MultiModalProjector (config )
566
562
567
563
def pixel_shuffle (self , x , scale_factor = 0.5 ):
@@ -599,31 +595,16 @@ def _parse_and_validate_image_input(
599
595
return None
600
596
601
597
if image_embeds is not None :
602
- if not isinstance (image_embeds , (torch .Tensor , list )):
603
- raise ValueError ("Incorrect type of image embeddings. "
604
- f"Got type: { type (image_embeds )} " )
605
-
606
598
return InternS1ImageEmbeddingInputs (
607
599
type = "image_embeds" ,
608
- data = flatten_bn ( image_embeds ) ,
600
+ data = image_embeds ,
609
601
)
610
602
611
603
image_token_id = kwargs ["image_token_id" ]
612
604
assert isinstance (image_token_id , torch .Tensor )
613
605
self .img_context_token_id = image_token_id .flatten ().unique ().item ()
614
606
615
607
if pixel_values is not None :
616
- if not isinstance (pixel_values , (torch .Tensor , list )):
617
- raise ValueError ("Incorrect type of pixel values. "
618
- f"Got type: { type (pixel_values )} " )
619
-
620
- if not isinstance (image_num_patches , (torch .Tensor , list )):
621
- raise ValueError ("Incorrect type of image_num_patches. "
622
- f"Got type: { type (image_num_patches )} " )
623
-
624
- pixel_values = flatten_bn (pixel_values , concat = True )
625
- image_num_patches = flatten_bn (image_num_patches , concat = True )
626
-
627
608
h , w = self .config .vision_config .image_size
628
609
return InternS1ImagePixelInputs (
629
610
type = "pixel_values" ,
@@ -638,7 +619,7 @@ def _parse_and_validate_image_input(
638
619
raise AssertionError ("This line should be unreachable." )
639
620
640
621
def _parse_and_validate_video_input (
641
- self , ** kwargs : object ) -> Optional [InternS1VideoPixelInputs ]:
622
+ self , ** kwargs : object ) -> Optional [InternS1VideoInputs ]:
642
623
pixel_values_flat_video = kwargs .pop ("pixel_values_videos" , None )
643
624
video_num_patches = kwargs .pop ("video_num_patches" , None )
644
625
video_embeds = kwargs .pop ("video_embeds" , None )
@@ -647,32 +628,16 @@ def _parse_and_validate_video_input(
647
628
return None
648
629
649
630
if video_embeds is not None :
650
- if not isinstance (video_embeds , (torch .Tensor , list )):
651
- raise ValueError ("Incorrect type of video embeddings. "
652
- f"Got type: { type (video_embeds )} " )
653
-
654
- return InternS1ImageEmbeddingInputs (
631
+ return InternS1VideoEmbeddingInputs (
655
632
type = "video_embeds" ,
656
- data = flatten_bn ( video_embeds ) ,
633
+ data = video_embeds ,
657
634
)
658
635
659
636
video_token_id = kwargs ["video_token_id" ]
660
637
assert isinstance (video_token_id , torch .Tensor )
661
638
self .video_context_token_id = video_token_id .flatten ().unique ().item ()
662
639
663
640
if pixel_values_flat_video is not None :
664
- if not isinstance (pixel_values_flat_video , (torch .Tensor , list )):
665
- raise ValueError ("Incorrect type of pixel values. "
666
- f"Got type: { type (pixel_values_flat_video )} " )
667
-
668
- if not isinstance (video_num_patches , (torch .Tensor , list )):
669
- raise ValueError ("Incorrect type of image_num_patches. "
670
- f"Got type: { type (video_num_patches )} " )
671
-
672
- pixel_values_flat_video = flatten_bn (pixel_values_flat_video ,
673
- concat = True )
674
- video_num_patches = flatten_bn (video_num_patches , concat = True )
675
-
676
641
h , w = self .config .vision_config .image_size
677
642
return InternS1VideoPixelInputs (
678
643
type = "pixel_values_videos" ,
@@ -686,11 +651,12 @@ def _parse_and_validate_video_input(
686
651
687
652
raise AssertionError ("This line should be unreachable." )
688
653
689
- def _process_image_input (
654
+ def _process_vision_input (
690
655
self ,
691
- image_input : Union [InternS1ImageInputs , InternS1VideoPixelInputs ],
656
+ image_input : Union [InternS1ImageInputs , InternS1VideoInputs ],
692
657
) -> tuple [torch .Tensor , ...]:
693
- if image_input ["type" ] == "image_embeds" :
658
+ if (image_input ["type" ] == "image_embeds"
659
+ or image_input ["type" ] == "video_embeds" ):
694
660
return image_input ["data" ]
695
661
696
662
assert self .vision_tower is not None
@@ -753,11 +719,11 @@ def get_multimodal_embeddings(self,
753
719
for modality in modalities :
754
720
if modality == "images" :
755
721
image_input = modalities ["images" ]
756
- vision_embeddings = self ._process_image_input (image_input )
722
+ vision_embeddings = self ._process_vision_input (image_input )
757
723
multimodal_embeddings += vision_embeddings
758
724
if modality == "videos" :
759
725
video_input = modalities ["videos" ]
760
- video_embeddings = self ._process_image_input (video_input )
726
+ video_embeddings = self ._process_vision_input (video_input )
761
727
multimodal_embeddings += video_embeddings
762
728
763
729
return multimodal_embeddings
0 commit comments