Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/sparsification/methods/FastV/fastv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ sparse:
special:
method: FastV
pruning_loc: 3
rate: 0.5
rate: 0.778
save:
save_trans: False
save_fake: False
Expand Down
2 changes: 1 addition & 1 deletion configs/sparsification/methods/FasterVLM/fastervlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ sparse:
method: TokenReduction
special:
method: FasterVLM
rate: 0.75
rate: 0.778
save:
save_trans: False
save_fake: False
Expand Down
11 changes: 6 additions & 5 deletions configs/sparsification/methods/VisionZip/visionzip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ eval:
bs: 1
inference_per_block: False
sparse:
method: TokenReduction
special:
method: VisionZip
dominant: 191
contextual: 30
vision:
method: TokenReduction
special:
method: VisionZip
dominant: 191 # visual_tokens = dominan_tokens + 1(cls_token)
contextual: 30
save:
save_trans: False
save_fake: False
Expand Down
122 changes: 86 additions & 36 deletions llmc/compression/token_reduction/fastervlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,39 @@ def add_sparse_config(self):
special_config['select_feature'] = self.model.pruning_config['select_feature']
special_config['image_token_index'] = self.model.pruning_config['image_token_index']

self.model.model.parameters = special_config
special_config['image_attentions_list'] = []

self.pruning_paras = special_config

def register_reduction_modules(self):

def update_output_attentions_hook(module, args, kwargs):
kwargs['output_attentions'] = True
return args, kwargs

def store_attention_hook(m, x, image_forward_outs, pruning_pars):
image_attentions = image_forward_outs.attentions[pruning_pars['select_layer']]
if pruning_pars['select_feature'] == 'default': # patch
image_attentions = image_attentions[:, :, 0, 1:]
elif pruning_pars['select_feature'] == 'full':
image_attentions = image_attentions
def clear_attentions_hook(m, x, pruning_paras):
pruning_paras['image_attentions_list'].clear()

def store_attention_hook(m, x, image_forward_outs, pruning_paras):
image_attentions = image_forward_outs.attentions[pruning_paras['select_layer']]
if pruning_paras['select_feature'] in ('default', 'patch'):
image_attention = image_attentions[:, :, 0, 1:]
elif pruning_paras['select_feature'] in ('full', 'cls_patch'):
image_attention = image_attentions
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable self.select_feature is not defined in this class context. Use pruning_paras['select_feature'] instead.

Suggested change
raise ValueError(f'Unexpected select feature: {self.select_feature}')
raise ValueError(f"Unexpected select feature: {pruning_paras['select_feature']}")

pruning_pars['image_attentions'] = image_attentions
pruning_paras['image_attentions_list'].append(image_attention.to(x[0].dtype))

def update_attentions_hook(m, x, outs, pruning_paras):
if len(pruning_paras['image_attentions_list']) == 1:
pruning_paras['image_attentions'] = pruning_paras['image_attentions_list'][0]
else:
pruning_paras['image_attentions'] = pruning_paras['image_attentions_list']
Comment on lines +53 to +54

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If len(pruning_paras['image_attentions_list']) is greater than 1, pruning_paras['image_attentions'] will be a list of tensors. The pruning_hook at line 59 expects a single tensor, which will raise an AttributeError. Stack and average the attentions here.

Suggested change
else:
pruning_paras['image_attentions'] = pruning_paras['image_attentions_list']
else:
# If multiple attention tensors are collected, stack and average them.
pruning_paras['image_attentions'] = torch.stack(pruning_paras['image_attentions_list']).mean(dim=0)


def pruning_hook(module, args, kwargs, pruning_pars):
def pruning_hook(module, args, kwargs, pruning_paras):

image_features = args[0]
image_attentions = pruning_pars['image_attentions']
image_attentions = pruning_paras['image_attentions']

# image_attentions = image_attentions.max(dim=1)[0] # (B, N) = (1, 576)
image_attentions = image_attentions.mean(dim=1) # (B, N) = (1, 576)
Expand All @@ -66,22 +77,22 @@ def pruning_hook(module, args, kwargs, pruning_pars):
index_mask = torch.zeros(B, N, dtype=torch.bool, device=image_features.device) # (B, N)
index_mask.scatter_(1, token_indices, True) # (B, N)

pruning_pars['index_mask'] = index_mask
pruning_pars['image_attentions'] = image_attentions
pruning_paras['index_mask'] = index_mask
pruning_paras['image_attentions'] = image_attentions

return (image_features,), kwargs

def get_image_mask_hook(module, args, kwargs, pruning_pars):
pruning_pars['image_mask'] = (
kwargs['input_ids'] == pruning_pars['image_token_index']
def get_image_mask_hook(module, args, kwargs, pruning_paras):
pruning_paras['image_mask'] = (
kwargs['input_ids'] == pruning_paras['image_token_index']
) # (B, len)

def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_pars):
def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras):

# Only batch size 1 is currently supported.
inputs_embeds = kwargs['inputs_embeds']
image_mask = pruning_pars['image_mask'][0]
index_mask = pruning_pars['index_mask'][0]
image_mask = pruning_paras['image_mask'][0]
index_mask = pruning_paras['index_mask'][0]

B, L = inputs_embeds.shape[:2]
device = inputs_embeds.device
Expand Down Expand Up @@ -109,28 +120,67 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_pars):

return args, kwargs

self.model.vision_model.register_forward_pre_hook(
update_output_attentions_hook,
with_kwargs=True
)
def prepare_inputs_hook(module, inputs, outputs, pruning_paras):

self.model.vision_model.register_forward_hook(
functools.partial(store_attention_hook, pruning_pars=self.model.model.parameters),
)
image_features = outputs
index_masks = pruning_paras['index_mask']
# image_attentions = pruning_paras['image_attentions']
new_image_features = []
for image_feature, index_mask in zip(image_features, index_masks):
image_feature = image_feature[index_mask]
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)

outputs = image_features
pruning_paras['image_features_shape'] = image_features[0].shape[0]

return outputs

if self.model.__class__.__name__ == 'LlavaHf':
self.model.vision_model.register_forward_pre_hook(
update_output_attentions_hook,
with_kwargs=True
)

self.model.vision_model.register_forward_hook(
functools.partial(store_attention_hook, pruning_paras=self.pruning_paras),
)
elif self.model.__class__.__name__ == 'Llava':
self.model.vision_model.register_forward_pre_hook(
functools.partial(clear_attentions_hook, pruning_paras=self.pruning_paras),
)

self.model.vision_model.register_forward_hook(
functools.partial(update_attentions_hook, pruning_paras=self.pruning_paras),
)

self.model.vision_model.vision_tower.register_forward_pre_hook(
update_output_attentions_hook,
with_kwargs=True
)

self.model.vision_model.vision_tower.register_forward_hook(
functools.partial(store_attention_hook, pruning_paras=self.pruning_paras),
)

self.model.vision_projector.register_forward_pre_hook(
functools.partial(pruning_hook, pruning_pars=self.model.model.parameters),
functools.partial(pruning_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)

self.model.vlm_model.register_forward_pre_hook(
functools.partial(get_image_mask_hook, pruning_pars=self.model.model.parameters),
with_kwargs=True
)
if self.model.__class__.__name__ == 'LlavaHf':
self.model.vlm_model.register_forward_pre_hook(
functools.partial(get_image_mask_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)

self.model.model.register_forward_pre_hook(
functools.partial(
prepare_inputs_for_llm_hook, pruning_pars=self.model.model.parameters
),
with_kwargs=True
)
self.model.model.register_forward_pre_hook(
functools.partial(
prepare_inputs_for_llm_hook, pruning_paras=self.pruning_paras
),
with_kwargs=True
)
elif self.model.__class__.__name__ == 'Llava':
self.model.vision_projector.register_forward_hook(
functools.partial(prepare_inputs_hook, pruning_paras=self.pruning_paras),
)
Comment on lines +139 to +186

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using self.model.__class__.__name__ for conditional logic is brittle. Use isinstance with the specific model classes (e.g., isinstance(self.model, LlavaHf)) for a more robust approach.

84 changes: 58 additions & 26 deletions llmc/compression/token_reduction/fastv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from types import MethodType

import torch

Expand All @@ -8,6 +9,8 @@
from .token_reduction_module import TokenReductionModule
from .utils import prefill_wrapper

IMAGE_TOKEN_INDEX = -200


@TOKEN_REDUCTION_REGISTRY.register('FastV')
class FastV(TokenReductionModule):
Expand All @@ -23,41 +26,61 @@ def add_sparse_config(self):
self.model.pruning_config['image_token_length']
self.special_config['attn_scores'] = None

self.model.model.parameters = self.special_config
self.pruning_paras = self.special_config

def register_reduction_modules(self):

@prefill_wrapper
def input_hook(module, input_args, pruning_pars):
def input_hook(module, input_args, pruning_paras):
input_ids = input_args[0]
image_token_idxs = (input_ids[0] ==
pruning_pars['vision_token_index']).nonzero(as_tuple=True)[0]
pruning_pars['image_token_start_index'] = image_token_idxs[0].item()
pruning_paras['vision_token_index']).nonzero(as_tuple=True)[0]
pruning_paras['image_token_start_index'] = image_token_idxs[0].item()

return input_args

def make_hook_prepare_inputs_labels_for_multimodal(pruning_paras):
def hook_prepare_inputs_labels_for_multimodal(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes
):
if 'image_token_start_index' not in pruning_paras:
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
return self._original_prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask,
past_key_values, labels, images, image_sizes
)
return hook_prepare_inputs_labels_for_multimodal

def update_output_attentions_hook(module, args, kwargs):
kwargs['output_attentions'] = True
return args, kwargs

def store_attention_hook(m, x, layer_outputs, pruning_pars):
def store_attention_hook(m, x, layer_outputs, pruning_paras):
layer_attention = layer_outputs[1]
pruning_pars['attn_scores'] = layer_attention
pruning_paras['attn_scores'] = layer_attention

@prefill_wrapper
def fastv_pruning_hook(module, args, kwargs, pruning_pars):
def fastv_pruning_hook(module, args, kwargs, pruning_paras):

rate = pruning_pars['rate']
image_token_start_index = pruning_pars['image_token_start_index']
image_token_length = pruning_pars['image_token_length']
rate = pruning_paras['rate']
image_token_start_index = pruning_paras['image_token_start_index']
image_token_length = pruning_paras['image_token_length']

hidden_states = args[0]
causal_mask = kwargs['attention_mask']
cache_position = kwargs['cache_position']

device = hidden_states.device
# last_layer_attention = layer_outputs[1]
last_layer_attention = pruning_pars['attn_scores']
last_layer_attention = pruning_paras['attn_scores']
# compute average attention over different head
last_layer_attention_avg = torch.mean(last_layer_attention, dim=1)[0]
# generate new attention mask based on the average attention,
Expand Down Expand Up @@ -98,42 +121,51 @@ def fastv_pruning_hook(module, args, kwargs, pruning_pars):
kwargs['cache_position'] = cache_position[:new_seq_length]
kwargs['position_ids'] = position_ids
kwargs['position_embeddings'] = None
pruning_pars['attention_mask'] = causal_mask
pruning_pars['cache_position'] = cache_position[:new_seq_length]
pruning_pars['position_ids'] = position_ids
pruning_pars['position_embeddings'] = None
pruning_paras['attention_mask'] = causal_mask
pruning_paras['cache_position'] = cache_position[:new_seq_length]
pruning_paras['position_ids'] = position_ids
pruning_paras['position_embeddings'] = None

return (hidden_states,), kwargs

@prefill_wrapper
def read_parameter_hook(module, args, kwargs, pruning_pars):
kwargs['attention_mask'] = pruning_pars['attention_mask']
kwargs['cache_position'] = pruning_pars['cache_position']
kwargs['position_ids'] = pruning_pars['position_ids']
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
def read_parameter_hook(module, args, kwargs, pruning_paras):
kwargs['attention_mask'] = pruning_paras['attention_mask']
kwargs['cache_position'] = pruning_paras['cache_position']
kwargs['position_ids'] = pruning_paras['position_ids']
kwargs['position_embeddings'] = pruning_paras['position_embeddings']

return args, kwargs

self.model.embed_tokens.register_forward_pre_hook(
functools.partial(input_hook, pruning_pars=self.model.model.parameters)
)
if self.model.__class__.__name__ == 'LlavaHf':
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(input_hook, pruning_paras=self.pruning_paras)
)
elif self.model.__class__.__name__ == 'Llava':
hook_fn = make_hook_prepare_inputs_labels_for_multimodal(self.pruning_paras)
self.model.vlm_model._original_prepare_inputs_labels_for_multimodal = (
self.model.vlm_model.prepare_inputs_labels_for_multimodal
)
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
hook_fn, self.model.vlm_model
)
Comment on lines +145 to +151

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Monkey-patching methods at runtime can make the code difficult to understand and maintain. Consider a more explicit mechanism, like subclassing the model or using a wrapper class. If monkey-patching is the only viable option, add a comment explaining why it's necessary.


self.blocks[self.pruning_loc - 1].register_forward_pre_hook(
update_output_attentions_hook,
with_kwargs=True
)

self.blocks[self.pruning_loc - 1].register_forward_hook(
functools.partial(store_attention_hook, pruning_pars=self.model.model.parameters),
functools.partial(store_attention_hook, pruning_paras=self.pruning_paras),
)

self.blocks[self.pruning_loc].register_forward_pre_hook(
functools.partial(fastv_pruning_hook, pruning_pars=self.model.model.parameters),
functools.partial(fastv_pruning_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)

for idx in range(self.pruning_loc + 1, len(self.blocks)):
self.blocks[idx].register_forward_pre_hook(
functools.partial(read_parameter_hook, pruning_pars=self.model.model.parameters),
functools.partial(read_parameter_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)
1 change: 1 addition & 0 deletions llmc/compression/token_reduction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def apply_info(model, dominant_num, contextual_num):
for module in model.modules():
if isinstance(module, CLIPEncoderLayer):
module.self_attn.k_proj._info = model._info
module.self_attn.k_proj.metric = None


def add_post_hook_to_get_2dPool(model, post_hook_fn, pruning_paras):
Expand Down
Loading