@@ -129,21 +129,25 @@ def build_model(self):
129129 self .vlm_model .img_context_token_id = tokenizer .convert_tokens_to_ids (IMG_CONTEXT_TOKEN )
130130
131131 def batch_process (self , img_qas ):
132- if len (img_qas ) == 1 :
133- return self .single_process (img_qas [0 ])
134132 tokenizer = AutoTokenizer .from_pretrained (self .model_path , trust_remote_code = True )
135133 questions = []
136134 pixel_values_list = []
137135 num_patches_list = []
138136 for idx in range (len (img_qas )):
139137 img_path = img_qas [idx ]['img' ]
140- pixel_values = load_image (img_path , max_num = 12 ).to (
141- next (self .vlm_model .parameters ()).dtype
142- )
143- pixel_values_list .append (pixel_values )
144- num_patches_list .append (pixel_values .size (0 ))
145- questions .append (f"<image>\n { img_qas [idx ]['question' ]} " )
146- pixel_values = torch .cat (pixel_values_list , dim = 0 )
138+ _num_patches_i = []
139+ if img_path is not None :
140+ if not isinstance (img_path , list ):
141+ img_path = [img_path ]
142+ for img_idx in range (len (img_path )):
143+ pixel_values = load_image (img_path [img_idx ], max_num = 12 ).to (
144+ next (self .vlm_model .parameters ()).dtype
145+ )
146+ pixel_values_list .append (pixel_values )
147+ _num_patches_i .append (pixel_values .size (0 ))
148+ num_patches_list .append (_num_patches_i )
149+ questions .append (img_qas [idx ]['question' ])
150+ pixel_values = torch .cat (pixel_values_list , dim = 0 ) if len (pixel_values_list ) > 0 else None
147151 generation_config = dict ()
148152
149153 IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
@@ -163,12 +167,10 @@ def batch_process(self, img_qas):
163167 template .append_message (template .roles [0 ], question )
164168 template .append_message (template .roles [1 ], None )
165169 query = template .get_prompt ()
166- image_tokens = (IMG_START_TOKEN +
167- IMG_CONTEXT_TOKEN * self .vlm_model .num_image_token * num_patches +
168- IMG_END_TOKEN )
169- query = query .replace ('<image>' , image_tokens , 1 )
170+ for _num_patches_i in num_patches :
171+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self .vlm_model .num_image_token * _num_patches_i + IMG_END_TOKEN # noqa
172+ query = query .replace ('<image>' , image_tokens , 1 )
170173 queries .append (query )
171-
172174 tokenizer .padding_side = 'left'
173175 model_inputs = tokenizer (queries , return_tensors = 'pt' , padding = True )
174176 input_ids = model_inputs ['input_ids' ]
@@ -183,64 +185,3 @@ def batch_process(self, img_qas):
183185 ** generation_config
184186 }
185187 return inputs
186-
187- def single_process (self , img_qa ):
188- tokenizer = AutoTokenizer .from_pretrained (self .model_path , trust_remote_code = True )
189- num_patches_list = None
190- pixel_values_list = []
191- if img_qa ['img' ] is not None :
192- if isinstance (img_qa ['img' ], list ):
193- num_patches_list = []
194- for img_idx in range (len (img_qa ['img' ])):
195- pixel_values = load_image (img_qa ['img' ][img_idx ], max_num = 12 ).to (
196- next (self .vlm_model .parameters ()).dtype
197- )
198- pixel_values_list .append (pixel_values )
199- num_patches_list .append (pixel_values .size (0 ))
200- pixel_values = torch .cat (pixel_values_list , dim = 0 )
201- else :
202- pixel_values = load_image (img_qa ['img' ], max_num = 12 ).to (
203- next (self .vlm_model .parameters ()).dtype
204- )
205- else :
206- pixel_values = None
207- question = img_qa ['question' ]
208- if pixel_values is not None and '<image>' not in question :
209- question = '<image>\n ' + question
210- if num_patches_list is None :
211- num_patches_list = [pixel_values .shape [0 ]] if pixel_values is not None else []
212- generation_config = dict ()
213-
214- IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
215- IMG_START_TOKEN = '<img>'
216- IMG_END_TOKEN = '</img>'
217- try :
218- template = get_conv_template (self .vlm_model .template )
219- except Exception :
220- raise Exception (
221- 'InternLM2 conversation.py not be found. '
222- 'Please copy it from model path to llmc/models.'
223- )
224- template .system_message = self .vlm_model .system_message
225- template .append_message (template .roles [0 ], question )
226- template .append_message (template .roles [1 ], None )
227- query = template .get_prompt ()
228- for num_patches in num_patches_list :
229- image_tokens = (IMG_START_TOKEN +
230- IMG_CONTEXT_TOKEN * self .vlm_model .num_image_token * num_patches +
231- IMG_END_TOKEN )
232- query = query .replace ('<image>' , image_tokens , 1 )
233-
234- model_inputs = tokenizer (query , return_tensors = 'pt' )
235- input_ids = model_inputs ['input_ids' ]
236- attention_mask = model_inputs ['attention_mask' ]
237- eos_token_id = tokenizer .convert_tokens_to_ids (template .sep )
238- generation_config ['eos_token_id' ] = eos_token_id
239-
240- inputs = {
241- 'pixel_values' : pixel_values ,
242- 'input_ids' : input_ids ,
243- 'attention_mask' : attention_mask ,
244- ** generation_config
245- }
246- return inputs
0 commit comments