Skip to content

Commit de25b60

Browse files
authored
support video llava (#378)
1 parent 5449ca3 commit de25b60

File tree

4 files changed

+241
-34
lines changed

4 files changed

+241
-34
lines changed

llmc/compression/token_reduction/pyramiddrop.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,24 @@ def add_sparse_config(self):
3030
'tokenizer_padding_side',
3131
'right',
3232
)
33-
special_config['image_token_index'] = self.model.pruning_config[
34-
'image_token_index'
35-
]
33+
special_config['is_video_model'] = self.model.pruning_config['is_video_model']
34+
35+
# vision_token can be image or video
36+
if special_config['is_video_model']:
37+
special_config['vision_token_index'] = self.model.pruning_config[
38+
'video_token_index'
39+
]
40+
special_config['vision_token_length'] = self.model.pruning_config[
41+
'video_token_length'
42+
]
43+
else:
44+
special_config['vision_token_index'] = self.model.pruning_config[
45+
'image_token_index'
46+
]
47+
special_config['vision_token_length'] = self.model.pruning_config[
48+
'image_token_length'
49+
]
50+
3651
self.model.model.parameters = special_config
3752

3853
def register_reduction_modules(self):
@@ -56,6 +71,10 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):
5671
image_token_posi = pruning_pars['image_token_posi']
5772
image_token_ratio_list = pruning_pars['image_token_ratio_list']
5873

74+
# for decoding stage
75+
if features.shape[1] == 1:
76+
return args, kwargs
77+
5978
if position_ids is None:
6079
position_ids = torch.arange(
6180
0, features.shape[1], dtype=torch.long, device=features.device
@@ -297,26 +316,31 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):
297316
return (new_input_embeds,), kwargs
298317

299318
def input_hook(module, input_args, pruning_pars):
319+
# for the decoding stage
320+
if input_args[0].shape[1] == 1:
321+
return input_args
300322
input_ids = input_args[0]
301323
pre_prompt_length_list = []
302324
image_token_posi = []
303-
image_tokens = []
304-
IMAGE_TOKEN_INDEX = pruning_pars['image_token_index']
325+
vision_tokens = []
326+
VISION_TOKEN_INDEX = pruning_pars['vision_token_index']
305327

306328
# find the position of the first image token
307329
for seq in input_ids:
308-
image_token_idxs = (seq == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0]
309-
image_tokens.append(image_token_idxs.shape[0])
330+
image_token_idxs = (seq == VISION_TOKEN_INDEX).nonzero(as_tuple=True)[0]
331+
vision_tokens.append(pruning_pars['vision_token_length'])
310332
image_token_posi.append(image_token_idxs[0].item())
311333
pre_prompt_length_list.append(seq.shape[0] - image_token_idxs.shape[0])
312334

313335
pruning_pars['prompt_len'] = pre_prompt_length_list
314336
pruning_pars['image_token_posi'] = image_token_posi
315-
pruning_pars['image_tokens'] = image_tokens
337+
pruning_pars['image_tokens'] = vision_tokens
316338

317339
return input_args
318340

319341
def read_parameter_hook(module, args, kwargs, pruning_pars):
342+
if args[0].shape[1] == 1:
343+
return args, kwargs
320344
kwargs['attention_mask'] = pruning_pars['attention_mask']
321345
# kwargs['cache_position'] = pruning_pars['cache_position']
322346
kwargs['position_ids'] = pruning_pars['position_ids']

llmc/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .smollm import SmolLM
2727
from .stablelm import StableLm
2828
from .starcoder import Starcoder
29+
from .videollava import VideoLLaVA
2930
from .vila import Vila
3031
from .vit import Vit
3132
from .wan_i2v import WanI2V

llmc/models/llava.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,26 @@ def build_model(self):
4141
self.model = self.vlm_model.language_model
4242
self.model_config = self.vlm_model_config.text_config
4343
self.pruning_config = {
44+
'is_video_model': False,
4445
'image_token_start_index': 5,
45-
'image_token_length': 576,
46+
'image_token_length': self.vlm_model_config.image_seq_length,
4647
'select_layer': self.vlm_model_config.vision_feature_layer,
4748
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
48-
'image_token_index': self.vlm_model_config.image_token_index
49+
'image_token_index': self.vlm_model_config.image_token_index,
4950
}
5051

