Skip to content

Commit 9b91523

Browse files
update models (#212)
1 parent d685e91 commit 9b91523

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

llmc/models/internvl2.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,19 @@ def build_model(self):
124124
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
125125
self.vlm_model.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) # noqa
126126

127+
self.default_image_prompt_template = {
128+
'single': '<image>\n',
129+
'multiple': 'Image-<|idx|>: <image>\n'
130+
}
131+
logger.warning(
132+
f'InternVL2 default_image_prompt_template: {self.default_image_prompt_template}'
133+
)
134+
logger.warning(
135+
'Default template refer to the link https://huggingface.co/OpenGVLab/InternVL2-2B. '
136+
'If you want a custom template, you can change it. '
137+
'Besides, you can also put the <image> into your calib dataset.'
138+
)
139+
127140
def batch_process(self, img_qas):
128141
questions = []
129142
pixel_values_list = []
@@ -141,8 +154,22 @@ def batch_process(self, img_qas):
141154
pixel_values_list.append(pixel_values)
142155
_num_patches_i.append(pixel_values.size(0))
143156
num_patches_list.append(_num_patches_i)
157+
if img_path is not None:
158+
if img_qas[idx]['question'].count('<image>') == 0:
159+
prefix = ''
160+
if len(img_path) == 1:
161+
prefix = self.default_image_prompt_template['single']
162+
else:
163+
for n in range(len(img_path)):
164+
prefix = prefix + self.default_image_prompt_template['multiple'].replace('<|idx|>', f'{n+1}') # noqa
165+
img_qas[idx]['question'] = prefix + img_qas[idx]['question']
166+
else:
167+
assert img_qas[idx]['question'].count('<image>') == len(img_path), f"{img_qas[idx]['img']} this data prompt is wrong." # noqa
144168
questions.append(img_qas[idx]['question'])
145-
pixel_values = torch.cat(pixel_values_list, dim=0) if len(pixel_values_list) > 0 else None
169+
170+
pixel_values = (
171+
torch.cat(pixel_values_list, dim=0) if len(pixel_values_list) > 0 else None
172+
)
146173
generation_config = dict()
147174

148175
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'

llmc/models/qwen2vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def build_model(self):
5252
self.max_pixels = 1280 * 28 * 28
5353
logger.warning(f'min_pixels is set to: {self.min_pixels}')
5454
logger.warning(f'max_pixels is set to: {self.max_pixels}')
55-
logger.warning('You can refer the link https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct '
56-
'to get more info of image Resolution for performance boost.')
55+
logger.warning('You can refer to the link https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct '
56+
'to get more info of image resolution for performance boost.')
5757
self.processor = AutoProcessor.from_pretrained(
5858
self.model_path,
5959
min_pixels=self.min_pixels,

0 commit comments

Comments
 (0)