@@ -313,7 +313,7 @@ def forward_context(self, model, inputs):
313313 inputs ['position_ids' ] = position_ids [1 :]
314314 inputs ['text_position_ids' ] = text_position_ids = position_ids [0 ]
315315 transformers_version = version .parse (transformers .__version__ )
316- if transformers_version >= version .parse ('4.53' ):
316+ if transformers_version >= version .parse ('4.53' ) and text_position_ids . shape [ 0 ] == 1 :
317317 # https://github.com/huggingface/transformers/pull/40194
318318 inputs .update (get_packed_seq_params (text_position_ids ))
319319 return super ().forward_context (model , inputs )
@@ -372,8 +372,7 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
372372 inputs .get ('video_grid_thw' ),
373373 attention_mask = inputs .get ('attention_mask' ),
374374 ** kwargs )
375- text_position_ids = torch .arange (inputs ['input_ids' ].shape [- 1 ])
376- return torch .concat ([text_position_ids [None , None ], position_ids ], dim = 0 )
375+ return self ._concat_text_position_ids (position_ids )
377376
378377 def _data_collator (self , batch : List [Dict [str , Any ]], * , padding_to : Optional [int ] = None ) -> Dict [str , Any ]:
379378 res = super ()._data_collator (batch , padding_to = padding_to )
@@ -591,8 +590,7 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
591590 audio_feature_lengths ,
592591 video_second_per_grid ,
593592 )
594- text_position_ids = torch .arange (inputs ['input_ids' ].shape [- 1 ])
595- return torch .concat ([text_position_ids [None , None ], position_ids ], dim = 0 )
593+ return self ._concat_text_position_ids (position_ids )
596594
597595 def _data_collator_mm_data (self , batch : List [Dict [str , Any ]]) -> Dict [str , Any ]:
598596 res = super ()._data_collator_mm_data (batch )
0 commit comments