5152
self.processor = AutoProcessor.from_pretrained(self.model_path)
5253

5354
def get_extra_rot_module_besides_embed_layers(self):
5455
return [self.vision_projector.linear_2]
5556

56-
def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True, return_inputs=True): # noqa
57+
def batch_process(
58+
self,
59+
img_qas,
60+
calib_or_eval='eval',
61+
apply_chat_template=True,
62+
return_inputs=True,
63+
): # noqa
5764
assert calib_or_eval == 'calib' or calib_or_eval == 'eval'
5865
assert apply_chat_template
5966
messages = []
@@ -68,18 +75,16 @@ def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True,
6875
'role': 'user',
6976
'content': [
7077
{'type': 'image'},
71-
{'type': 'text', 'text': img_qas[idx]['question']}
72-
]
78+
{'type': 'text', 'text': img_qas[idx]['question']},
79+
],
7380
}
7481
]
7582
images.append(image)
7683
else:
7784
message = [
7885
{
7986
'role': 'user',
80-
'content': [
81-
{'type': 'text', 'text': img_qas[idx]['question']}
82-
]
87+
'content': [{'type': 'text', 'text': img_qas[idx]['question']}],
8388
}
8489
]
8590
messages.append(message)
@@ -89,10 +94,7 @@ def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True,
8994
for n in range(len(messages))
9095
]
9196
if calib_or_eval == 'calib' and self.config['calib'].get('add_answer', False):
92-
texts = [
93-
texts[n] + ' ' + answers[n]
94-
for n in range(len(texts))
95-
]
97+
texts = [texts[n] + ' ' + answers[n] for n in range(len(texts))]
9698
if calib_or_eval == 'calib':
9799
logger.info(f'Calib data is:\n{texts}')
98100
if not return_inputs:
@@ -101,8 +103,10 @@ def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True,
101103
text=texts,
102104
images=images if len(images) else None,
103105
padding=True,
104-
return_tensors='pt'
105-
).to(next(self.vlm_model.parameters()).dtype) # noqa
106+
return_tensors='pt',
107+
).to(
108+
next(self.vlm_model.parameters()).dtype
109+
) # noqa
106110
return inputs
107111

