Skip to content

Commit dba7c55

Browse files
support internvl2 multi imgs and single txt with bs=1 (#190)
1 parent a1b4ba9 commit dba7c55

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

llmc/data/dataset/specified_preproc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,11 @@ def vlm_general(calib_dataset, tokenizer, batch_process, n_samples):
108108
img_qas = json.load(fp)
109109
for idx in range(len(img_qas)):
110110
if 'img' in img_qas[idx]:
111-
img_qas[idx]['img'] = os.path.join(calib_dataset, img_qas[idx]['img'])
111+
if isinstance(img_qas[idx]['img'], list):
112+
for img_idx in range(len(img_qas[idx]['img'])):
113+
img_qas[idx]['img'][img_idx] = os.path.join(calib_dataset, img_qas[idx]['img'][img_idx]) # noqa
114+
else:
115+
img_qas[idx]['img'] = os.path.join(calib_dataset, img_qas[idx]['img'])
112116
else:
113117
img_qas[idx]['img'] = None
114118
random.shuffle(img_qas)

llmc/models/internvl2.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,31 @@ def batch_process(self, img_qas):
184184
}
185185
return inputs
186186

187-
def single_process(self, img_qas):
187+
def single_process(self, img_qa):
188188
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
189-
if img_qas['img'] is not None:
190-
pixel_values = load_image(img_qas['img'], max_num=12).to(
191-
next(self.vlm_model.parameters()).dtype
192-
)
189+
num_patches_list = None
190+
pixel_values_list = []
191+
if img_qa['img'] is not None:
192+
if isinstance(img_qa['img'], list):
193+
num_patches_list = []
194+
for img_idx in range(len(img_qa['img'])):
195+
pixel_values = load_image(img_qa['img'][img_idx], max_num=12).to(
196+
next(self.vlm_model.parameters()).dtype
197+
)
198+
pixel_values_list.append(pixel_values)
199+
num_patches_list.append(pixel_values.size(0))
200+
pixel_values = torch.cat(pixel_values_list, dim=0)
201+
else:
202+
pixel_values = load_image(img_qa['img'], max_num=12).to(
203+
next(self.vlm_model.parameters()).dtype
204+
)
193205
else:
194206
pixel_values = None
195-
question = img_qas['question']
207+
question = img_qa['question']
196208
if pixel_values is not None and '<image>' not in question:
197209
question = '<image>\n' + question
198-
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
210+
if num_patches_list is None:
211+
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
199212
generation_config = dict()
200213

201214
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'

0 commit comments

Comments
 (0)