Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions configs/sparsification/methods/DART/dart.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ sparse:
method: TokenReduction
special:
method: DART
pruning_loc: 2
pruning_loc: 5
reduction_ratio: 0.778
max_num_trunction: 128
pivot_image_token: 4
pivot_text_token : 4
save:
Expand Down
2 changes: 1 addition & 1 deletion configs/sparsification/methods/FastV/fastv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ sparse:
special:
method: FastV
pruning_loc: 3
rate: 0.778
rate: 0.778 # prune_rate
save:
save_trans: False
save_fake: False
Expand Down
2 changes: 1 addition & 1 deletion configs/sparsification/methods/VisionZip/visionzip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ sparse:
vision:
method: TokenReduction
special:
method: VisionZip
method: VisionZip # retain
dominant: 191 # visual_tokens = dominan_tokens + 1(cls_token)
contextual: 30
save:
Expand Down
126 changes: 31 additions & 95 deletions llmc/compression/token_reduction/dart.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import functools
import math
from functools import wraps
from types import MethodType

import torch

Expand All @@ -19,95 +17,43 @@ def __init__(self, config, model, blocks):
self.register_reduction_modules()

def add_sparse_config(self):

self.pruning_loc = self.special_config['pruning_loc']
self.special_config['image_token_length'] = \
self.model.pruning_config['image_token_length']
self.special_config['IMAGE_TOKEN_INDEX'] = \
self.model.pruning_config['IMAGE_TOKEN_INDEX']

self.pruning_paras = self.special_config

def register_reduction_modules(self):

def input_hook_llava(fn, pruning_paras):
@wraps(fn)
def wrapper(self, *args, **kwargs):
if len(args) == 0:
return fn(*args, **kwargs)
input_args = args[0]
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
return fn(*args, **kwargs)

input_ids = args[0]
attention_mask = args[2]
token_indices = (
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
)
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item()
@prefill_wrapper
def vtoken_length_hook(module, input_args, pruning_paras):

outputs = fn(*args, **kwargs)
return outputs
return wrapper
input_ids = input_args[0]
token_indices = torch.where(
input_ids[0] == pruning_paras['vision_token_index']
)[0]
pruning_paras['vision_token_length'] = token_indices.shape[0]

def get_seq_len_hook(module, args, kwargs, pruning_paras):
if kwargs['input_ids'] is not None:
pruning_paras['seq_len'] = kwargs['input_ids'].shape[1]
elif kwargs['inputs_embeds'] is not None:
pruning_paras['seq_len'] = kwargs['inputs_embeds'].shape[1]
else:
raise ValueError('You have to specify either input_ids or inputs_embeds')
return input_args

@prefill_wrapper
def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_idx):
from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb, repeat_kv)
if len(kwargs['position_ids'][0]) == 1:
return layer_outs

hidden_states = kwargs['hidden_states']
position_embeddings = kwargs['position_embeddings']
position_ids = kwargs['position_ids']
past_key_value = layer_outs[2]

bsz, q_len, _ = hidden_states.size()
query_states = module.q_proj(hidden_states)
key_states = module.k_proj(hidden_states)
value_states = module.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, module.num_heads, module.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, module.num_key_value_heads, module.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, module.num_key_value_heads, module.head_dim
).transpose(1, 2)

if position_embeddings is None:
cos, sin = module.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
key_states = past_key_value.key_cache[layer_idx]
value_states = past_key_value.value_cache[layer_idx]
key_states = repeat_kv(key_states, module.num_key_value_groups)
value_states = repeat_kv(value_states, module.num_key_value_groups)

pruning_paras['any_states'] = (query_states, key_states, value_states)
past_key_value = kwargs['past_key_value']
if past_key_value is None:
raise ValueError('DART needs past_key_value but got None.')
pruning_paras['any_states'] = past_key_value.key_cache[layer_idx]

return layer_outs

@prefill_wrapper
def pruning_hook(module, args, kwargs, pruning_paras, normlayer):

image_token_start_index = pruning_paras['image_token_start_index']
image_token_length = pruning_paras['image_token_length']
any_states = pruning_paras['any_states'][-2]
seq_length = pruning_paras['seq_len']
image_token_start_index = pruning_paras['vision_token_start_index']
image_token_length = pruning_paras['vision_token_length']
any_states = pruning_paras['any_states']

hidden_states = args[0]
attention_mask = kwargs['attention_mask']
seq_length = hidden_states.shape[1]
device = hidden_states.device
last_layer_state = normlayer(hidden_states)

Expand Down Expand Up @@ -140,27 +86,20 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone())

position_embeddings = kwargs['position_embeddings']
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
index_dim = 1 if position_embeddings[0].dim() == 3 else 2
new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone()
new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone()
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)

return (hidden_states,), kwargs

hook_fn = input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
)
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
hook_fn, self.model.vlm_model
)

self.model.model.model.register_forward_pre_hook(
functools.partial(get_seq_len_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)
if self.special_config['vision_token_length'] is None:
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
)

