Skip to content

Commit d9cd424

Browse files
authored
support vila and minicpmv (#311)
Signed-off-by: pengchao.hu <pengchao.hu@sophgo.com>
1 parent f376d59 commit d9cd424

File tree

3 files changed

+708
-0
lines changed

3 files changed

+708
-0
lines changed

llmc/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .llama import Llama
1212
from .llava import Llava
1313
from .minicpm import MiniCPM
14+
from .minicpmv import MiniCPMV
1415
from .mistral import Mistral
1516
from .mixtral import Mixtral
1617
from .mllama import Mllama
@@ -25,4 +26,5 @@
2526
from .smollm import SmolLM
2627
from .stablelm import StableLm
2728
from .starcoder import Starcoder
29+
from .vila import Vila
2830
from .vit import Vit

llmc/models/minicpmv.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
from typing import List, Optional, Tuple, Union
2+
3+
import torch
4+
from accelerate import Accelerator, DistributedType
5+
from accelerate.state import AcceleratorState
6+
from lmms_eval.api.instance import Instance
7+
from lmms_eval.api.model import lmms
8+
from loguru import logger
9+
from PIL import Image
10+
from tqdm import tqdm
11+
from transformers import (AutoConfig, AutoModelForCausalLM, AutoProcessor,
12+
AutoTokenizer)
13+
14+
from llmc.utils.registry_factory import MODEL_REGISTRY
15+
16+
from .minicpm import MiniCPM
17+
18+
19+
@MODEL_REGISTRY
20+
class MiniCPMV(MiniCPM):
21+
22+
def __init__(self, config, device_map=None, use_cache=False):
23+
super().__init__(config, device_map, use_cache)
24+
25+
def build_model(self):
26+
self.eval_name = 'MiniCPMVEval'
27+
self.vlm_model_config = AutoConfig.from_pretrained(
28+
self.model_path, trust_remote_code=True)
29+
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
30+
self.vlm_model = AutoModelForCausalLM.from_pretrained(
31+
self.model_path,
32+
config=self.vlm_model_config,
33+
trust_remote_code=True,
34+
torch_dtype='auto',
35+
low_cpu_mem_usage=True,
36+
)
37+
self.mm_model = self.vlm_model
38+
self.vlm_model_config = self.vlm_model.config
39+
if not self.use_cache:
40+
if hasattr(self.vlm_model_config, 'use_cache'):
41+
self.vlm_model_config.use_cache = False
42+
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
43+
self.mm_model = self.vlm_model
44+
logger.info(f'self.vlm_model : {self.vlm_model}')
45+
self.vision_model = self.vlm_model.vpm
46+
self.model = self.vlm_model.llm
47+
self.model_config = self.vlm_model_config
48+
self.processor = AutoProcessor.from_pretrained(self.model_path,
49+
trust_remote_code=True)
50+
self.max_slice_nums = self.processor.image_processor.max_slice_nums
51+
self.max_length = 4096
52+
53+
def batch_process(self,
54+
img_qas,
55+
calib_or_eval='eval',
56+
apply_chat_template=True,
57+
return_inputs=True): # noqa
58+
assert calib_or_eval == 'calib' or calib_or_eval == 'eval'
59+
assert apply_chat_template
60+
add_answer = calib_or_eval == 'calib' and self.config['calib'].get(
61+
'add_answer', False)
62+
image_lists = []
63+
prompt_lists = []
64+
for idx in range(len(img_qas)):
65+
img_path = img_qas[idx]['image']
66+
question = img_qas[idx]['question']
67+
answer = img_qas[idx]['answer']
68+
image_lists.append([Image.open(img_path).convert('RGB')])
69+
if not add_answer:
70+
msg = [{
71+
'role': 'user',
72+
'content': '(<image>./</image>)\n' + question
73+
}]
74+
else:
75+
msg = [{
76+
'role': 'user',
77+
'content': '(<image>./</image>)\n' + question
78+
}, {
79+
'role': 'assistant',
80+
'content': answer
81+
}]
82+
prompt = self.processor.tokenizer.apply_chat_template(
83+
msg, tokenize=False, add_generation_prompt=True)
84+
prompt_lists.append(prompt)
85+
if not return_inputs:
86+
return prompt_lists
87+
inputs = self.processor(
88+
prompt_lists,
89+
image_lists,
90+
max_slice_num=self.max_slice_nums,
91+
use_image_id=self.model_config.use_image_id,
92+
return_tensors='pt',
93+
max_length=self.max_length).to(self.vlm_model.device).to(
94+
next(self.vlm_model.parameters()).dtype)
95+
inputs.pop('image_sizes')
96+
inputs['tokenizer'] = self.processor.tokenizer
97+
return inputs
98+
99+
def find_blocks(self):
100+
assert self.get_modality() == 'language'
101+
super().find_blocks()
102+
103+
def get_layernorms_in_block(self, block):
104+
assert self.get_modality() == 'language'
105+
return super().get_layernorms_in_block(block)
106+
107+
108+
@MODEL_REGISTRY
109+
class MiniCPMVEval(lmms):
110+
"""MiniCPM_V Model."""
111+
112+
def __init__(
113+
self,
114+
llmc_model,
115+
pretrained: str = 'openbmb/MiniCPM-V',
116+
device: Optional[str] = 'cuda',
117+
dtype: Optional[Union[str, torch.dtype]] = torch.bfloat16,
118+
batch_size: Optional[Union[int, str]] = 1,
119+
trust_remote_code: Optional[bool] = True,
120+
**kwargs,
121+
) -> None:
122+
lmms.__init__(self)
123+
assert batch_size == 1, f'Batch size should be 1 for MiniCPMV, but got {batch_size}.'
124+
125+
accelerator = Accelerator()
126+
if accelerator.num_processes > 1:
127+
self._device = torch.device(
128+
f'cuda:{accelerator.local_process_index}')
129+
else:
130+
self._device = device
131+
self._model = llmc_model.eval().cuda()
132+
self._tokenizer = AutoTokenizer.from_pretrained(
133+
pretrained, trust_remote_code=trust_remote_code)
134+
self._config = self._model.config
135+
self._max_length = 4096
136+
self.batch_size_per_gpu = int(batch_size)
137+
if accelerator.num_processes > 1:
138+
assert accelerator.distributed_type in [
139+
DistributedType.FSDP, DistributedType.MULTI_GPU,
140+
DistributedType.DEEPSPEED
141+
], 'Unsupported distributed type provided. Only DDP and FSDP are supported.'
142+
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate
143+
# config before using the model
144+
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the
145+
# prepare model works
146+
# I tried to set different parameters in the kwargs to let default zero 2 stage works,
147+
# but it didn't work.
148+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
149+
kwargs = {
150+
'train_micro_batch_size_per_gpu':
151+
self.batch_size_per_gpu,
152+
'train_batch_size':
153+
self.batch_size_per_gpu * accelerator.num_processes,
154+
}
155+
AcceleratorState().deepspeed_plugin.deepspeed_config_process(
156+
must_match=True, **kwargs)
157+
logger.info(
158+
'Detected that you are using DistributedType.DEEPSPEED. Make sure you run '
159+
'`accelerate config` and set zero stage to 0'
160+
)
161+
if accelerator.distributed_type == DistributedType.FSDP or \
162+
accelerator.distributed_type == DistributedType.DEEPSPEED:
163+
self._model = accelerator.prepare(self.model)
164+
else:
165+
self._model = accelerator.prepare_model(self.model,
166+
evaluation_mode=True)
167+
self.accelerator = accelerator
168+
if self.accelerator.is_local_main_process:
169+
logger.info(
170+
f'Using {accelerator.num_processes} devices with data parallelism'
171+
)
172+
self._rank = self.accelerator.local_process_index
173+
self._world_size = self.accelerator.num_processes
174+
else:
175+
self.model.to(self._device)
176+
self._rank = 0
177+
self._word_size = 1
178+
179+
@property
180+
def config(self):
181+
# return the associated transformers.AutoConfig for the given pretrained model.
182+
return self._config
183+
184+
@property
185+
def tokenizer(self):
186+
return self._tokenizer
187+
188+
@property
189+
def model(self):
190+
# returns the model, unwrapping it if using Accelerate
191+
if hasattr(self, 'accelerator'):
192+
return self.accelerator.unwrap_model(self._model)
193+
else:
194+
return self._model
195+
196+
@property
197+
def eot_token_id(self):
198+
return self.tokenizer.eos_token_id
199+
200+
@property
201+
def max_length(self):
202+
return self._max_length
203+
204+
@property
205+
def batch_size(self):
206+
return self.batch_size_per_gpu
207+
208+
@property
209+
def device(self):
210+
return self._device
211+
212+
@property
213+
def rank(self):
214+
return self._rank
215+
216+
@property
217+
def world_size(self):
218+
return self._world_size
219+
220+
def tok_encode(self,
221+
string: str,
222+
left_truncate_len=None,
223+
add_special_tokens=None) -> List[int]:
224+
add_special_tokens = False if add_special_tokens is None else add_special_tokens
225+
encoding = self.tokenizer.encode(string,
226+
add_special_tokens=add_special_tokens)
227+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
228+
if left_truncate_len:
229+
encoding = encoding[-left_truncate_len:]
230+
return encoding
231+
232+
def tok_decode(self, tokens):
233+
return self.tokenizer.decode(tokens)
234+
235+
def loglikelihood(self,
236+
requests: List[Instance]) -> List[Tuple[float, bool]]:
237+
# TODO
238+
assert False, 'We have not implemented this function for MiniCPM_V yet'
239+
240+
def flatten(self, input):
241+
new_list = []
242+
for i in input:
243+
for j in i:
244+
new_list.append(j)
245+
return new_list
246+
247+
def generate_until(self, requests: List[Instance]) -> List[str]:
248+
res = []
249+
pbar = tqdm(total=len(requests),
250+
disable=(self.rank != 0),
251+
desc='Model Responding')
252+
253+
for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [
254+
reg.args for reg in requests
255+
]:
256+
# encode, pad, and truncate contexts for this batch
257+
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
258+
visuals = self.flatten(visuals)
259+
assert len(visuals) == 1
260+
msgs = [{'role': 'user', 'content': [visuals[0], contexts]}]
261+
outputs = self.model.chat(image=None,
262+
msgs=msgs,
263+
tokenizer=self.tokenizer)
264+
res.append(outputs)
265+
pbar.update(1)
266+
return res
267+
268+
def generate_until_multi_round(self, requests) -> List[str]:
269+
raise NotImplementedError('TODO: Implement multi-round generation')

0 commit comments

Comments
 (0)