Skip to content

Commit 671c271

Browse files
authored
support llava-lht(test:quantize-awq, MME-pretrain) (#379)
1 parent de25b60 commit 671c271

File tree

3 files changed

+303
-2
lines changed

3 files changed

+303
-2
lines changed

llmc/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .internvl2 import InternVL2
1111
from .llama import Llama
1212
from .llava import Llava
13+
from .llava_lht import LlavaLHT
1314
from .minicpm import MiniCPM
1415
from .minicpmv import MiniCPMV
1516
from .mistral import Mistral

llmc/models/llava_lht.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
import types
2+
from datetime import timedelta
3+
from typing import Optional, Union
4+
5+
import torch
6+
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
7+
from accelerate.state import AcceleratorState
8+
from lmms_eval.api.model import lmms
9+
from lmms_eval.models.llava import Llava
10+
from loguru import logger
11+
from packaging import version
12+
from transformers import AutoConfig, AutoTokenizer
13+
14+
from llmc.utils.registry_factory import MODEL_REGISTRY
15+
16+
from .llama import Llama
17+
18+
try:
19+
from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
20+
DEFAULT_IMAGE_PATCH_TOKEN)
21+
from llava.mm_utils import get_model_name_from_path
22+
from llava.model.builder import load_pretrained_model
23+
from llava.model.language_model.llava_llama import LlavaConfig
24+
except Exception as e:
25+
logger.debug('LLaVA is not installed. Please install LLaVA to use this model.\nError: %s' % e)
26+
27+
28+
@MODEL_REGISTRY
29+
class LlavaLHT(Llama):
30+
def __init__(self, config, device_map=None, use_cache=False):
31+
super().__init__(config, device_map, use_cache)
32+
33+
def build_tokenizer(self):
34+
pass
35+
36+
def build_model(self):
37+
self.llava_llama_config = LlavaConfig.from_pretrained(
38+
self.model_path, trust_remote_code=True
39+
)
40+
self.vlm_model_config = AutoConfig.from_pretrained(
41+
self.model_path, trust_remote_code=True
42+
)
43+
if not self.use_cache:
44+
self.llava_llama_config.use_cache = False
45+
self.vlm_model_config.use_cache = False
46+
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
47+
self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model(
48+
self.model_path,
49+
None,
50+
get_model_name_from_path(self.model_path),
51+
load_8bit=False,
52+
load_4bit=False,
53+
torch_dtype=self.torch_dtype,
54+
device='cpu',
55+
config=self.llava_llama_config,
56+
)
57+
# llava-lht forward not support "cache_position"
58+
ori_forward = self.vlm_model.forward
59+
60+
def safe_forward(*args, **kwargs):
61+
kwargs['use_cache'] = False
62+
kwargs.pop('cache_position', None)
63+
return ori_forward(*args, **kwargs)
64+
self.vlm_model.forward = safe_forward
65+
# llava-lht generate use "inputs" instead of "input_ids"
66+
ori_generate = self.vlm_model.generate
67+
68+
def safe_generate(*args, **kwargs):
69+
if 'input_ids' in kwargs:
70+
kwargs['inputs'] = kwargs.pop('input_ids')
71+
return ori_generate(*args, **kwargs)
72+
self.vlm_model.generate = safe_generate
73+
74+
# "attention_mask" is passed via kwargs rather than as an explicit keyword argument.
75+
ori_prepare_inputs_for_generation = self.vlm_model.prepare_inputs_for_generation
76+
77+
def safe_prepare_inputs_for_generation(
78+
self, input_ids, past_key_values=None,
79+
inputs_embeds=None, attention_mask=None, **kwargs):
80+
if attention_mask is not None:
81+
kwargs['attention_mask'] = attention_mask
82+
return ori_prepare_inputs_for_generation(
83+
input_ids, past_key_values, inputs_embeds, **kwargs)
84+
self.vlm_model.prepare_inputs_for_generation = types.MethodType(
85+
safe_prepare_inputs_for_generation, self.vlm_model
86+
)
87+
88+
self.eval_name = 'LlavaLHTEval'
89+
self.mm_model = self.vlm_model
90+
logger.info(f'self.vlm_model : {self.vlm_model}')
91+
self.vision_model = self.vlm_model.get_vision_tower()
92+
self.vision_projector = self.vlm_model.model.mm_projector
93+
# Llava-lht merges the language model with the vision projector and vision model
94+
self.model = self.vlm_model
95+
self.model_config = self.vlm_model_config.text_config
96+
self.pruning_config = {
97+
'image_token_start_index': 5,
98+
'image_token_length': self.vlm_model_config.image_seq_length,
99+
'select_layer': self.vlm_model_config.vision_feature_layer,
100+
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
101+
'image_token_index': self.vlm_model_config.image_token_index
102+
}
103+
self.processor = None
104+
105+
def get_extra_rot_module_besides_embed_layers(self):
106+
return [self.vision_projector[2]]
107+
108+
def find_blocks(self):
109+
if self.get_modality() == 'language':
110+
super().find_blocks()
111+
elif self.get_modality() == 'vision':
112+
self.blocks = self.vision_model.vision_tower.vision_model.encoder.layers
113+
else:
114+
raise Exception(f'Llava do not support {self.get_modality()} modality.')
115+
116+
def get_layernorms_in_block(self, block):
117+
if self.get_modality() == 'language':
118+
return super().get_layernorms_in_block(block)
119+
elif self.get_modality() == 'vision':
120+
return {
121+
'layer_norm1': block.layer_norm1,
122+
'layer_norm2': block.layer_norm2,
123+
}
124+
else:
125+
raise Exception(f'Llava do not support {self.get_modality()} modality.')
126+
127+
def get_subsets_in_block(self, block):
128+
if self.get_modality() == 'language':
129+
return super().get_subsets_in_block(block)
130+
elif self.get_modality() == 'vision':
131+
return [
132+
{
133+
'layers': {
134+
'self_attn.q_proj': block.self_attn.q_proj,
135+
'self_attn.k_proj': block.self_attn.k_proj,
136+
'self_attn.v_proj': block.self_attn.v_proj,
137+
},
138+
'prev_op': [block.layer_norm1],
139+
'input': ['self_attn.q_proj'],
140+
'inspect': block.self_attn,
141+
'has_kwargs': True,
142+
},
143+
{
144+
'layers': {'self_attn.out_proj': block.self_attn.out_proj},
145+
'prev_op': [block.self_attn.v_proj],
146+
'input': ['self_attn.out_proj'],
147+
'inspect': block.self_attn.out_proj,
148+
'has_kwargs': False,
149+
},
150+
{
151+
'layers': {'mlp.fc1': block.mlp.fc1},
152+
'prev_op': [block.layer_norm2],
153+
'input': ['mlp.fc1'],
154+
'inspect': block.mlp.fc1,
155+
'has_kwargs': False,
156+
'is_mlp': True,
157+
},
158+
{
159+
'layers': {'mlp.fc2': block.mlp.fc2},
160+
'prev_op': [block.mlp.fc1],
161+
'input': ['mlp.fc2'],
162+
'inspect': block.mlp.fc2,
163+
'has_kwargs': False,
164+
'is_mlp': True,
165+
'do_trans': False
166+
},
167+
]
168+
else:
169+
raise Exception(f'Llava do not support {self.get_modality()} modality.')
170+
171+
172+
if version.parse(torch.__version__) >= version.parse('2.1.2'):
173+
best_fit_attn_implementation = 'sdpa'
174+
else:
175+
best_fit_attn_implementation = 'eager'
176+
177+
178+
@MODEL_REGISTRY
179+
class LlavaLHTEval(Llava):
180+
def __init__(
181+
self,
182+
llmc_model,
183+
pretrained: str = 'liuhaotian/llava-v1.5-7b',
184+
truncation: Optional[bool] = True,
185+
device: Optional[str] = 'cuda',
186+
batch_size: Optional[Union[int, str]] = 1,
187+
model_name=None,
188+
attn_implementation=best_fit_attn_implementation,
189+
device_map: str = '',
190+
conv_template='vicuna_v1',
191+
use_cache: bool = False,
192+
tie_weights: bool = True,
193+
truncate_context=False, # set it False for LLaVA-1.6
194+
customized_config=None, # ends in json
195+
**kwargs,
196+
) -> None:
197+
lmms.__init__(self)
198+
# Do not use kwargs for now
199+
assert kwargs == {}, f'Unexpected kwargs: {kwargs}'
200+
201+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
202+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
203+
self.accelerator = accelerator
204+
if accelerator.num_processes > 1:
205+
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
206+
self.device_map = f'cuda:{accelerator.local_process_index}'
207+
elif accelerator.num_processes == 1 and device_map == 'auto':
208+
self._device = torch.device(device)
209+
self.device_map = device_map
210+
else:
211+
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
212+
self.device_map = f'cuda:{accelerator.local_process_index}'
213+
214+
llava_model_args = {
215+
'multimodal': True,
216+
}
217+
if customized_config is not None:
218+
llava_model_args['customized_config'] = customized_config
219+
if attn_implementation is not None:
220+
llava_model_args['attn_implementation'] = attn_implementation
221+
if 'use_flash_attention_2' in kwargs:
222+
llava_model_args['use_flash_attention_2'] = kwargs['use_flash_attention_2']
223+
model_name = model_name if model_name is not None else get_model_name_from_path(pretrained)
224+
self._model = llmc_model.cuda()
225+
self._config = self._model.config
226+
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, use_fast=False)
227+
self._image_processor = None
228+
if 'llava' in model_name.lower():
229+
mm_use_im_start_end = getattr(self._config, 'mm_use_im_start_end', False)
230+
mm_use_im_patch_token = getattr(self._config, 'mm_use_im_patch_token', True)
231+
if mm_use_im_patch_token:
232+
self._tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
233+
if mm_use_im_start_end:
234+
self._tokenizer.add_tokens(
235+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN],
236+
special_tokens=True
237+
)
238+
self._image_processor = self._model.get_vision_tower().image_processor
239+
if hasattr(self._config, 'max_sequence_length'):
240+
self._max_length = self._config.max_sequence_length
241+
else:
242+
self._max_length = 2048
243+
self.model.eval()
244+
if tie_weights:
245+
self.model.tie_weights()
246+
247+
self.truncation = truncation
248+
self.batch_size_per_gpu = int(batch_size)
249+
self.conv_template = conv_template
250+
self.use_cache = use_cache
251+
self.truncate_context = truncate_context
252+
# assert self.batch_size_per_gpu == 1, (
253+
# "Llava currently does not support batched generation. "
254+
# "See: https://github.com/haotian-liu/LLaVA/issues/754. "
255+
# "HF Llava also has this issue."
256+
# )
257+
if accelerator.num_processes > 1:
258+
assert accelerator.distributed_type in [
259+
DistributedType.FSDP,
260+
DistributedType.MULTI_GPU,
261+
DistributedType.DEEPSPEED], (
262+
'Unsupported distributed type provided. '
263+
'Only DDP and FSDP are supported.')
264+
# To use DistributedType.DEEPSPEED, run `accelerate config` first.
265+
# You must select zero stage 0 (equivalent to DDP) for model preparation to work.
266+
# Attempts to support zero stage 2 via kwargs failed.
267+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
268+
kwargs = {
269+
'train_micro_batch_size_per_gpu': self.batch_size_per_gpu,
270+
'train_batch_size': self.batch_size_per_gpu * accelerator.num_processes,
271+
}
272+
AcceleratorState().deepspeed_plugin.deepspeed_config_process(
273+
must_match=True, **kwargs
274+
)
275+
logger.info(
276+
'Detected that you are using DistributedType.DEEPSPEED. '
277+
'Make sure you run `accelerate config` and set zero stage to 0'
278+
)
279+
280+
if (
281+
accelerator.distributed_type == DistributedType.FSDP
282+
or accelerator.distributed_type == DistributedType.DEEPSPEED
283+
):
284+
self._model = accelerator.prepare(self.model)
285+
else:
286+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
287+
self.accelerator = accelerator
288+
if self.accelerator.is_local_main_process:
289+
logger.info(f'Using {accelerator.num_processes} devices with data parallelism')
290+
self._rank = self.accelerator.local_process_index
291+
self._world_size = self.accelerator.num_processes
292+
elif accelerator.num_processes == 1 and device_map == 'auto':
293+
logger.info(f'Using {accelerator.num_processes} devices with tensor parallelism')
294+
self._rank = 0
295+
self._word_size = 1
296+
else:
297+
logger.info(f'Using single device: {self._device}')
298+
self.model.to(self._device)
299+
self._rank = 0
300+
self._world_size = 1

requirements/runtime.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
torch>=2.1.0
1+
torch>=2.2.0
22
torchvision
33
timm
44
pillow
55
loguru
66
transformers>=4.45.2
7+
lmms-eval
78
huggingface-hub
89
sentencepiece
910
protobuf
@@ -31,6 +32,5 @@ qwen-vl-utils
3132
tiktoken
3233
librosa
3334
human_eval
34-
lmms-eval
3535
imageio
3636
diffusers

0 commit comments

Comments
 (0)