Skip to content

Commit 14bb456

Browse files
update internvl2 (#193)
1 parent 0ea65a8 commit 14bb456

File tree

1 file changed

+16
-75
lines changed

1 file changed

+16
-75
lines changed

llmc/models/internvl2.py

Lines changed: 16 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -129,21 +129,25 @@ def build_model(self):
129129
self.vlm_model.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
130130

131131
def batch_process(self, img_qas):
132-
if len(img_qas) == 1:
133-
return self.single_process(img_qas[0])
134132
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
135133
questions = []
136134
pixel_values_list = []
137135
num_patches_list = []
138136
for idx in range(len(img_qas)):
139137
img_path = img_qas[idx]['img']
140-
pixel_values = load_image(img_path, max_num=12).to(
141-
next(self.vlm_model.parameters()).dtype
142-
)
143-
pixel_values_list.append(pixel_values)
144-
num_patches_list.append(pixel_values.size(0))
145-
questions.append(f"<image>\n{img_qas[idx]['question']}")
146-
pixel_values = torch.cat(pixel_values_list, dim=0)
138+
_num_patches_i = []
139+
if img_path is not None:
140+
if not isinstance(img_path, list):
141+
img_path = [img_path]
142+
for img_idx in range(len(img_path)):
143+
pixel_values = load_image(img_path[img_idx], max_num=12).to(
144+
next(self.vlm_model.parameters()).dtype
145+
)
146+
pixel_values_list.append(pixel_values)
147+
_num_patches_i.append(pixel_values.size(0))
148+
num_patches_list.append(_num_patches_i)
149+
questions.append(img_qas[idx]['question'])
150+
pixel_values = torch.cat(pixel_values_list, dim=0) if len(pixel_values_list) > 0 else None
147151
generation_config = dict()
148152

149153
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
@@ -163,12 +167,10 @@ def batch_process(self, img_qas):
163167
template.append_message(template.roles[0], question)
164168
template.append_message(template.roles[1], None)
165169
query = template.get_prompt()
166-
image_tokens = (IMG_START_TOKEN +
167-
IMG_CONTEXT_TOKEN * self.vlm_model.num_image_token * num_patches +
168-
IMG_END_TOKEN)
169-
query = query.replace('<image>', image_tokens, 1)
170+
for _num_patches_i in num_patches:
171+
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.vlm_model.num_image_token * _num_patches_i + IMG_END_TOKEN # noqa
172+
query = query.replace('<image>', image_tokens, 1)
170173
queries.append(query)
171-
172174
tokenizer.padding_side = 'left'
173175
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
174176
input_ids = model_inputs['input_ids']
@@ -183,64 +185,3 @@ def batch_process(self, img_qas):
183185
**generation_config
184186
}
185187
return inputs
186-
187-
def single_process(self, img_qa):
188-
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
189-
num_patches_list = None
190-
pixel_values_list = []
191-
if img_qa['img'] is not None:
192-
if isinstance(img_qa['img'], list):
193-
num_patches_list = []
194-
for img_idx in range(len(img_qa['img'])):
195-
pixel_values = load_image(img_qa['img'][img_idx], max_num=12).to(
196-
next(self.vlm_model.parameters()).dtype
197-
)
198-
pixel_values_list.append(pixel_values)
199-
num_patches_list.append(pixel_values.size(0))
200-
pixel_values = torch.cat(pixel_values_list, dim=0)
201-
else:
202-
pixel_values = load_image(img_qa['img'], max_num=12).to(
203-
next(self.vlm_model.parameters()).dtype
204-
)
205-
else:
206-
pixel_values = None
207-
question = img_qa['question']
208-
if pixel_values is not None and '<image>' not in question:
209-
question = '<image>\n' + question
210-
if num_patches_list is None:
211-
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
212-
generation_config = dict()
213-
214-
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
215-
IMG_START_TOKEN = '<img>'
216-
IMG_END_TOKEN = '</img>'
217-
try:
218-
template = get_conv_template(self.vlm_model.template)
219-
except Exception:
220-
raise Exception(
221-
'InternLM2 conversation.py not be found. '
222-
'Please copy it from model path to llmc/models.'
223-
)
224-
template.system_message = self.vlm_model.system_message
225-
template.append_message(template.roles[0], question)
226-
template.append_message(template.roles[1], None)
227-
query = template.get_prompt()
228-
for num_patches in num_patches_list:
229-
image_tokens = (IMG_START_TOKEN +
230-
IMG_CONTEXT_TOKEN * self.vlm_model.num_image_token * num_patches +
231-
IMG_END_TOKEN)
232-
query = query.replace('<image>', image_tokens, 1)
233-
234-
model_inputs = tokenizer(query, return_tensors='pt')
235-
input_ids = model_inputs['input_ids']
236-
attention_mask = model_inputs['attention_mask']
237-
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
238-
generation_config['eos_token_id'] = eos_token_id
239-
240-
inputs = {
241-
'pixel_values': pixel_values,
242-
'input_ids': input_ids,
243-
'attention_mask': attention_mask,
244-
**generation_config
245-
}
246-
return inputs

0 commit comments

Comments
 (0)