diff --git a/llmc/data/dataset/specified_preproc.py b/llmc/data/dataset/specified_preproc.py index 4302004c9..56649db10 100644 --- a/llmc/data/dataset/specified_preproc.py +++ b/llmc/data/dataset/specified_preproc.py @@ -108,7 +108,11 @@ def vlm_general(calib_dataset, tokenizer, batch_process, n_samples): img_qas = json.load(fp) for idx in range(len(img_qas)): if 'img' in img_qas[idx]: - img_qas[idx]['img'] = os.path.join(calib_dataset, img_qas[idx]['img']) + if isinstance(img_qas[idx]['img'], list): + for img_idx in range(len(img_qas[idx]['img'])): + img_qas[idx]['img'][img_idx] = os.path.join(calib_dataset, img_qas[idx]['img'][img_idx]) # noqa + else: + img_qas[idx]['img'] = os.path.join(calib_dataset, img_qas[idx]['img']) else: img_qas[idx]['img'] = None random.shuffle(img_qas) diff --git a/llmc/models/internvl2.py b/llmc/models/internvl2.py index b3b90c660..7eefd95fb 100644 --- a/llmc/models/internvl2.py +++ b/llmc/models/internvl2.py @@ -184,18 +184,31 @@ def batch_process(self, img_qas): } return inputs - def single_process(self, img_qas): + def single_process(self, img_qa): tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) - if img_qas['img'] is not None: - pixel_values = load_image(img_qas['img'], max_num=12).to( - next(self.vlm_model.parameters()).dtype - ) + num_patches_list = None + pixel_values_list = [] + if img_qa['img'] is not None: + if isinstance(img_qa['img'], list): + num_patches_list = [] + for img_idx in range(len(img_qa['img'])): + pixel_values = load_image(img_qa['img'][img_idx], max_num=12).to( + next(self.vlm_model.parameters()).dtype + ) + pixel_values_list.append(pixel_values) + num_patches_list.append(pixel_values.size(0)) + pixel_values = torch.cat(pixel_values_list, dim=0) + else: + pixel_values = load_image(img_qa['img'], max_num=12).to( + next(self.vlm_model.parameters()).dtype + ) else: pixel_values = None - question = img_qas['question'] + question = img_qa['question'] if pixel_values is not None and '' not in question: question = '\n' + question - num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] + if num_patches_list is None: + num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] generation_config = dict() IMG_CONTEXT_TOKEN = ''