2121from transformers .integrations import is_deepspeed_zero3_enabled
2222from transformers .utils import strtobool
2323
24+ from swift .llm import to_device
2425from swift .utils import get_env_args , get_logger
2526from ..utils import Processor , ProcessorMixin
2627from .template_inputs import InferRequest , StdTemplateInputs , TemplateInputs
@@ -1349,13 +1350,12 @@ def post_process_generate_response(self, response: str, inputs: StdTemplateInput
13491350 return response
13501351
13511352 def pre_forward_hook (self , model : nn .Module , args , kwargs ):
1352- from swift .llm import to_device
13531353 old_kwargs = to_device (kwargs , model .device )
13541354 kwargs = to_device (self ._post_encode (model , old_kwargs ), model .device )
13551355 for k , v in old_kwargs .items ():
13561356 if k in {
13571357 'input_ids' , 'attention_mask' , 'labels' , 'position_ids' , 'output_hidden_states' , 'logits_to_keep' ,
1358- 'cumulative_seqlens_q ' , 'cumulative_seqlens_k ' , 'max_length_q ' , 'max_length_k '
1358+ 'max_length_q ' , 'max_length_k ' , 'cu_seq_lens_q ' , 'cu_seq_lens_k '
13591359 } and k not in kwargs :
13601360 kwargs [k ] = v
13611361 if 'inputs_embeds' in kwargs :
@@ -1629,7 +1629,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
16291629 res = {}
16301630 if self .padding_free :
16311631 assert len (batch ) == 1 , f'batch: { batch } '
1632- for k in ['input_ids' , 'labels' , 'position_ids' , 'loss_scale' , 'channel' , 'real_position_ids' ]:
1632+ for k in ['input_ids' , 'labels' , 'position_ids' , 'loss_scale' , 'channel' ]:
16331633 v = batch [0 ].get (k )
16341634 if v is not None :
16351635 res [k ] = v if k == 'channel' else [v ]
@@ -1651,10 +1651,15 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
16511651 res [key ] = val
16521652
16531653 keys = [
1654- 'input_ids' , 'inputs_embeds' , 'attention_mask' , 'labels' , 'loss_scale' , 'position_ids' , 'token_type_ids' ,
1655- 'real_position_ids'
1654+ 'input_ids' ,
1655+ 'inputs_embeds' ,
1656+ 'attention_mask' ,
1657+ 'labels' ,
1658+ 'loss_scale' ,
1659+ 'position_ids' ,
1660+ 'token_type_ids' ,
16561661 ]
1657- pad_values = [self .tokenizer .pad_token_id , 0. , 0 , - 100 , 0. , 0. , 0 , 0. ]
1662+ pad_values = [self .tokenizer .pad_token_id , 0. , 0 , - 100 , 0. , 0. , 0 ]
16581663 # Convert to tensor and remove unnecessary dimensions.
16591664 seq_lens = None
16601665 for key in keys :
@@ -1681,16 +1686,13 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
16811686 if self .padding_free :
16821687 cp_size = self .sequence_parallel_size
16831688 if cp_size > 1 :
1684- for key in ['position_ids' , 'real_position_ids' ]:
1685- if key not in res :
1686- continue
1687- padding_len = padding_to - seq_lens [0 ]
1688- position_ids = res [key ][0 ]
1689- extended_position_ids = torch .arange (cp_size * 2 ).repeat (padding_len // (cp_size * 2 ))
1690- if position_ids .ndim == 3 : # compat mrope
1691- extended_position_ids = extended_position_ids [None ,
1692- None , :].expand (position_ids .shape [0 ], 1 , - 1 )
1693- res [key ] = [torch .concat ([position_ids , extended_position_ids ], dim = - 1 )]
1689+ padding_len = padding_to - seq_lens [0 ]
1690+ position_ids = res ['position_ids' ][0 ]
1691+ extended_position_ids = torch .arange (cp_size * 2 ).repeat (padding_len // (cp_size * 2 ))
1692+ if position_ids .ndim == 3 : # compat mrope
1693+ extended_position_ids = extended_position_ids [None ,
1694+ None , :].expand (position_ids .shape [0 ], 1 , - 1 )
1695+ res ['position_ids' ] = [torch .concat ([position_ids , extended_position_ids ], dim = - 1 )]
16941696 else :
16951697 seq_len = max (seq_lens ) if padding_to is None else padding_to
16961698 res ['attention_mask' ] = torch .tril (torch .ones (
@@ -1704,13 +1706,13 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
17041706 continue
17051707 if self .use_megatron and not self .padding_free and key == 'attention_mask' :
17061708 continue
1707- if padding_to is not None and not (self .padding_free and key in { 'position_ids' , 'real_position_ids' }
1709+ if padding_to is not None and not (self .padding_free and key == 'position_ids'
17081710 and self .sequence_parallel_size > 1 ):
17091711 padding_len = padding_to - seq_lens [0 ]
17101712 if padding_len > 0 :
17111713 res [key ][0 ] = F .pad (res [key ][0 ], (0 , padding_len ) if padding_right else (padding_len , 0 ),
17121714 'constant' , pad_value )
1713- if key == 'real_position_ids' :
1715+ if key == 'position_ids' and res [ key ][ 0 ]. ndim == 3 :
17141716 res [key ] = torch .concat (res [key ], dim = - 1 )
17151717 else :
17161718 res [key ] = self ._pad_sequence (res [key ], pad_value )
@@ -1951,3 +1953,53 @@ def _flash_attention_forward(*args, **kwargs):
19511953 yield
19521954 finally :
19531955 modeling_module ._flash_attention_forward = _origin_flash_attention_forward
1956+
1957+ @staticmethod
1958+ def _get_inputs_embeds_hf (inputs_embeds , inputs , visual , processor , config ):
1959+ input_ids = inputs ['input_ids' ]
1960+ pixel_values = inputs .get ('pixel_values' )
1961+ pixel_values_videos = inputs .get ('pixel_values_videos' )
1962+ image_grid_thw = inputs .get ('image_grid_thw' )
1963+ video_grid_thw = inputs .get ('video_grid_thw' )
1964+ dtype = visual .dtype
1965+ if pixel_values is None and pixel_values_videos is None : # plain-text
1966+ images = [Image .new ('RGB' , (32 , 32 ), (0 , 0 , 0 ))]
1967+ media_inputs = processor .image_processor (images = images , return_tensors = 'pt' )
1968+ media_inputs = to_device (media_inputs , input_ids .device )
1969+ pixel_values = media_inputs ['pixel_values' ].type (dtype )
1970+ image_embeds = visual (pixel_values , grid_thw = media_inputs ['image_grid_thw' ])
1971+ inputs_embeds = inputs_embeds + image_embeds .mean () * 0.
1972+ else :
1973+ if pixel_values is None :
1974+ pixel_values_mixed = pixel_values_videos
1975+ grid_thw = video_grid_thw
1976+ elif pixel_values_videos is None :
1977+ pixel_values_mixed = pixel_values
1978+ grid_thw = image_grid_thw
1979+ else :
1980+ pixel_values_mixed = torch .concat ([pixel_values , pixel_values_videos ], dim = 0 )
1981+ grid_thw = torch .concat ([image_grid_thw , video_grid_thw ], dim = 0 )
1982+ pixel_values_mixed = pixel_values_mixed .type (dtype )
1983+ mixed_embeds = visual (pixel_values_mixed , grid_thw = grid_thw )
1984+ if pixel_values is None :
1985+ image_embeds = None
1986+ video_embeds = mixed_embeds
1987+ elif pixel_values_videos is None :
1988+ image_embeds = mixed_embeds
1989+ video_embeds = None
1990+ else :
1991+ merge_length = processor .image_processor .merge_size ** 2
1992+ image_tokens = (image_grid_thw .prod (dim = - 1 ) // merge_length ).sum ()
1993+ image_embeds = mixed_embeds [:image_tokens ]
1994+ video_embeds = mixed_embeds [image_tokens :]
1995+
1996+ if image_embeds is not None :
1997+ image_mask = (input_ids == config .image_token_id ).unsqueeze (- 1 ).expand_as (inputs_embeds )
1998+ image_embeds = image_embeds .to (inputs_embeds .device , inputs_embeds .dtype )
1999+ inputs_embeds = inputs_embeds .masked_scatter (image_mask , image_embeds )
2000+
2001+ if video_embeds is not None :
2002+ video_mask = (input_ids == config .video_token_id ).unsqueeze (- 1 ).expand_as (inputs_embeds )
2003+ video_embeds = video_embeds .to (inputs_embeds .device , inputs_embeds .dtype )
2004+ inputs_embeds = inputs_embeds .masked_scatter (video_mask , video_embeds )
2005+ return inputs_embeds
0 commit comments