11# Copyright (c) Alibaba, Inc. and its affiliates.
22from typing import Any , Dict , List , Optional
33
4+ import torch
5+
46from ..base import Template
57from ..constant import MLLMTemplateType
68from ..register import TemplateMeta , register_template
@@ -22,8 +24,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
2224 idx_list = findall (input_ids , 10 )
2325 if idx_list :
2426 image_inputs = processor .image_processor (images , patch_size = processor .patch_size , return_tensors = 'pt' )
25- encoded ['pixel_values' ] = image_inputs ['pixel_values' ][ 0 ]
26- image_sizes = image_inputs ['image_sizes' ][ 0 ]
27+ encoded ['pixel_values' ] = image_inputs ['pixel_values' ]
28+ encoded [ ' image_sizes' ] = image_sizes = image_inputs ['image_sizes' ]
2729
2830 def _get_new_tokens (i ):
2931 height , width = image_sizes [i ]
@@ -44,9 +46,14 @@ def _get_new_tokens(i):
4446
4547 def _data_collator (self , batch : List [Dict [str , Any ]], * , padding_to : Optional [int ] = None ) -> Dict [str , Any ]:
4648 pixel_values = self .gather_list (batch , 'pixel_values' )
49+ image_sizes = self .gather_list (batch , 'image_sizes' )
4750 res = super ()._data_collator (batch , padding_to = padding_to )
4851 if pixel_values :
52+ pixel_values = torch .stack (pixel_values )
4953 res ['pixel_values' ] = pixel_values
54+ if image_sizes :
55+ image_sizes = torch .stack (image_sizes )
56+ res ['image_sizes' ] = image_sizes
5057 return res
5158
5259
0 commit comments