self.blocks[self.pruning_loc - 1].self_attn.register_forward_hook(
self.blocks[self.pruning_loc - 1].register_forward_hook(
functools.partial(
get_any_states_hook,
pruning_paras=self.pruning_paras,
Expand All @@ -173,24 +112,21 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
functools.partial(
pruning_hook,
pruning_paras=self.pruning_paras,
normlayer=self.model.model.model.norm
normlayer=self.model.language_model.norm
),
with_kwargs=True
)


def get_retained_image_token(pruning_paras, last_layer_state, any_states):
image_token_start_index = pruning_paras['image_token_start_index']
image_token_length = pruning_paras['image_token_length']
MAX_NUM_TRUNCTION = pruning_paras['max_num_trunction']
image_token_start_index = pruning_paras['vision_token_start_index']
image_token_length = pruning_paras['vision_token_length']
pivot_image_token = pruning_paras['pivot_image_token']
pivot_text_token = pruning_paras['pivot_text_token']
reduction_ratio = pruning_paras['reduction_ratio']
TOKEN_TOPK = math.ceil(
(
MAX_NUM_TRUNCTION if MAX_NUM_TRUNCTION is not None
else (image_token_length * (1 - reduction_ratio))
) // (pivot_image_token + pivot_text_token))
TOKEN_TOPK = int(
image_token_length * (1 - reduction_ratio) / (pivot_image_token + pivot_text_token)
)
device = last_layer_state.device

any_states = any_states.permute(0, 2, 1, 3)
Expand Down
59 changes: 13 additions & 46 deletions llmc/compression/token_reduction/fastv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import functools
from functools import wraps
from types import MethodType

import torch

Expand All @@ -18,46 +16,22 @@ def __init__(self, config, model, blocks):
self.register_reduction_modules()

def add_sparse_config(self):

self.pruning_loc = self.special_config['pruning_loc']
self.special_config['image_token_length'] = \
self.model.pruning_config['image_token_length']
self.special_config['IMAGE_TOKEN_INDEX'] = \
self.model.pruning_config['IMAGE_TOKEN_INDEX']
self.special_config['attn_scores'] = None

self.pruning_paras = self.special_config

def register_reduction_modules(self):

@prefill_wrapper
def input_hook(module, input_args, pruning_paras):
def vtoken_length_hook(module, input_args, pruning_paras):
input_ids = input_args[0]
image_token_idxs = (input_ids[0] ==
pruning_paras['vision_token_index']).nonzero(as_tuple=True)[0]
pruning_paras['image_token_start_index'] = image_token_idxs[0].item()

token_indices = torch.where(
input_ids[0] == pruning_paras['vision_token_index']
)[0]
pruning_paras['vision_token_length'] = token_indices.shape[0]
return input_args

def input_hook_llava(fn, pruning_paras):
@wraps(fn)
def wrapper(self, *args, **kwargs):
if len(args) == 0:
return fn(*args, **kwargs)
input_args = args[0]
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
return fn(*args, **kwargs)

input_ids = args[0]
attention_mask = args[2]
token_indices = \
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item()

outputs = fn(*args, **kwargs)
return outputs
return wrapper

@prefill_wrapper
def update_output_attentions_hook(module, args, kwargs, pruning_paras):
kwargs['output_attentions'] = True
pruning_paras['attn_scores'] = module.__class__.forward(module, *args, **kwargs)[1]
Expand All @@ -68,8 +42,8 @@ def update_output_attentions_hook(module, args, kwargs, pruning_paras):
def fastv_pruning_hook(module, args, kwargs, pruning_paras):

rate = pruning_paras['rate']
image_token_start_index = pruning_paras['image_token_start_index']
image_token_length = pruning_paras['image_token_length']
image_token_start_index = pruning_paras['vision_token_start_index']
image_token_length = pruning_paras['vision_token_length']

hidden_states = args[0]
causal_mask = kwargs['attention_mask']
Expand Down Expand Up @@ -121,24 +95,17 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):
kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone())

position_embeddings = kwargs['position_embeddings']
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
index_dim = 1 if position_embeddings[0].dim() == 3 else 2
new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone()
new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone()
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)

return (hidden_states,), kwargs

if self.model.__class__.__name__ == 'LlavaHf':
if self.special_config['vision_token_length'] is None:
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(input_hook, pruning_paras=self.pruning_paras)
)
elif self.model.__class__.__name__ == 'Llava':
hook_fn = input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
)
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
hook_fn, self.model.vlm_model
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
)

self.blocks[self.pruning_loc - 1].register_forward_pre_hook(
Expand Down
15 changes: 9 additions & 6 deletions llmc/compression/token_reduction/token_reduction_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ def set_sparse_config(self):
'video_token_length'
]
else:
self.special_config['vision_token_index'] = self.model.pruning_config[
'image_token_index'
]
self.special_config['vision_token_length'] = self.model.pruning_config[
'image_token_length'
]
self.special_config['vision_token_index'] = self.model.pruning_config.get(
'image_token_index', None
)
self.special_config['vision_token_start_index'] = self.model.pruning_config.get(
'vision_token_start_index', None
)
self.special_config['vision_token_length'] = self.model.pruning_config.get(
'image_token_length', None
)

def register_reduction_modules(self):
pass
12 changes: 3 additions & 9 deletions llmc/compression/token_reduction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,12 @@ def make_tome_class(transformer_class):
class VisionZipTransformer(transformer_class):
"""
Modifications:
- Initialize r, token size, and token sources.
- Initialize r
"""

def forward(self, *args, **kwdargs) -> torch.Tensor:
def forward(self, *args, **kwargs) -> torch.Tensor:
self._info['r'] = parse_r(len(self.vision_model.encoder.layers), self.r)
# self._info["r"] = self.r

self._info['size'] = None
self._info['source'] = None

return super().forward(*args, **kwdargs)
return super().forward(*args, **kwargs)

return VisionZipTransformer

Expand All @@ -93,7 +88,6 @@ def apply_info(model, dominant_num, contextual_num):
for module in model.modules():
if isinstance(module, CLIPEncoderLayer):
module.self_attn.k_proj._info = model._info
module.self_attn.k_proj.metric = None


def add_post_hook_to_get_2dPool(model, post_hook_fn, pruning_paras):
Expand Down
Loading