Skip to content

Commit 71f3ba4

Browse files
CrownStar7Jintao-Huang
authored andcommitted
[template] fix/pixtral/pixel_values & image_sizes (#4982)
1 parent 9fe2c76 commit 71f3ba4

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

swift/llm/template/template/pixtral.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from typing import Any, Dict, List, Optional
33

4+
import torch
5+
46
from ..base import Template
57
from ..constant import MLLMTemplateType
68
from ..register import TemplateMeta, register_template
@@ -22,8 +24,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
2224
idx_list = findall(input_ids, 10)
2325
if idx_list:
2426
image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
25-
encoded['pixel_values'] = image_inputs['pixel_values'][0]
26-
image_sizes = image_inputs['image_sizes'][0]
27+
encoded['pixel_values'] = image_inputs['pixel_values']
28+
encoded['image_sizes'] = image_sizes = image_inputs['image_sizes']
2729

2830
def _get_new_tokens(i):
2931
height, width = image_sizes[i]
@@ -44,9 +46,14 @@ def _get_new_tokens(i):
4446

4547
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
4648
pixel_values = self.gather_list(batch, 'pixel_values')
49+
image_sizes = self.gather_list(batch, 'image_sizes')
4750
res = super()._data_collator(batch, padding_to=padding_to)
4851
if pixel_values:
52+
pixel_values = torch.stack(pixel_values)
4953
res['pixel_values'] = pixel_values
54+
if image_sizes:
55+
image_sizes = torch.stack(image_sizes)
56+
res['image_sizes'] = image_sizes
5057
return res
5158

5259

0 commit comments

Comments
 (0)