Skip to content

Commit 3562266

Browse files
committed
[bugfix] fix text_position_ids (#5692)
1 parent 1bc3adc commit 3562266

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

swift/llm/template/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,3 +2003,9 @@ def _get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config):
20032003
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
20042004
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
20052005
return inputs_embeds
2006+
2007+
@staticmethod
2008+
def _concat_text_position_ids(position_ids):
2009+
seq_len = position_ids.shape[-1]
2010+
text_position_ids = torch.arange(seq_len, device=position_ids.device).expand(1, *position_ids.shape[1:])
2011+
return torch.concat([text_position_ids, position_ids], dim=0)

swift/llm/template/template/glm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,15 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
317317
inputs.get('image_grid_thw'),
318318
inputs.get('video_grid_thw'),
319319
attention_mask=inputs.get('attention_mask'))
320-
text_position_ids = torch.arange(inputs['input_ids'].shape[-1])
321-
return torch.concat([text_position_ids[None, None], position_ids], dim=0)
320+
return self._concat_text_position_ids(position_ids)
322321

323322
def forward_context(self, model, inputs):
324323
position_ids = inputs['position_ids']
325324
inputs['position_ids'] = position_ids[1:]
326-
inputs['text_position_ids'] = position_ids[0]
325+
inputs['text_position_ids'] = text_position_ids = position_ids[0]
327326
# https://github.com/huggingface/transformers/pull/40194
328-
inputs.update(get_packed_seq_params(inputs['text_position_ids']))
327+
if text_position_ids.shape[0] == 1:
328+
inputs.update(get_packed_seq_params(text_position_ids))
329329
return super().forward_context(model, inputs)
330330

331331
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:

swift/llm/template/template/qwen.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def forward_context(self, model, inputs):
313313
inputs['position_ids'] = position_ids[1:]
314314
inputs['text_position_ids'] = text_position_ids = position_ids[0]
315315
transformers_version = version.parse(transformers.__version__)
316-
if transformers_version >= version.parse('4.53'):
316+
if transformers_version >= version.parse('4.53') and text_position_ids.shape[0] == 1:
317317
# https://github.com/huggingface/transformers/pull/40194
318318
inputs.update(get_packed_seq_params(text_position_ids))
319319
return super().forward_context(model, inputs)
@@ -372,8 +372,7 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
372372
inputs.get('video_grid_thw'),
373373
attention_mask=inputs.get('attention_mask'),
374374
**kwargs)
375-
text_position_ids = torch.arange(inputs['input_ids'].shape[-1])
376-
return torch.concat([text_position_ids[None, None], position_ids], dim=0)
375+
return self._concat_text_position_ids(position_ids)
377376

378377
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
379378
res = super()._data_collator(batch, padding_to=padding_to)
@@ -591,8 +590,7 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
591590
audio_feature_lengths,
592591
video_second_per_grid,
593592
)
594-
text_position_ids = torch.arange(inputs['input_ids'].shape[-1])
595-
return torch.concat([text_position_ids[None, None], position_ids], dim=0)
593+
return self._concat_text_position_ids(position_ids)
596594

597595
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
598596
res = super()._data_collator_mm_data(batch)

0 commit comments

Comments
 (0)