Skip to content

Commit 4683d0c

Browse files
authored
llava has been renamed to llava-hf, and llava-lht is now renamed to llava. (#381)
1 parent 6f3755c commit 4683d0c

File tree

4 files changed

+422
-422
lines changed

4 files changed

+422
-422
lines changed

llmc/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .internvl2 import InternVL2
1111
from .llama import Llama
1212
from .llava import Llava
13-
from .llava_lht import LlavaLHT
13+
from .llava_hf import LlavaHf
1414
from .minicpm import MiniCPM
1515
from .minicpmv import MiniCPMV
1616
from .mistral import Mistral

llmc/models/llava.py

Lines changed: 158 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,117 @@
1-
from typing import List, Optional, Tuple, Union
1+
import types
2+
from datetime import timedelta
3+
from typing import Optional, Union
24

35
import torch
4-
from accelerate import Accelerator, DistributedType
6+
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
57
from accelerate.state import AcceleratorState
68
from lmms_eval.api.model import lmms
7-
from lmms_eval.models.llava_hf import LlavaHf
9+
from lmms_eval.models.llava import Llava as LLaVA
810
from loguru import logger
9-
from PIL import Image
10-
from transformers import (AutoConfig, AutoProcessor,
11-
LlavaForConditionalGeneration)
11+
from packaging import version
12+
from transformers import AutoConfig, AutoTokenizer
1213

1314
from llmc.utils.registry_factory import MODEL_REGISTRY
1415

1516
from .llama import Llama
1617

18+
try:
19+
from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
20+
DEFAULT_IMAGE_PATCH_TOKEN)
21+
from llava.mm_utils import get_model_name_from_path
22+
from llava.model.builder import load_pretrained_model
23+
from llava.model.language_model.llava_llama import LlavaConfig
24+
except Exception as e:
25+
logger.debug('LLaVA is not installed. Please install LLaVA to use this model.\nError: %s' % e)
26+
1727

1828
@MODEL_REGISTRY
1929
class Llava(Llama):
2030
def __init__(self, config, device_map=None, use_cache=False):
2131
super().__init__(config, device_map, use_cache)
2232

