1+ import os
12import types
23from datetime import timedelta
34from typing import Optional , Union
910from lmms_eval .models .llava import Llava as LLaVA
1011from loguru import logger
1112from packaging import version
13+ from PIL import Image
1214from transformers import AutoConfig , AutoTokenizer
1315
1416from llmc .utils .registry_factory import MODEL_REGISTRY
1719
1820try :
1921 from llava .constants import (DEFAULT_IM_END_TOKEN , DEFAULT_IM_START_TOKEN ,
20- DEFAULT_IMAGE_PATCH_TOKEN , IMAGE_TOKEN_INDEX )
21- from llava .mm_utils import get_model_name_from_path
22+ DEFAULT_IMAGE_PATCH_TOKEN ,
23+ DEFAULT_IMAGE_TOKEN , IMAGE_TOKEN_INDEX )
24+ from llava .conversation import SeparatorStyle , conv_templates
25+ from llava .mm_utils import (get_model_name_from_path , process_images ,
26+ tokenizer_image_token )
2227 from llava .model .builder import load_pretrained_model
2328 from llava .model .language_model .llava_llama import LlavaConfig
2429except Exception as e :
@@ -45,7 +50,7 @@ def build_model(self):
4550 self .vlm_model_config .use_cache = True
4651 logger .info (f'self.vlm_model_config : { self .vlm_model_config } ' )
4752
48- self .tokenizer , self .vlm_model , image_processor , context_len = load_pretrained_model (
53+ self .tokenizer , self .vlm_model , self . image_processor , context_len = load_pretrained_model (
4954 self .model_path ,
5055 None ,
5156 get_model_name_from_path (self .model_path ),
@@ -137,6 +142,96 @@ def get_subsets_in_block(self, block):
137142 else :
138143 raise Exception (f'Llava do not support { self .get_modality ()} modality.' )
139144
145+ def eval_custom_samples_just_infer (
146+ self ,
147+ img_qas ,
148+ eval_cfg
149+ ): # noqa
150+
151+ custom_samples_ans = img_qas .copy ()
152+
153+ self .vlm_model .cuda ()
154+
155+ def load_image (image_file ):
156+ image = Image .open (image_file ).convert ('RGB' )
157+ return image
158+
159+ def load_images (image_files ):
160+ out = []
161+ for image_file in image_files :
162+ image = load_image (image_file )
163+ out .append (image )
164+ return out
165+
166+ self .first_turn_question = True
167+
168+ for data_idx , questions in enumerate (img_qas ):
169+ self .first_turn_question = True
170+
171+ custom_samples_ans [data_idx ]['answer' ] = []
172+
173+ image_files = questions ['image' ]
174+ image_files = [os .path .join (eval_cfg .path , 'images' , image_file ) for image_file in image_files ] # noqa
175+ images = load_images (image_files )
176+ image_sizes = [x .size for x in images ]
177+ images_tensor = process_images (
178+ images ,
179+ self .image_processor ,
180+ self .vlm_model .config
181+ ).to (self .vlm_model .device , dtype = torch .float16 )
182+
183+ input_ids_old = None
184+
185+ for question_idx , question in enumerate (questions ['question' ]):
186+
187+ conv_mode = 'llava_v1'
188+ conv = conv_templates [conv_mode ].copy ()
189+ if question_idx > 0 :
190+ conv .system = ''
191+ qs = question
192+ self .first_turn_question = False
193+ else :
194+ qs = DEFAULT_IMAGE_TOKEN + '\n ' + question
195+ conv .append_message (conv .roles [0 ], qs )
196+ conv .append_message (conv .roles [1 ], None )
197+ prompt = conv .get_prompt ()
198+
199+ input_ids = tokenizer_image_token (prompt , self .tokenizer , IMAGE_TOKEN_INDEX , return_tensors = 'pt' ).unsqueeze (0 ).cuda () # noqa
200+ # print(f"input_ids 1: {input_ids}, {input_ids.shape}")
201+ if input_ids_old is not None :
202+ input_ids = torch .cat ((input_ids_old , input_ids ), dim = 1 )
203+ # print(f"input_ids 2: {input_ids}, {input_ids.shape}")
204+
205+ with torch .inference_mode ():
206+ output_ids = self .vlm_model .generate (
207+ input_ids ,
208+ attention_mask = input_ids .new_ones (input_ids .shape , dtype = torch .bool ),
209+ images = images_tensor ,
210+ image_sizes = image_sizes ,
211+ do_sample = False ,
212+ top_p = None ,
213+ num_beams = 1 ,
214+ max_new_tokens = eval_cfg .max_new_tokens ,
215+ use_cache = True ,
216+ )
217+
218+ # print(f"output_ids: {output_ids}, {output_ids.shape}")
219+
220+ outputs = self .tokenizer .batch_decode (output_ids , skip_special_tokens = True )
221+
222+ print ('--------------------------------' )
223+ print (f'data_idx: { data_idx } ' )
224+ print (f'question_idx: { question_idx } ' )
225+ print (f'question: { question } ' )
226+ print (f'outputs: { outputs } ' )
227+ print ('--------------------------------' )
228+
229+ custom_samples_ans [data_idx ]['answer' ].append (outputs [0 ])
230+
231+ input_ids_old = torch .cat ((input_ids , output_ids ), dim = 1 )
232+
233+ return custom_samples_ans
234+
140235
141236if version .parse (torch .__version__ ) >= version .parse ('2.1.2' ):
142237 best_fit_attn_implementation = 'sdpa'
0 commit comments