@@ -311,19 +311,22 @@ def _get_new_tokens(i):
311311 def forward_context (self , model , inputs ):
312312 if 'real_position_ids' not in inputs :
313313 return super ().forward_context (model , inputs )
314- position_ids = inputs ['position_ids' ]
314+ text_position_ids = inputs ['position_ids' ]
315315 inputs ['position_ids' ] = inputs .pop ('real_position_ids' )
316- transformers_ge_453 = version .parse (transformers .__version__ ) >= version .parse ('4.53' )
317- if transformers_ge_453 :
318- inputs .update (get_packed_seq_params (position_ids ))
316+ transformers_version = version .parse (transformers .__version__ )
317+ if transformers_version >= version .parse ('4.53' ):
318+ if transformers_version >= version .parse ('4.56' ):
319+ inputs ['position_ids' ] = torch .concat ([text_position_ids [None ], inputs ['position_ids' ]], dim = 0 )
320+ else :
321+ inputs .update (get_packed_seq_params (text_position_ids ))
319322 return super ().forward_context (model , inputs )
320323 if self .version == 'v2' :
321324 from transformers .models .qwen2_vl import modeling_qwen2_vl as modeling_module
322325 elif self .version == 'v2_5' :
323326 from transformers .models .qwen2_5_vl import modeling_qwen2_5_vl as modeling_module
324327 elif self .version == 'omni' :
325328 from transformers .models .qwen2_5_omni import modeling_qwen2_5_omni as modeling_module
326- return self ._patch_flash_attention_forward (modeling_module , position_ids )
329+ return self ._patch_flash_attention_forward (modeling_module , text_position_ids )
327330
328331 def _post_encode (self , model , inputs : Dict [str , Any ]) -> Dict [str , Any ]:
329332 if not self .is_training :
0 commit comments