Skip to content

Commit ee61b0f

Browse files
update llava model (#208)
1 parent cd1668d commit ee61b0f

File tree

1 file changed

+4
-26
lines changed

1 file changed

+4
-26
lines changed

llmc/models/llava.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,9 @@ def build_model(self):
3131
self.model = self.vlm_model.language_model
3232
self.model_config = self.vlm_model_config.text_config
3333

34+
self.processor = AutoProcessor.from_pretrained(self.model_path)
35+
3436
def batch_process(self, img_qas):
35-
if len(img_qas) == 1:
36-
return self.single_process(img_qas[0])
37-
processor = AutoProcessor.from_pretrained(self.model_path)
3837
messages = []
3938
images = []
4039
for idx in range(len(img_qas)):
@@ -52,38 +51,17 @@ def batch_process(self, img_qas):
5251
messages.append(message)
5352
images.append(image)
5453
texts = [
55-
processor.apply_chat_template(msg, add_generation_prompt=True)
54+
self.processor.apply_chat_template(msg, add_generation_prompt=True)
5655
for msg in messages
5756
]
58-
inputs = processor(
57+
inputs = self.processor(
5958
text=texts,
6059
images=images,
6160
padding=True,
6261
return_tensors='pt'
6362
).to(next(self.vlm_model.parameters()).dtype) # noqa
6463
return inputs
6564

66-
def single_process(self, img_qas):
67-
processor = AutoProcessor.from_pretrained(self.model_path)
68-
img_path = img_qas['img']
69-
image = Image.open(img_path) if img_path is not None else None
70-
message = [
71-
{
72-
'role': 'user',
73-
'content': [{'type': 'text', 'text': img_qas['question']}]
74-
}
75-
]
76-
if img_path is not None:
77-
message[0]['content'].insert(0, {'type': 'image'})
78-
text = processor.apply_chat_template(message, add_generation_prompt=True)
79-
inputs = processor(
80-
text=text,
81-
images=image,
82-
padding=True,
83-
return_tensors='pt'
84-
).to(next(self.vlm_model.parameters()).dtype) # noqa
85-
return inputs
86-
8765
def find_blocks(self, modality='language'):
8866
if modality == 'language':
8967
self.blocks = self.model.model.layers

0 commit comments

Comments
 (0)