Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions llmc/compression/token_reduction/pyramiddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,24 @@ def add_sparse_config(self):
'tokenizer_padding_side',
'right',
)
special_config['image_token_index'] = self.model.pruning_config[
'image_token_index'
]
special_config['is_video_model'] = self.model.pruning_config['is_video_model']

# vision_token can be image or video
if special_config['is_video_model']:
special_config['vision_token_index'] = self.model.pruning_config[
'video_token_index'
]
special_config['vision_token_length'] = self.model.pruning_config[
'video_token_length'
]
else:
special_config['vision_token_index'] = self.model.pruning_config[
'image_token_index'
]
special_config['vision_token_length'] = self.model.pruning_config[
'image_token_length'
]

self.model.model.parameters = special_config

def register_reduction_modules(self):
Expand All @@ -56,6 +71,10 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):
image_token_posi = pruning_pars['image_token_posi']
image_token_ratio_list = pruning_pars['image_token_ratio_list']

# for decoding stage
if features.shape[1] == 1:
return args, kwargs

if position_ids is None:
position_ids = torch.arange(
0, features.shape[1], dtype=torch.long, device=features.device
Expand Down Expand Up @@ -297,26 +316,31 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):
return (new_input_embeds,), kwargs

def input_hook(module, input_args, pruning_pars):
# for the decoding stage
if input_args[0].shape[1] == 1:
return input_args
input_ids = input_args[0]
pre_prompt_length_list = []
image_token_posi = []
image_tokens = []
IMAGE_TOKEN_INDEX = pruning_pars['image_token_index']
vision_tokens = []
VISION_TOKEN_INDEX = pruning_pars['vision_token_index']

# find the position of the first image token
for seq in input_ids:
image_token_idxs = (seq == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0]
image_tokens.append(image_token_idxs.shape[0])
image_token_idxs = (seq == VISION_TOKEN_INDEX).nonzero(as_tuple=True)[0]
vision_tokens.append(pruning_pars['vision_token_length'])
image_token_posi.append(image_token_idxs[0].item())
pre_prompt_length_list.append(seq.shape[0] - image_token_idxs.shape[0])

pruning_pars['prompt_len'] = pre_prompt_length_list
pruning_pars['image_token_posi'] = image_token_posi
pruning_pars['image_tokens'] = image_tokens
pruning_pars['image_tokens'] = vision_tokens

return input_args

def read_parameter_hook(module, args, kwargs, pruning_pars):
if args[0].shape[1] == 1:
return args, kwargs
kwargs['attention_mask'] = pruning_pars['attention_mask']
# kwargs['cache_position'] = pruning_pars['cache_position']
kwargs['position_ids'] = pruning_pars['position_ids']
Expand Down
1 change: 1 addition & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .smollm import SmolLM
from .stablelm import StableLm
from .starcoder import Starcoder
from .videollava import VideoLLaVA
from .vila import Vila
from .vit import Vit
from .wan_i2v import WanI2V
Expand Down
69 changes: 43 additions & 26 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,26 @@ def build_model(self):
self.model = self.vlm_model.language_model
self.model_config = self.vlm_model_config.text_config
self.pruning_config = {
'is_video_model': False,
'image_token_start_index': 5,
'image_token_length': 576,
'image_token_length': self.vlm_model_config.image_seq_length,
'select_layer': self.vlm_model_config.vision_feature_layer,
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
'image_token_index': self.vlm_model_config.image_token_index
'image_token_index': self.vlm_model_config.image_token_index,
}

self.processor = AutoProcessor.from_pretrained(self.model_path)

def get_extra_rot_module_besides_embed_layers(self):
return [self.vision_projector.linear_2]