108112
def find_blocks(self):
@@ -162,7 +166,7 @@ def get_subsets_in_block(self, block):
162166
'inspect': block.mlp.fc2,
163167
'has_kwargs': False,
164168
'is_mlp': True,
165-
'do_trans': False
169+
'do_trans': False,
166170
},
167171
]
168172
else:
@@ -204,8 +208,9 @@ def __init__(
204208

205209
self._model = llmc_model.cuda()
206210
self.pretrained = pretrained
207-
self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision,
208-
trust_remote_code=trust_remote_code)
211+
self._image_processor = AutoProcessor.from_pretrained(
212+
pretrained, revision=revision, trust_remote_code=trust_remote_code
213+
)
209214
# Pad from left for batched generation:
210215
# https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips
211216
self._image_processor.tokenizer.padding_side = 'left'
@@ -218,24 +223,36 @@ def __init__(
218223
if accelerator.distributed_type == DistributedType.DEEPSPEED:
219224
kwargs = {
220225
'train_micro_batch_size_per_gpu': self.batch_size_per_gpu,
221-
'train_batch_size': self.batch_size_per_gpu * accelerator.num_processes,
226+
'train_batch_size': self.batch_size_per_gpu
227+
* accelerator.num_processes,
222228
}
223229
AcceleratorState().deepspeed_plugin.deepspeed_config_process(
224-
must_match=True, **kwargs)
225-
logger.info('Detected that you are using DistributedType.DEEPSPEED. \
226-
Make sure you run `accelerate config` and set zero stage to 0')
227-
if accelerator.distributed_type == DistributedType.FSDP or \
228-
accelerator.distributed_type == DistributedType.DEEPSPEED:
230+
must_match=True, **kwargs
231+
)
232+
logger.info(
233+
'Detected that you are using DistributedType.DEEPSPEED. \
234+
Make sure you run `accelerate config` and set zero stage to 0'
235+
)
236+
if (
237+
accelerator.distributed_type == DistributedType.FSDP
238+
or accelerator.distributed_type == DistributedType.DEEPSPEED
239+
):
229240
self._model = accelerator.prepare(self.model)
230241
else:
231-
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
242+
self._model = accelerator.prepare_model(
243+
self.model, evaluation_mode=True
244+
)
232245
self.accelerator = accelerator
233246
if self.accelerator.is_local_main_process:
234-
logger.info(f'Using {accelerator.num_processes} devices with data parallelism')
247+
logger.info(
248+
f'Using {accelerator.num_processes} devices with data parallelism'
249+
)
235250
self._rank = self.accelerator.local_process_index
236251
self._world_size = self.accelerator.num_processes
237252
elif accelerator.num_processes == 1 and device_map == 'auto':
238-
logger.info(f'Using {accelerator.num_processes} devices with pipeline parallelism')
253+
logger.info(
254+
f'Using {accelerator.num_processes} devices with pipeline parallelism'
255+
)
239256
self._rank = 0
240257
self._word_size = 1
241258
else:

llmc/models/videollava.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from datetime import timedelta
2+
from typing import List, Optional, Tuple, Union
3+
4+
import torch
5+
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
6+
from accelerate.state import AcceleratorState
7+
from lmms_eval.api.model import lmms
8+
from lmms_eval.models.video_llava import VideoLLaVA as VL
9+
from loguru import logger
10+
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
11+
GenerationConfig, VideoLlavaForConditionalGeneration,
12+
VideoLlavaProcessor)
13+
14+
from llmc.utils.registry_factory import MODEL_REGISTRY
15+
16+
from .llama import Llama
17+
18+
19+
@MODEL_REGISTRY
20+
class VideoLLaVA(Llama):
21+
def __init__(self, config, device_map=None, use_cache=False):
22+
super().__init__(config, device_map, use_cache)
23+
24+
def build_model(self):
25+
self.vlm_model_config = AutoConfig.from_pretrained(
26+
self.model_path, trust_remote_code=True
27+
)
28+
if not self.use_cache:
29+
self.vlm_model_config.text_config.use_cache = False
30+
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
31+
self.vlm_model = VideoLlavaForConditionalGeneration.from_pretrained(
32+
self.model_path,
33+
config=self.vlm_model_config,
34+
torch_dtype=self.torch_dtype,
35+
low_cpu_mem_usage=True,
36+
)
37+
self.eval_name = 'VideoLLaVAHfEval'
38+
self.mm_model = self.vlm_model
39+
logger.info(f'self.vlm_model : {self.vlm_model}')
40+
self.video_tower = self.vlm_model.video_tower
41+
self.image_tower = self.vlm_model.image_tower
42+
self.vision_projector = self.vlm_model.multi_modal_projector
43+
self.model = self.vlm_model.language_model
44+
self.model_config = self.vlm_model_config.text_config
45+
self.pruning_config = {
46+
'is_video_model': True,
47+
'image_token_length': self.vlm_model_config.image_seq_length,
48+
'video_token_length': self.vlm_model_config.video_seq_length,
49+
'select_layer': self.vlm_model_config.vision_feature_layer,
50+
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
51+
'image_token_index': self.vlm_model_config.image_token_index,
52+
'video_token_index': self.vlm_model_config.video_token_index,
53+
}
54+
55+
56+
@MODEL_REGISTRY
57+
class VideoLLaVAHfEval(VL):
58+
def __init__(
59+
self,
60+
llmc_model,
61+
pretrained: str = 'LanguageBind/Video-LLaVA-7B-hf',
62+
truncation: Optional[bool] = True,
63+
device: Optional[str] = 'cuda:0',
64+
dtype: Optional[Union[str, torch.dtype]] = 'auto',
65+
batch_size: Optional[Union[int, str]] = 1,
66+
trust_remote_code: Optional[bool] = False,
67+
revision=None,
68+
attn_implementation=(
69+
'sdpa' if torch.__version__ > '2.1.2' else 'eager'
70+
),
71+
# inference implementation for attention, can be "sdpa", "eager", "flash_attention_2".
72+
# Seems FA2 is not effective during inference:
73+
# https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453/5
74+
device_map='cuda:0',
75+
conv_template='llava_v1',
76+
use_cache=True,
77+
truncate_context=False,
78+
num_frames: int = 8,
79+
# whether to truncate the context in generation,
80+
# set it False for LLaVA-1.6
81+
**kwargs,
82+
) -> None:
83+
lmms.__init__(self)
84+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
85+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
86+
if accelerator.num_processes > 1:
87+
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
88+
self.device_map = f'cuda:{accelerator.local_process_index}'
89+
elif accelerator.num_processes == 1 and device_map == 'auto':
90+
self._device = torch.device(device)
91+
self.device_map = device_map
92+
else:
93+
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
94+
self.device_map = f'cuda:{accelerator.local_process_index}'
95+
96+
self.pretrained = pretrained
97+
self._model = llmc_model.cuda()
98+
self._processor = VideoLlavaProcessor.from_pretrained(pretrained)
99+
self.prompt = 'USER: <video>{}? ASSISTANT:'
100+
self.num_frames = num_frames
101+
assert (
102+
num_frames == 8
103+
), 'num_frames must be 8'
104+
# self.model_name = get_model_name_from_path(pretrained)
105+
# self._tokenizer, self._model, self.processor,
106+
# self._max_length = load_pretrained_model(pretrained,
107+
# None, self.model_name, device_map=self.device_map)
108+
# self.video_processor = self.processor["video"]
109+
self._config = self._model.config
110+
self.model.eval()
111+
self.model.tie_weights()
112+
self.truncation = truncation
113+
self.batch_size_per_gpu = int(batch_size)
114+
self.conv_template = conv_template
115+
self.use_cache = use_cache
116+
self.truncate_context = truncate_context
117+
# assert self.batch_size_per_gpu == 1,
118+
# "Llava currently does not support batched generation.
119+
# See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
120+
if accelerator.num_processes > 1:
121+
assert accelerator.distributed_type in [
122+
DistributedType.FSDP,
123+
DistributedType.MULTI_GPU,
124+
DistributedType.DEEPSPEED,
125+
], 'Unsupported distributed type provided. Only DDP and FSDP are supported.'
126+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
127+
kwargs = {
128+
'train_micro_batch_size_per_gpu': self.batch_size_per_gpu,
129+
'train_batch_size': self.batch_size_per_gpu
130+
* accelerator.num_processes,
131+
}
132+
AcceleratorState().deepspeed_plugin.deepspeed_config_process(
133+
must_match=True, **kwargs
134+
)
135+
logger.info(
136+
'Detected that you are using DistributedType.DEEPSPEED. ' +
137+
'Make sure you run `accelerate config` and set zero stage to 0'
138+
)
139+
if (
140+
accelerator.distributed_type == DistributedType.FSDP
141+
or accelerator.distributed_type == DistributedType.DEEPSPEED
142+
):
143+
self._model = accelerator.prepare(self.model)
144+
else:
145+
self._model = accelerator.prepare_model(
146+
self.model, evaluation_mode=True
147+
)
148+
self.accelerator = accelerator
149+
if self.accelerator.is_local_main_process:
150+
logger.info(
151+
f'Using {accelerator.num_processes} devices with data parallelism'
152+
)
153+
self._rank = self.accelerator.local_process_index
154+
self._world_size = self.accelerator.num_processes
155+
elif accelerator.num_processes == 1 and device_map == 'auto':
156+
logger.info(
157+
f'Using {accelerator.num_processes} devices with tensor parallelism'
158+
)
159+
self._rank = 0
160+
self._word_size = 1
161+
else:
162+
logger.info(f'Using single device: {self._device}')
163+
self.model.to(self._device)
164+
self._rank = 0
165+
self._world_size = 1

0 commit comments

Comments
 (0)