Skip to content

Commit ec771c0

Browse files
authored
llava_onevision (#386)
1 parent 4d7cc78 commit ec771c0

File tree

4 files changed

+243
-2
lines changed

4 files changed

+243
-2
lines changed

llmc/compression/sparsification/base_blockwise_sparsification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def set_sparsity_config(self):
4040
if 'sparsity' in self.sparsity_config['weight']:
4141
self.sparsity = self.sparsity_config['weight']['sparsity']
4242
self.W_mask = None
43-
elif 'n_prune_layers' in self.sparsity_config:
43+
elif 'n_prune_layers' in self.sparsity_config['weight']:
4444
self.n_prune_layers = self.sparsity_config['weight']['n_prune_layers']
4545

4646
def set_kv_sparse_config(self):

llmc/compression/token_reduction/tome.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,15 @@
44

55
import torch
66
import torch.nn.functional as F
7+
from loguru import logger
78
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
8-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLVisionBlock
9+
10+
try:
11+
from transformers.models.qwen2_vl.modeling_qwen2_vl import \
12+
Qwen2VLVisionBlock
13+
except ModuleNotFoundError:
14+
logger.info('Qwen2VLVisionBlock not found, if need, please upgrade transformers first.')
15+
Qwen2VLVisionBlock = None
916

1017
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
1118

llmc/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .llama import Llama
1212
from .llava import Llava
1313
from .llava_hf import LlavaHf
14+
from .llava_onevision import Llava_OneVision
1415
from .minicpm import MiniCPM
1516
from .minicpmv import MiniCPMV
1617
from .mistral import Mistral

llmc/models/llava_onevision.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
from datetime import timedelta
2+
from typing import List, Optional, Tuple, Union
3+
4+
import torch
5+
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
6+
from accelerate.state import AcceleratorState
7+
from lmms_eval.api.model import lmms
8+
from lmms_eval.models.llava_onevision import Llava_OneVision as LLaVA_OV
9+
from loguru import logger
10+
from packaging import version
11+
from transformers import AutoConfig
12+
13+
from llmc.utils.registry_factory import MODEL_REGISTRY
14+
15+
from .llama import Llama
16+
17+
try:
18+
from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
19+
DEFAULT_IMAGE_TOKEN, IGNORE_INDEX,
20+
IMAGE_TOKEN_INDEX)
21+
from llava.conversation import SeparatorStyle, conv_templates
22+
from llava.mm_utils import (KeywordsStoppingCriteria,
23+
get_model_name_from_path, process_images,
24+
tokenizer_image_token)
25+
from llava.model.builder import load_pretrained_model
26+
except ImportError as e:
27+
logger.debug(
28+
f'LLaVA is not installed. Please install LLaVA to use this model.\nError: {e}'
29+
)
30+
31+
# Determine best attention implementation
32+
if version.parse(torch.__version__) >= version.parse('2.1.2'):
33+
best_fit_attn_implementation = 'sdpa'
34+
else:
35+
best_fit_attn_implementation = 'eager'
36+
37+
38+
@MODEL_REGISTRY
39+
class Llava_OneVision(Llama):
40+
def __init__(self, config, device_map=None, use_cache=False):
41+
super().__init__(config, device_map, use_cache)
42+
43+
def build_model(self):
44+
self.vlm_model_config = AutoConfig.from_pretrained(
45+
self.model_path, trust_remote_code=True
46+
)
47+
if not self.use_cache:
48+
self.vlm_model_config.text_config.use_cache = False
49+
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
50+
51+
llava_model_args = {
52+
'multimodal': True,
53+
}
54+
llava_model_args['attn_implementation'] = best_fit_attn_implementation
55+
56+
model_name = 'llava_qwen'
57+
58+
overwrite_config = {}
59+
overwrite_config['mm_spatial_pool_stride'] = 2
60+
overwrite_config['mm_spatial_pool_mode'] = 'bilinear'
61+
62+
llava_model_args['overwrite_config'] = overwrite_config
63+
try:
64+
# Try to load the model with the multimodal argument
65+
self.tokenizer, self.vlm_model, image_processor, max_length = (
66+
load_pretrained_model(
67+
self.model_path,
68+
None,
69+
model_name,
70+
device_map=self.device_map,
71+
**llava_model_args,
72+
)
73+
)
74+
except TypeError:
75+
# for older versions of LLaVA that don't have multimodal argument
76+
llava_model_args.pop('multimodal', None)
77+
self.tokenizer, self.vlm_model, image_processor, max_length = (
78+
load_pretrained_model(
79+
self.model_path,
80+
None,
81+
model_name,
82+
device_map=self.device_map,
83+
**llava_model_args,
84+
)
85+
)
86+
87+
self.vlm_model.image_processor = image_processor
88+
self.vlm_model.max_length = max_length
89+
self.vlm_model.tokenizer = self.tokenizer
90+
91+
self.eval_name = 'Llava_OneVision_Eval'
92+
self.mm_model = self.vlm_model
93+
logger.info(f'self.vlm_model : {self.vlm_model}')
94+
self.vision_model = self.vlm_model.get_vision_tower()
95+
self.vision_projector = self.vlm_model.model.mm_projector
96+
self.model = self.vlm_model
97+
self.model_config = self.vlm_model_config.text_config
98+
self.pruning_config = {
99+
'is_video_model': False,
100+
'image_token_length': self.vlm_model_config.image_seq_length,
101+
'select_layer': self.vlm_model_config.vision_feature_layer,
102+
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
103+
'image_token_index': self.vlm_model_config.image_token_index,
104+
}
105+
106+
self.processor = None
107+
108+
109+
@MODEL_REGISTRY
110+
class Llava_OneVision_Eval(LLaVA_OV):
111+
"""Llava Model."""
112+
113+
def __init__(
114+
self,
115+
llmc_model,
116+
pretrained: str = 'liuhaotian/llava-v1.5-7b',
117+
truncation: Optional[bool] = True,
118+
device: Optional[str] = 'cuda:0',
119+
batch_size: Optional[Union[int, str]] = 1,
120+
model_name: Optional[str] = None,
121+
attn_implementation: Optional[str] = best_fit_attn_implementation,
122+
device_map: Optional[str] = 'cuda:0',
123+
conv_template: Optional[str] = 'qwen_1_5',
124+
use_cache: Optional[bool] = True,
125+
truncate_context: Optional[
126+
bool
127+
] = False, # whether to truncate the context in generation, set it False for LLaVA-1.6
128+
customized_config: Optional[str] = None, # ends in json
129+
max_frames_num: Optional[int] = 32,
130+
mm_spatial_pool_stride: Optional[int] = 2,
131+
mm_spatial_pool_mode: Optional[str] = 'bilinear',
132+
token_strategy: Optional[
133+
str
134+
] = 'single', # could be "single" or "multiple", "multiple"
135+
# denotes adding multiple <image> tokens for each frame
136+
video_decode_backend: str = 'decord',
137+
**kwargs,
138+
) -> None:
139+
lmms.__init__(self)
140+
# Do not use kwargs for now
141+
assert kwargs == {}, f'Unexpected kwargs: {kwargs}'
142+
143+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
144+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
145+
if accelerator.num_processes > 1:
146+
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
147+
self.device_map = f'cuda:{accelerator.local_process_index}'
148+
elif accelerator.num_processes == 1 and device_map == 'auto':
149+
self._device = torch.device(device)
150+
self.device_map = device_map
151+
else:
152+
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
153+
self.device_map = f'cuda:{accelerator.local_process_index}'
154+
155+
self.pretrained = pretrained
156+
self.token_strategy = token_strategy
157+
self.max_frames_num = max_frames_num
158+
self.mm_spatial_pool_stride = mm_spatial_pool_stride
159+
self.mm_spatial_pool_mode = mm_spatial_pool_mode
160+
self.video_decode_backend = video_decode_backend
161+
162+
# cfg_pretrained = AutoConfig.from_pretrained(self.pretrained)
163+
164+
self._model = llmc_model.cuda()
165+
self._tokenizer, self._image_processor, self._max_length = (
166+
llmc_model.tokenizer,
167+
llmc_model.image_processor,
168+
llmc_model.max_length,
169+
)
170+
171+
del llmc_model.tokenizer
172+
del llmc_model.image_processor
173+
del llmc_model.max_length
174+
175+
self._config = self._model.config
176+
self.model.eval()
177+
self.truncation = truncation
178+
self.batch_size_per_gpu = int(batch_size)
179+
self.conv_template = conv_template
180+
self.use_cache = use_cache
181+
self.truncate_context = truncate_context
182+
assert (
183+
self.batch_size_per_gpu == 1
184+
), 'Llava currently does not support batched generation.'
185+
186+
if accelerator.num_processes > 1:
187+
assert accelerator.distributed_type in [
188+
DistributedType.FSDP,
189+
DistributedType.MULTI_GPU,
190+
DistributedType.DEEPSPEED,
191+
], 'Unsupported distributed type provided. Only DDP and FSDP are supported.'
192+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
193+
kwargs = {
194+
'train_micro_batch_size_per_gpu': self.batch_size_per_gpu,
195+
'train_batch_size': self.batch_size_per_gpu
196+
* accelerator.num_processes,
197+
}
198+
AcceleratorState().deepspeed_plugin.deepspeed_config_process(
199+
must_match=True, **kwargs
200+
)
201+
logger.info(
202+
'Detected that you are using DistributedType.DEEPSPEED.'
203+
)
204+
205+
if (
206+
accelerator.distributed_type == DistributedType.FSDP
207+
or accelerator.distributed_type == DistributedType.DEEPSPEED
208+
):
209+
self._model = accelerator.prepare(self.model)
210+
else:
211+
self._model = accelerator.prepare_model(
212+
self.model, evaluation_mode=True
213+
)
214+
self.accelerator = accelerator
215+
if self.accelerator.is_local_main_process:
216+
logger.info(
217+
f'Using {accelerator.num_processes} devices with data parallelism'
218+
)
219+
self._rank = self.accelerator.local_process_index
220+
self._world_size = self.accelerator.num_processes
221+
222+
elif accelerator.num_processes == 1 and device_map == 'auto':
223+
logger.info(
224+
f'Using {accelerator.num_processes} devices with tensor parallelism'
225+
)
226+
self._rank = 0
227+
self._world_size = 1
228+
229+
else:
230+
logger.info(f'Using single device: {self._device}')
231+
self.model.to(self._device)
232+
self._rank = 0
233+
self._world_size = 1

0 commit comments

Comments
 (0)