def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True, return_inputs=True): # noqa
def batch_process(
self,
img_qas,
calib_or_eval='eval',
apply_chat_template=True,
return_inputs=True,
): # noqa
assert calib_or_eval == 'calib' or calib_or_eval == 'eval'
assert apply_chat_template
messages = []
Expand All @@ -68,18 +75,16 @@ def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True,
'role': 'user',
'content': [
{'type': 'image'},
{'type': 'text', 'text': img_qas[idx]['question']}
]
{'type': 'text', 'text': img_qas[idx]['question']},
],
}
]
images.append(image)
else:
message = [
{
'role': 'user',
'content': [
{'type': 'text', 'text': img_qas[idx]['question']}
]
'content': [{'type': 'text', 'text': img_qas[idx]['question']}],
}
]
messages.append(message)
Expand All @@ -89,10 +94,7 @@ def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True,
for n in range(len(messages))
]
if calib_or_eval == 'calib' and self.config['calib'].get('add_answer', False):
texts = [
texts[n] + ' ' + answers[n]
for n in range(len(texts))
]
texts = [texts[n] + ' ' + answers[n] for n in range(len(texts))]
if calib_or_eval == 'calib':
logger.info(f'Calib data is:\n{texts}')
if not return_inputs:
Expand All @@ -101,8 +103,10 @@ def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True,
text=texts,
images=images if len(images) else None,
padding=True,
return_tensors='pt'
).to(next(self.vlm_model.parameters()).dtype) # noqa
return_tensors='pt',
).to(
next(self.vlm_model.parameters()).dtype
) # noqa
return inputs

def find_blocks(self):
Expand Down Expand Up @@ -162,7 +166,7 @@ def get_subsets_in_block(self, block):
'inspect': block.mlp.fc2,
'has_kwargs': False,
'is_mlp': True,
'do_trans': False
'do_trans': False,
},
]
else:
Expand Down Expand Up @@ -204,8 +208,9 @@ def __init__(

self._model = llmc_model.cuda()
self.pretrained = pretrained
self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision,
trust_remote_code=trust_remote_code)
self._image_processor = AutoProcessor.from_pretrained(
pretrained, revision=revision, trust_remote_code=trust_remote_code
)
# Pad from left for batched generation:
# https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips
self._image_processor.tokenizer.padding_side = 'left'
Expand All @@ -218,24 +223,36 @@ def __init__(
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
'train_micro_batch_size_per_gpu': self.batch_size_per_gpu,
'train_batch_size': self.batch_size_per_gpu * accelerator.num_processes,
'train_batch_size': self.batch_size_per_gpu
* accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(
must_match=True, **kwargs)
logger.info('Detected that you are using DistributedType.DEEPSPEED. \
Make sure you run `accelerate config` and set zero stage to 0')
if accelerator.distributed_type == DistributedType.FSDP or \
accelerator.distributed_type == DistributedType.DEEPSPEED:
must_match=True, **kwargs
)
logger.info(
'Detected that you are using DistributedType.DEEPSPEED. \
Make sure you run `accelerate config` and set zero stage to 0'
)
if (
accelerator.distributed_type == DistributedType.FSDP
or accelerator.distributed_type == DistributedType.DEEPSPEED
):
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self._model = accelerator.prepare_model(
self.model, evaluation_mode=True
)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
logger.info(f'Using {accelerator.num_processes} devices with data parallelism')
logger.info(
f'Using {accelerator.num_processes} devices with data parallelism'
)
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == 'auto':
logger.info(f'Using {accelerator.num_processes} devices with pipeline parallelism')
logger.info(
f'Using {accelerator.num_processes} devices with pipeline parallelism'
)
self._rank = 0
self._word_size = 1
else:
Expand Down
165 changes: 165 additions & 0 deletions llmc/models/videollava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from datetime import timedelta
from typing import List, Optional, Tuple, Union

import torch
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
from accelerate.state import AcceleratorState
from lmms_eval.api.model import lmms
from lmms_eval.models.video_llava import VideoLLaVA as VL
from loguru import logger
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
GenerationConfig, VideoLlavaForConditionalGeneration,
VideoLlavaProcessor)

from llmc.utils.registry_factory import MODEL_REGISTRY

from .llama import Llama


