diff --git a/llmc/models/internvl2.py b/llmc/models/internvl2.py index 7eefd95fb..3defb4d52 100644 --- a/llmc/models/internvl2.py +++ b/llmc/models/internvl2.py @@ -129,21 +129,25 @@ def build_model(self): self.vlm_model.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) def batch_process(self, img_qas): - if len(img_qas) == 1: - return self.single_process(img_qas[0]) tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) questions = [] pixel_values_list = [] num_patches_list = [] for idx in range(len(img_qas)): img_path = img_qas[idx]['img'] - pixel_values = load_image(img_path, max_num=12).to( - next(self.vlm_model.parameters()).dtype - ) - pixel_values_list.append(pixel_values) - num_patches_list.append(pixel_values.size(0)) - questions.append(f"\n{img_qas[idx]['question']}") - pixel_values = torch.cat(pixel_values_list, dim=0) + _num_patches_i = [] + if img_path is not None: + if not isinstance(img_path, list): + img_path = [img_path] + for img_idx in range(len(img_path)): + pixel_values = load_image(img_path[img_idx], max_num=12).to( + next(self.vlm_model.parameters()).dtype + ) + pixel_values_list.append(pixel_values) + _num_patches_i.append(pixel_values.size(0)) + num_patches_list.append(_num_patches_i) + questions.append(img_qas[idx]['question']) + pixel_values = torch.cat(pixel_values_list, dim=0) if len(pixel_values_list) > 0 else None generation_config = dict() IMG_CONTEXT_TOKEN = '' @@ -163,12 +167,10 @@ def batch_process(self, img_qas): template.append_message(template.roles[0], question) template.append_message(template.roles[1], None) query = template.get_prompt() - image_tokens = (IMG_START_TOKEN + - IMG_CONTEXT_TOKEN * self.vlm_model.num_image_token * num_patches + - IMG_END_TOKEN) - query = query.replace('', image_tokens, 1) + for _num_patches_i in num_patches: + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.vlm_model.num_image_token * _num_patches_i + IMG_END_TOKEN # noqa + query = query.replace('', image_tokens, 1) queries.append(query) - tokenizer.padding_side = 'left' model_inputs = tokenizer(queries, return_tensors='pt', padding=True) input_ids = model_inputs['input_ids'] @@ -183,64 +185,3 @@ def batch_process(self, img_qas): **generation_config } return inputs - - def single_process(self, img_qa): - tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) - num_patches_list = None - pixel_values_list = [] - if img_qa['img'] is not None: - if isinstance(img_qa['img'], list): - num_patches_list = [] - for img_idx in range(len(img_qa['img'])): - pixel_values = load_image(img_qa['img'][img_idx], max_num=12).to( - next(self.vlm_model.parameters()).dtype - ) - pixel_values_list.append(pixel_values) - num_patches_list.append(pixel_values.size(0)) - pixel_values = torch.cat(pixel_values_list, dim=0) - else: - pixel_values = load_image(img_qa['img'], max_num=12).to( - next(self.vlm_model.parameters()).dtype - ) - else: - pixel_values = None - question = img_qa['question'] - if pixel_values is not None and '' not in question: - question = '\n' + question - if num_patches_list is None: - num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] - generation_config = dict() - - IMG_CONTEXT_TOKEN = '' - IMG_START_TOKEN = '' - IMG_END_TOKEN = '' - try: - template = get_conv_template(self.vlm_model.template) - except Exception: - raise Exception( - 'InternLM2 conversation.py not be found. ' - 'Please copy it from model path to llmc/models.' - ) - template.system_message = self.vlm_model.system_message - template.append_message(template.roles[0], question) - template.append_message(template.roles[1], None) - query = template.get_prompt() - for num_patches in num_patches_list: - image_tokens = (IMG_START_TOKEN + - IMG_CONTEXT_TOKEN * self.vlm_model.num_image_token * num_patches + - IMG_END_TOKEN) - query = query.replace('', image_tokens, 1) - - model_inputs = tokenizer(query, return_tensors='pt') - input_ids = model_inputs['input_ids'] - attention_mask = model_inputs['attention_mask'] - eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) - generation_config['eos_token_id'] = eos_token_id - - inputs = { - 'pixel_values': pixel_values, - 'input_ids': input_ids, - 'attention_mask': attention_mask, - **generation_config - } - return inputs