diff --git a/llmc/compression/token_reduction/pyramiddrop.py b/llmc/compression/token_reduction/pyramiddrop.py index 9be07ad99..680e2bc2d 100644 --- a/llmc/compression/token_reduction/pyramiddrop.py +++ b/llmc/compression/token_reduction/pyramiddrop.py @@ -30,9 +30,24 @@ def add_sparse_config(self): 'tokenizer_padding_side', 'right', ) - special_config['image_token_index'] = self.model.pruning_config[ - 'image_token_index' - ] + special_config['is_video_model'] = self.model.pruning_config['is_video_model'] + + # vision_token can be image or video + if special_config['is_video_model']: + special_config['vision_token_index'] = self.model.pruning_config[ + 'video_token_index' + ] + special_config['vision_token_length'] = self.model.pruning_config[ + 'video_token_length' + ] + else: + special_config['vision_token_index'] = self.model.pruning_config[ + 'image_token_index' + ] + special_config['vision_token_length'] = self.model.pruning_config[ + 'image_token_length' + ] + self.model.model.parameters = special_config def register_reduction_modules(self): @@ -56,6 +71,10 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx): image_token_posi = pruning_pars['image_token_posi'] image_token_ratio_list = pruning_pars['image_token_ratio_list'] + # for decoding stage + if features.shape[1] == 1: + return args, kwargs + if position_ids is None: position_ids = torch.arange( 0, features.shape[1], dtype=torch.long, device=features.device @@ -297,26 +316,31 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx): return (new_input_embeds,), kwargs def input_hook(module, input_args, pruning_pars): + # for the decoding stage + if input_args[0].shape[1] == 1: + return input_args input_ids = input_args[0] pre_prompt_length_list = [] image_token_posi = [] - image_tokens = [] - IMAGE_TOKEN_INDEX = pruning_pars['image_token_index'] + vision_tokens = [] + VISION_TOKEN_INDEX = pruning_pars['vision_token_index'] # find the position of the first image token for seq in input_ids: - image_token_idxs = (seq == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0] - image_tokens.append(image_token_idxs.shape[0]) + image_token_idxs = (seq == VISION_TOKEN_INDEX).nonzero(as_tuple=True)[0] + vision_tokens.append(pruning_pars['vision_token_length']) image_token_posi.append(image_token_idxs[0].item()) pre_prompt_length_list.append(seq.shape[0] - image_token_idxs.shape[0]) pruning_pars['prompt_len'] = pre_prompt_length_list pruning_pars['image_token_posi'] = image_token_posi - pruning_pars['image_tokens'] = image_tokens + pruning_pars['image_tokens'] = vision_tokens return input_args def read_parameter_hook(module, args, kwargs, pruning_pars): + if args[0].shape[1] == 1: + return args, kwargs kwargs['attention_mask'] = pruning_pars['attention_mask'] # kwargs['cache_position'] = pruning_pars['cache_position'] kwargs['position_ids'] = pruning_pars['position_ids'] diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index d57751a5e..85d6ffcc5 100755 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -26,6 +26,7 @@ from .smollm import SmolLM from .stablelm import StableLm from .starcoder import Starcoder +from .videollava import VideoLLaVA from .vila import Vila from .vit import Vit from .wan_i2v import WanI2V diff --git a/llmc/models/llava.py b/llmc/models/llava.py index 926dfbcab..9d488f8f7 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -41,11 +41,12 @@ def build_model(self): self.model = self.vlm_model.language_model self.model_config = self.vlm_model_config.text_config self.pruning_config = { + 'is_video_model': False, 'image_token_start_index': 5, - 'image_token_length': 576, + 'image_token_length': self.vlm_model_config.image_seq_length, 'select_layer': self.vlm_model_config.vision_feature_layer, 'select_feature': self.vlm_model_config.vision_feature_select_strategy, - 'image_token_index': self.vlm_model_config.image_token_index + 'image_token_index': self.vlm_model_config.image_token_index, } self.processor = AutoProcessor.from_pretrained(self.model_path) @@ -53,7 +54,13 @@ def build_model(self): def get_extra_rot_module_besides_embed_layers(self): return [self.vision_projector.linear_2] - def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True, return_inputs=True): # noqa + def batch_process( + self, + img_qas, + calib_or_eval='eval', + apply_chat_template=True, + return_inputs=True, + ): # noqa assert calib_or_eval == 'calib' or calib_or_eval == 'eval' assert apply_chat_template messages = [] @@ -68,8 +75,8 @@ def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True, 'role': 'user', 'content': [ {'type': 'image'}, - {'type': 'text', 'text': img_qas[idx]['question']} - ] + {'type': 'text', 'text': img_qas[idx]['question']}, + ], } ] images.append(image) @@ -77,9 +84,7 @@ def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True, message = [ { 'role': 'user', - 'content': [ - {'type': 'text', 'text': img_qas[idx]['question']} - ] + 'content': [{'type': 'text', 'text': img_qas[idx]['question']}], } ] messages.append(message) @@ -89,10 +94,7 @@ def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True, for n in range(len(messages)) ] if calib_or_eval == 'calib' and self.config['calib'].get('add_answer', False): - texts = [ - texts[n] + ' ' + answers[n] - for n in range(len(texts)) - ] + texts = [texts[n] + ' ' + answers[n] for n in range(len(texts))] if calib_or_eval == 'calib': logger.info(f'Calib data is:\n{texts}') if not return_inputs: @@ -101,8 +103,10 @@ def batch_process(self, img_qas, calib_or_eval='eval', apply_chat_template=True, text=texts, images=images if len(images) else None, padding=True, - return_tensors='pt' - ).to(next(self.vlm_model.parameters()).dtype) # noqa + return_tensors='pt', + ).to( + next(self.vlm_model.parameters()).dtype + ) # noqa return inputs def find_blocks(self): @@ -162,7 +166,7 @@ def get_subsets_in_block(self, block): 'inspect': block.mlp.fc2, 'has_kwargs': False, 'is_mlp': True, - 'do_trans': False + 'do_trans': False, }, ] else: @@ -204,8 +208,9 @@ def __init__( self._model = llmc_model.cuda() self.pretrained = pretrained - self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, - trust_remote_code=trust_remote_code) + self._image_processor = AutoProcessor.from_pretrained( + pretrained, revision=revision, trust_remote_code=trust_remote_code + ) # Pad from left for batched generation: # https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips self._image_processor.tokenizer.padding_side = 'left' @@ -218,24 +223,36 @@ def __init__( if accelerator.distributed_type == DistributedType.DEEPSPEED: kwargs = { 'train_micro_batch_size_per_gpu': self.batch_size_per_gpu, - 'train_batch_size': self.batch_size_per_gpu * accelerator.num_processes, + 'train_batch_size': self.batch_size_per_gpu + * accelerator.num_processes, } AcceleratorState().deepspeed_plugin.deepspeed_config_process( - must_match=True, **kwargs) - logger.info('Detected that you are using DistributedType.DEEPSPEED. \ - Make sure you run `accelerate config` and set zero stage to 0') - if accelerator.distributed_type == DistributedType.FSDP or \ - accelerator.distributed_type == DistributedType.DEEPSPEED: + must_match=True, **kwargs + ) + logger.info( + 'Detected that you are using DistributedType.DEEPSPEED. \ + Make sure you run `accelerate config` and set zero stage to 0' + ) + if ( + accelerator.distributed_type == DistributedType.FSDP + or accelerator.distributed_type == DistributedType.DEEPSPEED + ): self._model = accelerator.prepare(self.model) else: - self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self._model = accelerator.prepare_model( + self.model, evaluation_mode=True + ) self.accelerator = accelerator if self.accelerator.is_local_main_process: - logger.info(f'Using {accelerator.num_processes} devices with data parallelism') + logger.info( + f'Using {accelerator.num_processes} devices with data parallelism' + ) self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes elif accelerator.num_processes == 1 and device_map == 'auto': - logger.info(f'Using {accelerator.num_processes} devices with pipeline parallelism') + logger.info( + f'Using {accelerator.num_processes} devices with pipeline parallelism' + ) self._rank = 0 self._word_size = 1 else: diff --git a/llmc/models/videollava.py b/llmc/models/videollava.py new file mode 100644 index 000000000..8d7931164 --- /dev/null +++ b/llmc/models/videollava.py @@ -0,0 +1,165 @@ +from datetime import timedelta +from typing import List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs +from accelerate.state import AcceleratorState +from lmms_eval.api.model import lmms +from lmms_eval.models.video_llava import VideoLLaVA as VL +from loguru import logger +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + GenerationConfig, VideoLlavaForConditionalGeneration, + VideoLlavaProcessor) + +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .llama import Llama + + +@MODEL_REGISTRY +class VideoLLaVA(Llama): + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + + def build_model(self): + self.vlm_model_config = AutoConfig.from_pretrained( + self.model_path, trust_remote_code=True + ) + if not self.use_cache: + self.vlm_model_config.text_config.use_cache = False + logger.info(f'self.vlm_model_config : {self.vlm_model_config}') + self.vlm_model = VideoLlavaForConditionalGeneration.from_pretrained( + self.model_path, + config=self.vlm_model_config, + torch_dtype=self.torch_dtype, + low_cpu_mem_usage=True, + ) + self.eval_name = 'VideoLLaVAHfEval' + self.mm_model = self.vlm_model + logger.info(f'self.vlm_model : {self.vlm_model}') + self.video_tower = self.vlm_model.video_tower + self.image_tower = self.vlm_model.image_tower + self.vision_projector = self.vlm_model.multi_modal_projector + self.model = self.vlm_model.language_model + self.model_config = self.vlm_model_config.text_config + self.pruning_config = { + 'is_video_model': True, + 'image_token_length': self.vlm_model_config.image_seq_length, + 'video_token_length': self.vlm_model_config.video_seq_length, + 'select_layer': self.vlm_model_config.vision_feature_layer, + 'select_feature': self.vlm_model_config.vision_feature_select_strategy, + 'image_token_index': self.vlm_model_config.image_token_index, + 'video_token_index': self.vlm_model_config.video_token_index, + } + + +@MODEL_REGISTRY +class VideoLLaVAHfEval(VL): + def __init__( + self, + llmc_model, + pretrained: str = 'LanguageBind/Video-LLaVA-7B-hf', + truncation: Optional[bool] = True, + device: Optional[str] = 'cuda:0', + dtype: Optional[Union[str, torch.dtype]] = 'auto', + batch_size: Optional[Union[int, str]] = 1, + trust_remote_code: Optional[bool] = False, + revision=None, + attn_implementation=( + 'sdpa' if torch.__version__ > '2.1.2' else 'eager' + ), + # inference implementation for attention, can be "sdpa", "eager", "flash_attention_2". + # Seems FA2 is not effective during inference: + # https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453/5 + device_map='cuda:0', + conv_template='llava_v1', + use_cache=True, + truncate_context=False, + num_frames: int = 8, + # whether to truncate the context in generation, + # set it False for LLaVA-1.6 + **kwargs, + ) -> None: + lmms.__init__(self) + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self._device = torch.device(f'cuda:{accelerator.local_process_index}') + self.device_map = f'cuda:{accelerator.local_process_index}' + elif accelerator.num_processes == 1 and device_map == 'auto': + self._device = torch.device(device) + self.device_map = device_map + else: + self._device = torch.device(f'cuda:{accelerator.local_process_index}') + self.device_map = f'cuda:{accelerator.local_process_index}' + + self.pretrained = pretrained + self._model = llmc_model.cuda() + self._processor = VideoLlavaProcessor.from_pretrained(pretrained) + self.prompt = 'USER: