@@ -1627,18 +1627,20 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
16271627 + 1 :]
16281628 added_tokens_len += token_len - 1
16291629 data .update (media_inputs )
1630-
1631- inputs ['input_ids' ] = input_ids
1630+ # The architecture will be optimized in ms-swift3.0
1631+ data ['input_ids' ] = input_ids
16321632 inputs ['labels' ] = labels
1633- data ['input_ids' ] = torch .tensor (input_ids )[None ]
16341633 inputs ['_data' ] = data
1634+ inputs .update (data )
16351635 return inputs , {}
16361636
16371637 def _post_encode (self , model , data : Any ) -> Dict [str , Any ]:
1638+ if not self ._is_training :
1639+ return data
16381640 _model = model .model
16391641 if not hasattr (_model , 'embed_tokens' ):
16401642 _model = _model .model # LoRA
1641- input_ids = data ['input_ids' ]
1643+ input_ids = torch . tensor ( data ['input_ids' ], device = model . device )[ None ]
16421644 pixel_values = data .get ('pixel_values' )
16431645 pixel_values_videos = data .get ('pixel_values_videos' )
16441646 inputs_embeds = _model .embed_tokens (input_ids )
@@ -1685,10 +1687,6 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
16851687 res ['position_ids' ] = position_ids .contiguous ()
16861688 return res
16871689
1688- @staticmethod
1689- def _get_generate_ids (generate_ids : List [int ], input_token_len : int ) -> List [int ]:
1690- return generate_ids
1691-
16921690
16931691class Qwen2VLTemplate (_Qwen2VLTemplateMixin , QwenTemplate ):
16941692 pass
0 commit comments