@MODEL_REGISTRY
class VideoLLaVA(Llama):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def build_model(self):
self.vlm_model_config = AutoConfig.from_pretrained(
self.model_path, trust_remote_code=True
)
if not self.use_cache:
self.vlm_model_config.text_config.use_cache = False
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
self.vlm_model = VideoLlavaForConditionalGeneration.from_pretrained(
self.model_path,
config=self.vlm_model_config,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
)
self.eval_name = 'VideoLLaVAHfEval'
self.mm_model = self.vlm_model
logger.info(f'self.vlm_model : {self.vlm_model}')
self.video_tower = self.vlm_model.video_tower
self.image_tower = self.vlm_model.image_tower
self.vision_projector = self.vlm_model.multi_modal_projector
self.model = self.vlm_model.language_model
self.model_config = self.vlm_model_config.text_config
self.pruning_config = {
'is_video_model': True,
'image_token_length': self.vlm_model_config.image_seq_length,
'video_token_length': self.vlm_model_config.video_seq_length,
'select_layer': self.vlm_model_config.vision_feature_layer,
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
'image_token_index': self.vlm_model_config.image_token_index,
'video_token_index': self.vlm_model_config.video_token_index,
}


@MODEL_REGISTRY
class VideoLLaVAHfEval(VL):
def __init__(
self,
llmc_model,
pretrained: str = 'LanguageBind/Video-LLaVA-7B-hf',
truncation: Optional[bool] = True,
device: Optional[str] = 'cuda:0',
dtype: Optional[Union[str, torch.dtype]] = 'auto',
batch_size: Optional[Union[int, str]] = 1,
trust_remote_code: Optional[bool] = False,
revision=None,
attn_implementation=(
'sdpa' if torch.__version__ > '2.1.2' else 'eager'
),
# inference implementation for attention, can be "sdpa", "eager", "flash_attention_2".
# Seems FA2 is not effective during inference:
# https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453/5
device_map='cuda:0',
conv_template='llava_v1',
use_cache=True,
truncate_context=False,
num_frames: int = 8,
# whether to truncate the context in generation,
# set it False for LLaVA-1.6
**kwargs,
) -> None:
lmms.__init__(self)
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
if accelerator.num_processes > 1:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'
elif accelerator.num_processes == 1 and device_map == 'auto':
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'

self.pretrained = pretrained
self._model = llmc_model.cuda()
self._processor = VideoLlavaProcessor.from_pretrained(pretrained)
self.prompt = 'USER: <video>{}? ASSISTANT:'
self.num_frames = num_frames
assert (
num_frames == 8
), 'num_frames must be 8'
# self.model_name = get_model_name_from_path(pretrained)
# self._tokenizer, self._model, self.processor,
# self._max_length = load_pretrained_model(pretrained,
# None, self.model_name, device_map=self.device_map)
# self.video_processor = self.processor["video"]
self._config = self._model.config
self.model.eval()
self.model.tie_weights()
self.truncation = truncation
self.batch_size_per_gpu = int(batch_size)
self.conv_template = conv_template
self.use_cache = use_cache
self.truncate_context = truncate_context
# assert self.batch_size_per_gpu == 1,
# "Llava currently does not support batched generation.
# See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
DistributedType.DEEPSPEED,
], 'Unsupported distributed type provided. Only DDP and FSDP are supported.'
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
'train_micro_batch_size_per_gpu': self.batch_size_per_gpu,
'train_batch_size': self.batch_size_per_gpu
* accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(
must_match=True, **kwargs
)
logger.info(
'Detected that you are using DistributedType.DEEPSPEED. ' +
'Make sure you run `accelerate config` and set zero stage to 0'
)
if (
accelerator.distributed_type == DistributedType.FSDP
or accelerator.distributed_type == DistributedType.DEEPSPEED
):
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(
self.model, evaluation_mode=True
)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
logger.info(
f'Using {accelerator.num_processes} devices with data parallelism'
)
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == 'auto':
logger.info(
f'Using {accelerator.num_processes} devices with tensor parallelism'
)
self._rank = 0
self._word_size = 1
else:
logger.info(f'Using single device: {self._device}')
self.model.to(self._device)
self._rank = 0
self._world_size = 1