1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
from collections .abc import Iterable , Mapping , Sequence
4
- from typing import Any , Literal , Optional , TypedDict , Union , cast
4
+ from typing import Annotated , Any , Literal , Optional , Union , cast
5
5
6
6
import numpy as np
7
7
import torch
41
41
# yapf: enable
42
42
from vllm .multimodal .profiling import BaseDummyInputsBuilder
43
43
from vllm .sequence import IntermediateTensors
44
+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
44
45
45
46
from .interfaces import (MultiModalEmbeddings , SupportsMultiModal ,
46
47
SupportsTranscription )
54
55
TOKENS_PER_AUDIO = 188
55
56
56
57
57
- class Gemma3nImagePixelInputs (TypedDict ):
58
- pixel_values : torch .Tensor
59
- """Shape: `(batch_size * num_images, num_channels, height, width)`"""
58
+ class Gemma3nImagePixelInputs (TensorSchema ):
59
+ """
60
+ Dimensions:
61
+ - bn: Batch size * number of images
62
+ - c: Number of channels (3)
63
+ - h: Height of each patch
64
+ - w: Width of each patch
65
+ """
66
+ type : Literal ["pixel_values" ] = "pixel_values"
67
+ pixel_values : Annotated [torch .Tensor , TensorShape ("bn" , 3 , "h" , "w" )]
60
68
61
69
62
- class Gemma3nAudioInputs (TypedDict ):
63
- input_features : Union [torch .Tensor , list [torch .Tensor ]]
64
- input_features_padded : torch .Tensor
65
- """Shape: `(batch_size * num_audio, seq_length, num_features)`"""
66
- input_features_mask : torch .Tensor
67
- """Shape: `(batch_size * num_audio, seq_length)`"""
70
+ class Gemma3nAudioInputs (TensorSchema ):
71
+ """
72
+ Dimensions:
73
+ - bn: Batch size * number of audios
74
+ - s: seq_length
75
+ - f: num_features
76
+ """
77
+ type : Literal ["audio" ] = "audio"
78
+ input_features_padded : Annotated [torch .Tensor , TensorShape ("bn" , "s" , "f" )]
79
+ input_features_mask : Annotated [torch .Tensor , TensorShape ("bn" , "s" )]
68
80
69
81
70
82
Gemma3nImageInputs = Gemma3nImagePixelInputs
@@ -212,9 +224,9 @@ def _get_mm_fields_config(
212
224
213
225
return dict (
214
226
pixel_values = MultiModalFieldConfig .batched ("image" ),
215
- input_features = MultiModalFieldConfig .batched ("audio" ),
216
227
input_features_padded = MultiModalFieldConfig .batched ("audio" ),
217
- input_features_mask = MultiModalFieldConfig .batched ("audio" ))
228
+ input_features_mask = MultiModalFieldConfig .batched ("audio" ),
229
+ )
218
230
219
231
def _get_prompt_updates (
220
232
self ,
@@ -422,6 +434,7 @@ def forward(
422
434
dummy_inputs = Gemma3nDummyInputsBuilder )
423
435
class Gemma3nForConditionalGeneration (nn .Module , SupportsMultiModal ,
424
436
SupportsTranscription ):
437
+ merge_by_field_config = True
425
438
supported_languages = ISO639_1_SUPPORTED_LANGS
426
439
427
440
packed_modules_mapping = {
@@ -482,14 +495,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
482
495
device = self .language_model .model .embed_tokens .weight .device ,
483
496
dtype = self .language_model .model .embed_tokens .weight .dtype )
484
497
485
- @property
486
- def dtype (self ):
487
- return next (self .parameters ()).dtype
488
-
489
- def _validate_pixel_values (self , data : torch .Tensor ) -> torch .Tensor :
490
- # TODO check if there are any
491
- return data
492
-
493
498
def _parse_and_validate_image_input (
494
499
self , ** kwargs : object ) -> Optional [Gemma3nImageInputs ]:
495
500
pixel_values = kwargs .pop ("pixel_values" , None )
@@ -499,34 +504,22 @@ def _parse_and_validate_image_input(
499
504
if pixel_values is None :
500
505
return None
501
506
502
- if not isinstance (pixel_values , (torch .Tensor , list )):
503
- raise ValueError ("Incorrect type of pixel values. "
504
- f"Got type: { type (pixel_values )} " )
505
-
506
- pixel_values = flatten_bn (pixel_values , concat = True )
507
- pixel_values = pixel_values .contiguous ()
508
-
509
- return Gemma3nImagePixelInputs (
510
- pixel_values = self ._validate_pixel_values (pixel_values ), )
507
+ return Gemma3nImagePixelInputs (pixel_values = pixel_values )
511
508
512
509
def _parse_and_validate_audio_input (
513
510
self , ** kwargs : object ) -> Optional [Gemma3nAudioInputs ]:
514
- input_features = kwargs .pop ("input_features" , None )
515
- if input_features is None :
511
+
512
+ input_features_padded = kwargs .pop ("input_features_padded" , None )
513
+ if input_features_padded is None :
516
514
return None
517
515
518
516
input_features_mask = kwargs .pop ("input_features_mask" , None )
519
517
if input_features_mask is None :
520
518
return None
521
519
522
- input_features_padded = kwargs .pop ("input_features_padded" , None )
523
- if input_features_padded is None :
524
- return None
525
-
526
520
return Gemma3nAudioInputs (
527
- input_features = input_features ,
528
- input_features_mask = input_features_mask ,
529
521
input_features_padded = input_features_padded ,
522
+ input_features_mask = input_features_mask ,
530
523
)
531
524
532
525
def _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) -> dict :
@@ -539,7 +532,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
539
532
) and "image" not in mm_input_by_modality :
540
533
mm_input_by_modality [
541
534
"image" ] = self ._parse_and_validate_image_input (** kwargs )
542
- if input_key == "input_features " \
535
+ if input_key == "input_features_padded " \
543
536
and "audio" not in mm_input_by_modality :
544
537
mm_input_by_modality [
545
538
"audio" ] = self ._parse_and_validate_audio_input (** kwargs )
0 commit comments