Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions llmc/compression/token_reduction/fastv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

from .token_reduction_module import TokenReductionModule
from .utils import prefill_wrapper


@TOKEN_REDUCTION_REGISTRY.register('FastV')
Expand All @@ -16,18 +17,25 @@ def __init__(self, config, model, blocks):
self.register_reduction_modules()

def add_sparse_config(self):
special_config = self.config.get('special', {})
self.pruning_loc = special_config['pruning_loc']
special_config['image_token_start_index'] = \
self.model.pruning_config['image_token_start_index']
special_config['image_token_length'] = \

self.pruning_loc = self.special_config['pruning_loc']
self.special_config['image_token_length'] = \
self.model.pruning_config['image_token_length']
special_config['attn_scores'] = None
self.special_config['attn_scores'] = None

self.model.model.parameters = special_config
self.model.model.parameters = self.special_config

def register_reduction_modules(self):

@prefill_wrapper
def input_hook(module, input_args, pruning_pars):
input_ids = input_args[0]
image_token_idxs = (input_ids[0] ==
pruning_pars['vision_token_index']).nonzero(as_tuple=True)[0]
pruning_pars['image_token_start_index'] = image_token_idxs[0].item()

return input_args

def update_output_attentions_hook(module, args, kwargs):
kwargs['output_attentions'] = True
return args, kwargs
Expand All @@ -36,6 +44,7 @@ def store_attention_hook(m, x, layer_outputs, pruning_pars):
layer_attention = layer_outputs[1]
pruning_pars['attn_scores'] = layer_attention

@prefill_wrapper
def fastv_pruning_hook(module, args, kwargs, pruning_pars):

rate = pruning_pars['rate']
Expand Down Expand Up @@ -96,6 +105,7 @@ def fastv_pruning_hook(module, args, kwargs, pruning_pars):

return (hidden_states,), kwargs

@prefill_wrapper
def read_parameter_hook(module, args, kwargs, pruning_pars):
kwargs['attention_mask'] = pruning_pars['attention_mask']
kwargs['cache_position'] = pruning_pars['cache_position']
Expand All @@ -104,6 +114,10 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):

return args, kwargs

self.model.embed_tokens.register_forward_pre_hook(
functools.partial(input_hook, pruning_pars=self.model.model.parameters)
)

self.blocks[self.pruning_loc - 1].register_forward_pre_hook(
update_output_attentions_hook,
with_kwargs=True
Expand Down
42 changes: 12 additions & 30 deletions llmc/compression/token_reduction/pyramiddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

from .token_reduction_module import TokenReductionModule
from .utils import prefill_wrapper


@TOKEN_REDUCTION_REGISTRY.register('PyramidDrop')
Expand All @@ -20,38 +21,21 @@ def __init__(self, config, model, blocks):
self.register_reduction_modules()

def add_sparse_config(self):
special_config = self.config.get('special', {})
self.pruning_loc = special_config['layer_list']
image_token_ratio_list = special_config['image_token_ratio_list']

self.pruning_loc = self.special_config['layer_list']
image_token_ratio_list = self.special_config['image_token_ratio_list']
image_token_ratio_list.insert(0, 1.0)
special_config['image_token_ratio_list'] = image_token_ratio_list
special_config['tokenizer_padding_side'] = getattr(
self.special_config['image_token_ratio_list'] = image_token_ratio_list
self.special_config['tokenizer_padding_side'] = getattr(
self.model.vlm_model.language_model.model.config,
'tokenizer_padding_side',
'right',
)
special_config['is_video_model'] = self.model.pruning_config['is_video_model']

# vision_token can be image or video
if special_config['is_video_model']:
special_config['vision_token_index'] = self.model.pruning_config[
'video_token_index'
]
special_config['vision_token_length'] = self.model.pruning_config[
'video_token_length'
]
else:
special_config['vision_token_index'] = self.model.pruning_config[
'image_token_index'
]
special_config['vision_token_length'] = self.model.pruning_config[
'image_token_length'
]

self.model.model.parameters = special_config

def register_reduction_modules(self):
self.model.model.parameters = self.special_config

def register_reduction_modules(self):
@prefill_wrapper
def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):

if layer_idx == self.pruning_loc[0]:
Expand Down Expand Up @@ -315,10 +299,9 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):

return (new_input_embeds,), kwargs

