@@ -124,6 +124,19 @@ def build_model(self):
124124 IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
125125 self .vlm_model .img_context_token_id = self .tokenizer .convert_tokens_to_ids (IMG_CONTEXT_TOKEN ) # noqa
126126
127+ self .default_image_prompt_template = {
128+ 'single' : '<image>\n ' ,
129+ 'multiple' : 'Image-<|idx|>: <image>\n '
130+ }
131+ logger .warning (
132+ f'InternVL2 default_image_prompt_template: { self .default_image_prompt_template } '
133+ )
134+ logger .warning (
135+ 'Default template refer to the link https://huggingface.co/OpenGVLab/InternVL2-2B. '
136+ 'If you want a custom template, you can change it. '
137+ 'Besides, you can also put the <image> into your calib dataset.'
138+ )
139+
127140 def batch_process (self , img_qas ):
128141 questions = []
129142 pixel_values_list = []
@@ -141,8 +154,22 @@ def batch_process(self, img_qas):
141154 pixel_values_list .append (pixel_values )
142155 _num_patches_i .append (pixel_values .size (0 ))
143156 num_patches_list .append (_num_patches_i )
157+ if img_path is not None :
158+ if img_qas [idx ]['question' ].count ('<image>' ) == 0 :
159+ prefix = ''
160+ if len (img_path ) == 1 :
161+ prefix = self .default_image_prompt_template ['single' ]
162+ else :
163+ for n in range (len (img_path )):
164+ prefix = prefix + self .default_image_prompt_template ['multiple' ].replace ('<|idx|>' , f'{ n + 1 } ' ) # noqa
165+ img_qas [idx ]['question' ] = prefix + img_qas [idx ]['question' ]
166+ else :
167+ assert img_qas [idx ]['question' ].count ('<image>' ) == len (img_path ), f"{ img_qas [idx ]['img' ]} this data prompt is wrong." # noqa
144168 questions .append (img_qas [idx ]['question' ])
145- pixel_values = torch .cat (pixel_values_list , dim = 0 ) if len (pixel_values_list ) > 0 else None
169+
170+ pixel_values = (
171+ torch .cat (pixel_values_list , dim = 0 ) if len (pixel_values_list ) > 0 else None
172+ )
146173 generation_config = dict ()
147174
148175 IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
0 commit comments