@@ -3560,6 +3560,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
35603560 labels = inputs ['labels' ]
35613561 idx_list = _findall (input_ids , - 100 )
35623562 processor = self .tokenizer .processor
3563+ inputs = {'_data' : {}}
35633564 if images :
35643565 image_inputs = processor .image_processor (images , cut_enable = cut_enable , return_tensors = 'pt' )
35653566 added_tokens_len = 0
@@ -3579,21 +3580,23 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
35793580 _range = torch .arange (len (input_ids ))[:, None ]
35803581 matrix = (_range > image_token_idx ).sum (dim = 1 )
35813582 media_offset = torch .stack ([torch .zeros (matrix .shape [0 ], dtype = torch .long ), matrix ], dim = - 1 )[None ]
3582- inputs ['_data' ] = {'pixel_values' : image_inputs ['pixel_values' ]}
3583- inputs ['media_offset' ] = media_offset
3584- inputs ['num_images' ] = image_inputs ['pixel_values' ].shape [0 ]
3585- inputs ['input_ids' ] = input_ids
3583+ inputs ['_data' ].update ({
3584+ 'pixel_values' : image_inputs ['pixel_values' ],
3585+ 'media_offset' : media_offset ,
3586+ })
3587+ inputs ['_data' ]['input_ids' ] = input_ids
35863588 inputs ['labels' ] = labels
35873589 return inputs , {}
35883590
35893591 def _post_encode (self , model , data : Any ) -> Dict [str , Any ]:
3590- image_embeds = model .forward_image (data ['pixel_values' ])
3591- return {'image_embeds' : image_embeds }
3592+ if 'pixel_values' in data :
3593+ pixel_values = data .pop ('pixel_values' )
3594+ data ['image_embeds' ] = model .forward_image (pixel_values )
3595+ return data
35923596
35933597 def data_collator (self , batch : List [Dict [str , Any ]], padding_to : Optional [int ] = None ) -> Dict [str , Any ]:
35943598 res = super ().data_collator (batch , padding_to )
35953599 image_embeds = [b ['image_embeds' ] for b in batch if 'image_embeds' in b ]
3596- num_images = [b ['num_images' ] if 'num_images' in b else 0 for b in batch ]
35973600 if image_embeds :
35983601 res ['image_embeds' ] = torch .concat (image_embeds )
35993602 media_offset = []
@@ -3609,7 +3612,7 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
36093612 curr_media_offset .shape [2 ])
36103613 curr_media_offset = torch .concat ([curr_media_offset , padding ], dim = 1 )
36113614 media_offset .append (curr_media_offset + cusum_offset )
3612- cusum_offset += num_images [bi ]
3615+ cusum_offset += image_embeds [bi ]. shape [ 0 ]
36133616
36143617 # media_offset = [b['media_offset'] for b in batch if 'media_offset' in b]
36153618
0 commit comments