2929from .siglip import (SiglipVisionModel , dummy_image_for_siglip ,
3030 dummy_seq_data_for_siglip , get_siglip_image_feature_size ,
3131 get_siglip_patch_grid_length , input_processor_for_siglip )
32- from .utils import (filter_weights , init_vllm_registered_model ,
32+ from .utils import (filter_weights , flatten_bn , init_vllm_registered_model ,
3333 merge_multimodal_embeddings )
3434
3535logger = init_logger (__name__ )
@@ -47,15 +47,16 @@ class LlavaNextImagePixelInputs(TypedDict):
4747 type : Literal ["pixel_values" ]
4848 data : Union [torch .Tensor , List [torch .Tensor ]]
4949 """
50- Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
50+ Shape:
51+ `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
5152
52- Note that `num_patches` may be different for each batch, in which case
53- the data is passed as a list instead of a batched tensor.
53+ Note that `num_patches` may be different per batch and image,
54+ in which case the data is passed as a list instead of a batched tensor.
5455 """
5556
5657 image_sizes : NotRequired [torch .Tensor ]
5758 """
58- Shape: `(batch_size, 2)`
59+ Shape: `(batch_size * num_images , 2)`
5960
6061 This should be in `(height, width)` format.
6162 """
@@ -64,7 +65,7 @@ class LlavaNextImagePixelInputs(TypedDict):
6465class LlavaNextImageEmbeddingInputs (TypedDict ):
6566 type : Literal ["image_embeds" ]
6667 data : torch .Tensor
67- """Shape: `(batch_size, image_feature_size, hidden_size)`
68+ """Shape: `(batch_size * num_images , image_feature_size, hidden_size)`
6869
6970 `hidden_size` must match the hidden size of language model backbone.
7071 """
@@ -315,10 +316,19 @@ def __init__(self,
315316 torch .empty (config .text_config .hidden_size ))
316317
317318 def _validate_image_sizes (self , data : torch .Tensor ) -> torch .Tensor :
318- if list (data .shape [1 :]) != [2 ]:
319- raise ValueError (
320- f"The expected image sizes shape is batch dimension plus "
321- f"{ [2 ]} . You supplied { data .shape } ." )
319+ expected_dims = (2 , )
320+
321+ def _validate_shape (d : torch .Tensor ):
322+ actual_dims = tuple (d .shape )
323+
324+ if actual_dims != expected_dims :
325+ expected_expr = str (expected_dims )
326+ raise ValueError (
327+ f"The expected shape of image sizes per image per batch "
328+ f"is { expected_expr } . You supplied { tuple (d .shape )} ." )
329+
330+ for d in data :
331+ _validate_shape (d )
322332
323333 return data
324334
@@ -335,7 +345,7 @@ def _validate_shape(d: torch.Tensor):
335345 if actual_dims != expected_dims :
336346 expected_expr = ("num_patches" , * map (str , expected_dims ))
337347 raise ValueError (
338- "The expected shape of pixel values in each batch element "
348+ "The expected shape of pixel values per image per batch "
339349 f"is { expected_expr } . You supplied { tuple (d .shape )} ." )
340350
341351 for d in data :
@@ -357,35 +367,25 @@ def _parse_and_validate_image_input(
357367 raise ValueError ("Incorrect type of pixel values. "
358368 f"Got type: { type (pixel_values )} " )
359369
360- if not isinstance (image_sizes , torch .Tensor ):
370+ if not isinstance (image_sizes , ( torch .Tensor , list ) ):
361371 raise ValueError ("Incorrect type of image sizes. "
362372 f"Got type: { type (image_sizes )} " )
363373
364- # Remove the N dimension until multiple images are supported.
365- if isinstance (pixel_values , torch .Tensor ):
366- pixel_values = pixel_values .squeeze (1 )
367- else :
368- pixel_values = [t .squeeze (0 ) for t in pixel_values ]
369-
370- image_sizes = image_sizes .squeeze (1 )
371-
372374 return LlavaNextImagePixelInputs (
373375 type = "pixel_values" ,
374- data = self ._validate_pixel_values (pixel_values ),
375- image_sizes = self ._validate_image_sizes (image_sizes ),
376+ data = self ._validate_pixel_values (flatten_bn (pixel_values )),
377+ image_sizes = self ._validate_image_sizes (
378+ flatten_bn (image_sizes , concat = True )),
376379 )
377380
378381 if image_embeds is not None :
379382 if not isinstance (image_embeds , torch .Tensor ):
380383 raise ValueError ("Incorrect type of image embeds. "
381384 f"Got type: { type (image_embeds )} " )
382385
383- # Remove the N dimension until multiple images are supported.
384- image_embeds = image_embeds .squeeze (1 )
385-
386386 return LlavaNextImageEmbeddingInputs (
387387 type = "image_embeds" ,
388- data = image_embeds ,
388+ data = flatten_bn ( image_embeds ) ,
389389 )
390390
391391 raise AssertionError ("This line should be unreachable." )
0 commit comments