@prefill_wrapper
def input_hook(module, input_args, pruning_pars):
# for the decoding stage
if input_args[0].shape[1] == 1:
return input_args

input_ids = input_args[0]
pre_prompt_length_list = []
image_token_posi = []
Expand All @@ -338,9 +321,8 @@ def input_hook(module, input_args, pruning_pars):

return input_args

@prefill_wrapper
def read_parameter_hook(module, args, kwargs, pruning_pars):
if args[0].shape[1] == 1:
return args, kwargs
kwargs['attention_mask'] = pruning_pars['attention_mask']
# kwargs['cache_position'] = pruning_pars['cache_position']
kwargs['position_ids'] = pruning_pars['position_ids']
Expand Down
6 changes: 5 additions & 1 deletion llmc/compression/token_reduction/sparsevlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

from .token_reduction_module import TokenReductionModule
from .utils import prefill_wrapper, prefill_wrapper_model


@TOKEN_REDUCTION_REGISTRY.register('SparseVLM')
Expand All @@ -29,7 +30,7 @@ def add_sparse_config(self):
self.model.model.parameters = special_config

def register_reduction_modules(self):

@prefill_wrapper
def input_hook(module, input_args, pruning_pars):
input_ids = input_args[0]
pre_prompt_length_list = []
Expand All @@ -51,6 +52,7 @@ def input_hook(module, input_args, pruning_pars):

return input_args

@prefill_wrapper_model
def register_module_pars(module, args, kwargs, pruning_pars):
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
inputs_embeds = kwargs['inputs_embeds']
Expand Down Expand Up @@ -92,6 +94,7 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
return args, kwargs

@prefill_wrapper
def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx):

attn_logits = layer_outputs[1]
Expand Down Expand Up @@ -195,6 +198,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer

return new_output

@prefill_wrapper
def read_parameter_hook(module, args, kwargs, pruning_pars):
kwargs['position_ids'] = pruning_pars['position_ids']
kwargs['cache_position'] = pruning_pars['cache_position']
Expand Down
20 changes: 20 additions & 0 deletions llmc/compression/token_reduction/token_reduction_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@ def __init__(self, config, model, blocks):
self.config = config
self.model = model
self.blocks = blocks
self.set_sparse_config()

def set_sparse_config(self):
self.special_config = self.config.get('special', {})
self.special_config['is_video_model'] = self.model.pruning_config['is_video_model']
# vision_token can be image or video
if self.special_config['is_video_model']:
self.special_config['vision_token_index'] = self.model.pruning_config[
'video_token_index'
]
self.special_config['vision_token_length'] = self.model.pruning_config[
'video_token_length'
]
else:
self.special_config['vision_token_index'] = self.model.pruning_config[
'image_token_index'
]
self.special_config['vision_token_length'] = self.model.pruning_config[
'image_token_length'
]

def register_reduction_modules(self):
pass
25 changes: 25 additions & 0 deletions llmc/compression/token_reduction/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
from functools import wraps
from typing import Any, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from transformers.models.clip.modeling_clip import CLIPEncoderLayer


def prefill_wrapper(func):
@wraps(func)
def wrapper(*args, **kwargs):
# for the decoding stage
if len(args) > 1:
input_args = args[1]
if hasattr(input_args[0], 'shape') and input_args[0].shape[1] == 1:
return None
return func(*args, **kwargs)
return wrapper


def prefill_wrapper_model(func):
@wraps(func)
def wrapper(*args, **kwargs):
# for the decoding stage
if len(args) > 1:
input_args = args[2]['inputs_embeds']
if hasattr(input_args, 'shape') and input_args.shape[1] == 1:
return None
return func(*args, **kwargs)
return wrapper


def parse_r(num_layers: int, r: Union[List[int], Tuple[int, float], int]) -> List[int]:
"""Copy from the TOME. https://github.com/facebookresearch/ToMe.

Expand Down
2 changes: 1 addition & 1 deletion llmc/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def safe_prepare_inputs_for_generation(
self.model = self.vlm_model
self.model_config = self.vlm_model_config.text_config
self.pruning_config = {
'image_token_start_index': 5,
'is_video_model': False,
'image_token_length': self.vlm_model_config.image_seq_length,
'select_layer': self.vlm_model_config.vision_feature_layer,
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
Expand Down