diff --git a/llmc/models/internvl2.py b/llmc/models/internvl2.py index 766632e54..7269c1056 100644 --- a/llmc/models/internvl2.py +++ b/llmc/models/internvl2.py @@ -124,6 +124,19 @@ def build_model(self): IMG_CONTEXT_TOKEN = '' self.vlm_model.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) # noqa + self.default_image_prompt_template = { + 'single': '\n', + 'multiple': 'Image-<|idx|>: \n' + } + logger.warning( + f'InternVL2 default_image_prompt_template: {self.default_image_prompt_template}' + ) + logger.warning( + 'Default template refer to the link https://huggingface.co/OpenGVLab/InternVL2-2B. ' + 'If you want a custom template, you can change it. ' + 'Besides, you can also put the into your calib dataset.' + ) + def batch_process(self, img_qas): questions = [] pixel_values_list = [] @@ -141,8 +154,22 @@ def batch_process(self, img_qas): pixel_values_list.append(pixel_values) _num_patches_i.append(pixel_values.size(0)) num_patches_list.append(_num_patches_i) + if img_path is not None: + if img_qas[idx]['question'].count('') == 0: + prefix = '' + if len(img_path) == 1: + prefix = self.default_image_prompt_template['single'] + else: + for n in range(len(img_path)): + prefix = prefix + self.default_image_prompt_template['multiple'].replace('<|idx|>', f'{n+1}') # noqa + img_qas[idx]['question'] = prefix + img_qas[idx]['question'] + else: + assert img_qas[idx]['question'].count('') == len(img_path), f"{img_qas[idx]['img']} this data prompt is wrong." # noqa questions.append(img_qas[idx]['question']) - pixel_values = torch.cat(pixel_values_list, dim=0) if len(pixel_values_list) > 0 else None + + pixel_values = ( + torch.cat(pixel_values_list, dim=0) if len(pixel_values_list) > 0 else None + ) generation_config = dict() IMG_CONTEXT_TOKEN = '' diff --git a/llmc/models/qwen2vl.py b/llmc/models/qwen2vl.py index ca71f92c2..bb6e4e263 100644 --- a/llmc/models/qwen2vl.py +++ b/llmc/models/qwen2vl.py @@ -52,8 +52,8 @@ def build_model(self): self.max_pixels = 1280 * 28 * 28 logger.warning(f'min_pixels is set to: {self.min_pixels}') logger.warning(f'max_pixels is set to: {self.max_pixels}') - logger.warning('You can refer the link https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct ' - 'to get more info of image Resolution for performance boost.') + logger.warning('You can refer to the link https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct ' + 'to get more info of image resolution for performance boost.') self.processor = AutoProcessor.from_pretrained( self.model_path, min_pixels=self.min_pixels,