diff --git a/llmc/compression/token_reduction/fastv.py b/llmc/compression/token_reduction/fastv.py index 6c92e880f..e288ad6c3 100644 --- a/llmc/compression/token_reduction/fastv.py +++ b/llmc/compression/token_reduction/fastv.py @@ -6,6 +6,7 @@ from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule +from .utils import prefill_wrapper @TOKEN_REDUCTION_REGISTRY.register('FastV') @@ -16,18 +17,25 @@ def __init__(self, config, model, blocks): self.register_reduction_modules() def add_sparse_config(self): - special_config = self.config.get('special', {}) - self.pruning_loc = special_config['pruning_loc'] - special_config['image_token_start_index'] = \ - self.model.pruning_config['image_token_start_index'] - special_config['image_token_length'] = \ + + self.pruning_loc = self.special_config['pruning_loc'] + self.special_config['image_token_length'] = \ self.model.pruning_config['image_token_length'] - special_config['attn_scores'] = None + self.special_config['attn_scores'] = None - self.model.model.parameters = special_config + self.model.model.parameters = self.special_config def register_reduction_modules(self): + @prefill_wrapper + def input_hook(module, input_args, pruning_pars): + input_ids = input_args[0] + image_token_idxs = (input_ids[0] == + pruning_pars['vision_token_index']).nonzero(as_tuple=True)[0] + pruning_pars['image_token_start_index'] = image_token_idxs[0].item() + + return input_args + def update_output_attentions_hook(module, args, kwargs): kwargs['output_attentions'] = True return args, kwargs @@ -36,6 +44,7 @@ def store_attention_hook(m, x, layer_outputs, pruning_pars): layer_attention = layer_outputs[1] pruning_pars['attn_scores'] = layer_attention + @prefill_wrapper def fastv_pruning_hook(module, args, kwargs, pruning_pars): rate = pruning_pars['rate'] @@ -96,6 +105,7 @@ def fastv_pruning_hook(module, args, kwargs, pruning_pars): return (hidden_states,), kwargs + @prefill_wrapper def read_parameter_hook(module, args, kwargs, pruning_pars): kwargs['attention_mask'] = pruning_pars['attention_mask'] kwargs['cache_position'] = pruning_pars['cache_position'] @@ -104,6 +114,10 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): return args, kwargs + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(input_hook, pruning_pars=self.model.model.parameters) + ) + self.blocks[self.pruning_loc - 1].register_forward_pre_hook( update_output_attentions_hook, with_kwargs=True diff --git a/llmc/compression/token_reduction/pyramiddrop.py b/llmc/compression/token_reduction/pyramiddrop.py index 680e2bc2d..04ece8e3d 100644 --- a/llmc/compression/token_reduction/pyramiddrop.py +++ b/llmc/compression/token_reduction/pyramiddrop.py @@ -10,6 +10,7 @@ from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule +from .utils import prefill_wrapper @TOKEN_REDUCTION_REGISTRY.register('PyramidDrop') @@ -20,38 +21,21 @@ def __init__(self, config, model, blocks): self.register_reduction_modules() def add_sparse_config(self): - special_config = self.config.get('special', {}) - self.pruning_loc = special_config['layer_list'] - image_token_ratio_list = special_config['image_token_ratio_list'] + + self.pruning_loc = self.special_config['layer_list'] + image_token_ratio_list = self.special_config['image_token_ratio_list'] image_token_ratio_list.insert(0, 1.0) - special_config['image_token_ratio_list'] = image_token_ratio_list - special_config['tokenizer_padding_side'] = getattr( + self.special_config['image_token_ratio_list'] = image_token_ratio_list + self.special_config['tokenizer_padding_side'] = getattr( self.model.vlm_model.language_model.model.config, 'tokenizer_padding_side', 'right', ) - special_config['is_video_model'] = self.model.pruning_config['is_video_model'] - - # vision_token can be image or video - if special_config['is_video_model']: - special_config['vision_token_index'] = self.model.pruning_config[ - 'video_token_index' - ] - special_config['vision_token_length'] = self.model.pruning_config[ - 'video_token_length' - ] - else: - special_config['vision_token_index'] = self.model.pruning_config[ - 'image_token_index' - ] - special_config['vision_token_length'] = self.model.pruning_config[ - 'image_token_length' - ] - - self.model.model.parameters = special_config - def register_reduction_modules(self): + self.model.model.parameters = self.special_config + def register_reduction_modules(self): + @prefill_wrapper def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx): if layer_idx == self.pruning_loc[0]: @@ -315,10 +299,9 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx): return (new_input_embeds,), kwargs + @prefill_wrapper def input_hook(module, input_args, pruning_pars): - # for the decoding stage - if input_args[0].shape[1] == 1: - return input_args + input_ids = input_args[0] pre_prompt_length_list = [] image_token_posi = [] @@ -338,9 +321,8 @@ def input_hook(module, input_args, pruning_pars): return input_args + @prefill_wrapper def read_parameter_hook(module, args, kwargs, pruning_pars): - if args[0].shape[1] == 1: - return args, kwargs kwargs['attention_mask'] = pruning_pars['attention_mask'] # kwargs['cache_position'] = pruning_pars['cache_position'] kwargs['position_ids'] = pruning_pars['position_ids'] diff --git a/llmc/compression/token_reduction/sparsevlm.py b/llmc/compression/token_reduction/sparsevlm.py index d24468848..92b9659c6 100755 --- a/llmc/compression/token_reduction/sparsevlm.py +++ b/llmc/compression/token_reduction/sparsevlm.py @@ -6,6 +6,7 @@ from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule +from .utils import prefill_wrapper, prefill_wrapper_model @TOKEN_REDUCTION_REGISTRY.register('SparseVLM') @@ -29,7 +30,7 @@ def add_sparse_config(self): self.model.model.parameters = special_config def register_reduction_modules(self): - + @prefill_wrapper def input_hook(module, input_args, pruning_pars): input_ids = input_args[0] pre_prompt_length_list = [] @@ -51,6 +52,7 @@ def input_hook(module, input_args, pruning_pars): return input_args + @prefill_wrapper_model def register_module_pars(module, args, kwargs, pruning_pars): pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] inputs_embeds = kwargs['inputs_embeds'] @@ -92,6 +94,7 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx) kwargs['position_embeddings'] = pruning_pars['position_embeddings'] return args, kwargs + @prefill_wrapper def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx): attn_logits = layer_outputs[1] @@ -195,6 +198,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer return new_output + @prefill_wrapper def read_parameter_hook(module, args, kwargs, pruning_pars): kwargs['position_ids'] = pruning_pars['position_ids'] kwargs['cache_position'] = pruning_pars['cache_position'] diff --git a/llmc/compression/token_reduction/token_reduction_module.py b/llmc/compression/token_reduction/token_reduction_module.py index 7391d6418..0645ef538 100644 --- a/llmc/compression/token_reduction/token_reduction_module.py +++ b/llmc/compression/token_reduction/token_reduction_module.py @@ -4,6 +4,26 @@ def __init__(self, config, model, blocks): self.config = config self.model = model self.blocks = blocks + self.set_sparse_config() + + def set_sparse_config(self): + self.special_config = self.config.get('special', {}) + self.special_config['is_video_model'] = self.model.pruning_config['is_video_model'] + # vision_token can be image or video + if self.special_config['is_video_model']: + self.special_config['vision_token_index'] = self.model.pruning_config[ + 'video_token_index' + ] + self.special_config['vision_token_length'] = self.model.pruning_config[ + 'video_token_length' + ] + else: + self.special_config['vision_token_index'] = self.model.pruning_config[ + 'image_token_index' + ] + self.special_config['vision_token_length'] = self.model.pruning_config[ + 'image_token_length' + ] def register_reduction_modules(self): pass diff --git a/llmc/compression/token_reduction/utils.py b/llmc/compression/token_reduction/utils.py index a293c61de..8a341ce79 100755 --- a/llmc/compression/token_reduction/utils.py +++ b/llmc/compression/token_reduction/utils.py @@ -1,3 +1,4 @@ +from functools import wraps from typing import Any, List, Optional, Tuple, Union import torch @@ -5,6 +6,30 @@ from transformers.models.clip.modeling_clip import CLIPEncoderLayer +def prefill_wrapper(func): + @wraps(func) + def wrapper(*args, **kwargs): + # for the decoding stage + if len(args) > 1: + input_args = args[1] + if hasattr(input_args[0], 'shape') and input_args[0].shape[1] == 1: + return None + return func(*args, **kwargs) + return wrapper + + +def prefill_wrapper_model(func): + @wraps(func) + def wrapper(*args, **kwargs): + # for the decoding stage + if len(args) > 1: + input_args = args[2]['inputs_embeds'] + if hasattr(input_args, 'shape') and input_args.shape[1] == 1: + return None + return func(*args, **kwargs) + return wrapper + + def parse_r(num_layers: int, r: Union[List[int], Tuple[int, float], int]) -> List[int]: """Copy from the TOME. https://github.com/facebookresearch/ToMe. diff --git a/llmc/models/llava.py b/llmc/models/llava.py index 450db09a0..3256b65b6 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -96,7 +96,7 @@ def safe_prepare_inputs_for_generation( self.model = self.vlm_model self.model_config = self.vlm_model_config.text_config self.pruning_config = { - 'image_token_start_index': 5, + 'is_video_model': False, 'image_token_length': self.vlm_model_config.image_seq_length, 'select_layer': self.vlm_model_config.vision_feature_layer, 'select_feature': self.vlm_model_config.vision_feature_select_strategy,