33+
def build_tokenizer(self):
34+
pass
35+
2336
def build_model(self):
37+
self.llava_config = LlavaConfig.from_pretrained(
38+
self.model_path, trust_remote_code=True
39+
)
2440
self.vlm_model_config = AutoConfig.from_pretrained(
2541
self.model_path, trust_remote_code=True
2642
)
2743
if not self.use_cache:
28-
self.vlm_model_config.text_config.use_cache = False
44+
self.llava_config.use_cache = False
45+
self.vlm_model_config.use_cache = False
2946
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
30-
self.vlm_model = LlavaForConditionalGeneration.from_pretrained(
47+
self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model(
3148
self.model_path,
32-
config=self.vlm_model_config,
49+
None,
50+
get_model_name_from_path(self.model_path),
51+
load_8bit=False,
52+
load_4bit=False,
53+
device='cpu',
3354
torch_dtype=self.torch_dtype,
34-
low_cpu_mem_usage=True,
55+
config=self.llava_config,
3556
)
36-
self.eval_name = 'LlavaHfEval'
57+
58+
# llava forward not support "cache_position"
59+
ori_forward = self.vlm_model.forward
60+
61+
def safe_forward(*args, **kwargs):
62+
kwargs['use_cache'] = False
63+
kwargs.pop('cache_position', None)
64+
return ori_forward(*args, **kwargs)
65+
self.vlm_model.forward = safe_forward
66+
67+
# llava generate use "inputs" instead of "input_ids"
68+
ori_generate = self.vlm_model.generate
69+
70+
def safe_generate(*args, **kwargs):
71+
if 'input_ids' in kwargs:
72+
kwargs['inputs'] = kwargs.pop('input_ids')
73+
return ori_generate(*args, **kwargs)
74+
self.vlm_model.generate = safe_generate
75+
76+
# "attention_mask" is passed via kwargs rather than as an explicit keyword argument.
77+
ori_prepare_inputs_for_generation = self.vlm_model.prepare_inputs_for_generation
78+
79+
def safe_prepare_inputs_for_generation(
80+
self, input_ids, past_key_values=None,
81+
inputs_embeds=None, attention_mask=None, **kwargs):
82+
if attention_mask is not None:
83+
kwargs['attention_mask'] = attention_mask
84+
return ori_prepare_inputs_for_generation(
85+
input_ids, past_key_values, inputs_embeds, **kwargs)
86+
self.vlm_model.prepare_inputs_for_generation = types.MethodType(
87+
safe_prepare_inputs_for_generation, self.vlm_model
88+
)
89+
90+
self.eval_name = 'LlavaEval'
3791
self.mm_model = self.vlm_model
3892
logger.info(f'self.vlm_model : {self.vlm_model}')
39-
self.vision_model = self.vlm_model.vision_tower
40-
self.vision_projector = self.vlm_model.multi_modal_projector
41-
self.model = self.vlm_model.language_model
93+
self.vision_model = self.vlm_model.get_vision_tower()
94+
self.vision_projector = self.vlm_model.model.mm_projector
95+
# Llava merges the language model with the vision projector and vision model
96+
self.model = self.vlm_model
4297
self.model_config = self.vlm_model_config.text_config
4398
self.pruning_config = {
44-
'is_video_model': False,
4599
'image_token_start_index': 5,
46100
'image_token_length': self.vlm_model_config.image_seq_length,
47101
'select_layer': self.vlm_model_config.vision_feature_layer,
48102
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
49-
'image_token_index': self.vlm_model_config.image_token_index,
103+
'image_token_index': self.vlm_model_config.image_token_index
50104
}
51-
52-
self.processor = AutoProcessor.from_pretrained(self.model_path)
105+
self.processor = None
53106

54107
def get_extra_rot_module_besides_embed_layers(self):
55-
return [self.vision_projector.linear_2]
56-
57-
def batch_process(
58-
self,
59-
img_qas,
60-
calib_or_eval='eval',
61-
apply_chat_template=True,
62-
return_inputs=True,
63-
): # noqa
64-
assert calib_or_eval == 'calib' or calib_or_eval == 'eval'
65-
assert apply_chat_template
66-
messages = []
67-
images = []
68-
answers = []
69-
for idx in range(len(img_qas)):
70-
img_path = img_qas[idx]['image']
71-
if img_path is not None:
72-
image = Image.open(img_path)
73-
message = [
74-
{
75-
'role': 'user',
76-
'content': [
77-
{'type': 'image'},
78-
{'type': 'text', 'text': img_qas[idx]['question']},
79-
],
80-
}
81-
]
82-
images.append(image)
83-
else:
84-
message = [
85-
{
86-
'role': 'user',
87-
'content': [{'type': 'text', 'text': img_qas[idx]['question']}],
88-
}
89-
]
90-
messages.append(message)
91-
answers.append(img_qas[idx]['answer'])
92-
texts = [
93-
self.processor.apply_chat_template(messages[n], add_generation_prompt=True)
94-
for n in range(len(messages))
95-
]
96-
if calib_or_eval == 'calib' and self.config['calib'].get('add_answer', False):
97-
texts = [texts[n] + ' ' + answers[n] for n in range(len(texts))]
98-
if calib_or_eval == 'calib':
99-
logger.info(f'Calib data is:\n{texts}')
100-
if not return_inputs:
101-
return texts
102-
inputs = self.processor(
103-
text=texts,
104-
images=images if len(images) else None,
105-
padding=True,
106-
return_tensors='pt',
107-
).to(
108-
next(self.vlm_model.parameters()).dtype
109-
) # noqa
110-
return inputs
108+
return [self.vision_projector[2]]
111109

112110
def find_blocks(self):
113111
if self.get_modality() == 'language':
114112
super().find_blocks()
115113
elif self.get_modality() == 'vision':
116-
self.blocks = self.vision_model.vision_model.encoder.layers
114+
self.blocks = self.vision_model.vision_tower.vision_model.encoder.layers
117115
else:
118116
raise Exception(f'Llava do not support {self.get_modality()} modality.')
119117

@@ -166,98 +164,141 @@ def get_subsets_in_block(self, block):
166164
'inspect': block.mlp.fc2,
167165
'has_kwargs': False,
168166
'is_mlp': True,
169-
'do_trans': False,
167+
'do_trans': False
170168
},
171169
]
172170
else:
173171
raise Exception(f'Llava do not support {self.get_modality()} modality.')
174172

