Skip to content

Commit 10a1143

Browse files
authored
fix qwen2-vl position_ids (#2461)
1 parent 5c22918 commit 10a1143

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

swift/llm/utils/template.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,18 +1627,20 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
16271627
+ 1:]
16281628
added_tokens_len += token_len - 1
16291629
data.update(media_inputs)
1630-
1631-
inputs['input_ids'] = input_ids
1630+
# The architecture will be optimized in ms-swift3.0
1631+
data['input_ids'] = input_ids
16321632
inputs['labels'] = labels
1633-
data['input_ids'] = torch.tensor(input_ids)[None]
16341633
inputs['_data'] = data
1634+
inputs.update(data)
16351635
return inputs, {}
16361636

16371637
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
1638+
if not self._is_training:
1639+
return data
16381640
_model = model.model
16391641
if not hasattr(_model, 'embed_tokens'):
16401642
_model = _model.model # LoRA
1641-
input_ids = data['input_ids']
1643+
input_ids = torch.tensor(data['input_ids'], device=model.device)[None]
16421644
pixel_values = data.get('pixel_values')
16431645
pixel_values_videos = data.get('pixel_values_videos')
16441646
inputs_embeds = _model.embed_tokens(input_ids)

0 commit comments

Comments
 (0)