175173

174+
if version.parse(torch.__version__) >= version.parse('2.1.2'):
175+
best_fit_attn_implementation = 'sdpa'
176+
else:
177+
best_fit_attn_implementation = 'eager'
178+
179+
176180
@MODEL_REGISTRY
177-
class LlavaHfEval(LlavaHf):
181+
class LlavaEval(LLaVA):
178182
def __init__(
179183
self,
180184
llmc_model,
181-
pretrained: str = 'llava-hf/llava-1.5-7b-hf',
182-
revision: str = 'main',
183-
device: str = 'cuda',
184-
dtype: Optional[Union[str, torch.dtype]] = 'auto',
185-
batch_size: int = 1,
186-
trust_remote_code: Optional[bool] = False,
187-
attn_implementation: Optional[str] = None,
185+
pretrained: str = 'liuhaotian/llava-v1.5-7b',
186+
truncation: Optional[bool] = True,
187+
device: Optional[str] = 'cuda',
188+
batch_size: Optional[Union[int, str]] = 1,
189+
model_name=None,
190+
attn_implementation=best_fit_attn_implementation,
188191
device_map: str = '',
189-
chat_template: Optional[str] = None,
192+
conv_template='vicuna_v1',
190193
use_cache: bool = False,
191-
max_frames_num: Optional[int] = 32,
194+
tie_weights: bool = True,
195+
truncate_context=False, # set it False for LLaVA-1.6 no matter truncate
196+
customized_config=None, # ends in json
192197
**kwargs,
193198
) -> None:
194-
195199
lmms.__init__(self)
196200
# Do not use kwargs for now
197201
assert kwargs == {}, f'Unexpected kwargs: {kwargs}'
198202

199-
accelerator = Accelerator()
200-
if accelerator.num_processes > 1 and device_map == '':
203+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
204+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
205+
self.accelerator = accelerator
206+
if accelerator.num_processes > 1:
201207
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
202208
self.device_map = f'cuda:{accelerator.local_process_index}'
203-
else:
209+
elif accelerator.num_processes == 1 and device_map == 'auto':
204210
self._device = torch.device(device)
205211
self.device_map = device_map
206-
if isinstance(dtype, str) and dtype != 'auto':
207-
dtype = getattr(torch, dtype)
212+
else:
213+
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
214+
self.device_map = f'cuda:{accelerator.local_process_index}'
215+
216+
llava_model_args = {
217+
'multimodal': True,
218+
}
219+
if customized_config is not None:
220+
llava_model_args['customized_config'] = customized_config
221+
if attn_implementation is not None:
222+
llava_model_args['attn_implementation'] = attn_implementation
223+
if 'use_flash_attention_2' in kwargs:
224+
llava_model_args['use_flash_attention_2'] = kwargs['use_flash_attention_2']
225+
model_name = model_name if model_name is not None else get_model_name_from_path(pretrained)
208226

209227
self._model = llmc_model.cuda()
210-
self.pretrained = pretrained
211-
self._image_processor = AutoProcessor.from_pretrained(
212-
pretrained, revision=revision, trust_remote_code=trust_remote_code
213-
)
214-
# Pad from left for batched generation:
215-
# https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips
216-
self._image_processor.tokenizer.padding_side = 'left'
217-
self._tokenizer = self._image_processor.tokenizer
218228
self._config = self._model.config
229+
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, use_fast=False)
230+
self._image_processor = None
231+
if 'llava' in model_name.lower():
232+
mm_use_im_start_end = getattr(self._config, 'mm_use_im_start_end', False)
233+
mm_use_im_patch_token = getattr(self._config, 'mm_use_im_patch_token', True)
234+
if mm_use_im_patch_token:
235+
self._tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
236+
if mm_use_im_start_end:
237+
self._tokenizer.add_tokens(
238+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN],
239+
special_tokens=True
240+
)
241+
self._image_processor = self._model.get_vision_tower().image_processor
242+
if hasattr(self._config, 'max_sequence_length'):
243+
self._max_length = self._config.max_sequence_length
244+
else:
245+
self._max_length = 2048
246+
247+
self.model.eval()
248+
if tie_weights:
249+
self.model.tie_weights()
250+
251+
self.truncation = truncation
219252
self.batch_size_per_gpu = int(batch_size)
220-
self.chat_template = chat_template
253+
self.conv_template = conv_template
221254
self.use_cache = use_cache
222-
if accelerator.num_processes > 1 and device_map == '':
255+
self.truncate_context = truncate_context
256+
# assert self.batch_size_per_gpu == 1, (
257+
# "Llava currently does not support batched generation. "
258+
# "See: https://github.com/haotian-liu/LLaVA/issues/754. "
259+
# "HF Llava also has this issue."
260+
# )
261+
if accelerator.num_processes > 1:
262+
assert accelerator.distributed_type in [
263+
DistributedType.FSDP,
264+
DistributedType.MULTI_GPU,
265+
DistributedType.DEEPSPEED], (
266+
'Unsupported distributed type provided. '
267+
'Only DDP and FSDP are supported.')
268+
# To use DistributedType.DEEPSPEED, run `accelerate config` first.
269+
# You must select zero stage 0 (equivalent to DDP) for model preparation to work.
270+
# Attempts to support zero stage 2 via kwargs failed.
223271
if accelerator.distributed_type == DistributedType.DEEPSPEED:
224272
kwargs = {
225273
'train_micro_batch_size_per_gpu': self.batch_size_per_gpu,
226-
'train_batch_size': self.batch_size_per_gpu
227-
* accelerator.num_processes,
274+
'train_batch_size': self.batch_size_per_gpu * accelerator.num_processes,
228275
}
229276
AcceleratorState().deepspeed_plugin.deepspeed_config_process(
230277
must_match=True, **kwargs
231278
)
232279
logger.info(
233-
'Detected that you are using DistributedType.DEEPSPEED. \
234-
Make sure you run `accelerate config` and set zero stage to 0'
280+
'Detected that you are using DistributedType.DEEPSPEED. '
281+
'Make sure you run `accelerate config` and set zero stage to 0'
235282
)
283+
236284
if (
237285
accelerator.distributed_type == DistributedType.FSDP
238286
or accelerator.distributed_type == DistributedType.DEEPSPEED
239287
):
240288
self._model = accelerator.prepare(self.model)
241289
else:
242-
self._model = accelerator.prepare_model(
243-
self.model, evaluation_mode=True
244-
)
290+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
245291
self.accelerator = accelerator
246292
if self.accelerator.is_local_main_process:
247-
logger.info(
248-
f'Using {accelerator.num_processes} devices with data parallelism'
249-
)
293+
logger.info(f'Using {accelerator.num_processes} devices with data parallelism')
250294
self._rank = self.accelerator.local_process_index
251295
self._world_size = self.accelerator.num_processes
252296
elif accelerator.num_processes == 1 and device_map == 'auto':
253-
logger.info(
254-
f'Using {accelerator.num_processes} devices with pipeline parallelism'
255-
)
297+
logger.info(f'Using {accelerator.num_processes} devices with tensor parallelism')
256298
self._rank = 0
257299
self._word_size = 1
258300
else:
259301
logger.info(f'Using single device: {self._device}')
260302
self.model.to(self._device)
261303
self._rank = 0
262-
self._word_size = 1
263-
self.accelerator = accelerator
304+
self._world_size = 1

0 commit comments

Comments
 (0)