Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llmc/compression/token_reduction/holitom.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def SigLipEncoder_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
) -> Union[Tuple]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return type hint was changed to Union[Tuple], but the function can still return a BaseModelOutput instance (lines 86-90) if return_dict is true. This makes the type hint incorrect.

Suggested change
) -> Union[Tuple]:
) -> Union[Tuple, BaseModelOutput]:

output_attentions = (
output_attentions
if output_attentions is not None
Expand Down
69 changes: 61 additions & 8 deletions llmc/compression/token_reduction/pyramiddrop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import functools
import math
from functools import wraps
from types import MethodType

import torch
from torch import nn
Expand All @@ -26,13 +28,17 @@ def add_sparse_config(self):
image_token_ratio_list = self.special_config['image_token_ratio_list']
image_token_ratio_list.insert(0, 1.0)
self.special_config['image_token_ratio_list'] = image_token_ratio_list
if self.model.__class__.__name__ == 'LlavaHf':
llama_model = self.model.vlm_model.language_model.model
elif self.model.__class__.__name__ == 'Llava':
llama_model = self.model.vlm_model.model
Comment on lines +31 to +34

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/elif block to get the llama_model based on the model's class name is repeated later in this file inside pruning_hook (lines 223-226). This code duplication makes the code harder to maintain. Consider refactoring this logic into a helper method within the PyramidDrop class.

self.special_config['tokenizer_padding_side'] = getattr(
self.model.vlm_model.language_model.model.config,
llama_model.config,
'tokenizer_padding_side',
'right',
)

self.model.model.parameters = self.special_config
self.pruning_paras = self.special_config

def register_reduction_modules(self):
@prefill_wrapper
Expand Down Expand Up @@ -214,8 +220,12 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):
attention_mask_list.append(new_attention_mask)

# Truncate sequences to max length as image embeddings can make the sequence longer
if self.model.__class__.__name__ == 'LlavaHf':
llama_model = self.model.vlm_model.language_model.model
elif self.model.__class__.__name__ == 'Llava':
llama_model = self.model.vlm_model.model
tokenizer_model_max_length = getattr(
self.model.vlm_model.language_model.model.config,
llama_model.config,
'tokenizer_model_max_length',
2048,
)
Expand Down Expand Up @@ -321,6 +331,39 @@ def input_hook(module, input_args, pruning_pars):

return input_args

def input_hook_llava(fn, pruning_paras):
@wraps(fn)
def wrapper(self, *args, **kwargs):
if len(args) == 0:
return fn(*args, **kwargs)
input_args = args[0]
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
return fn(*args, **kwargs)

input_ids = args[0]
attention_mask = args[2]

image_token_posi = []
prompt_len = []
vision_tokens = []
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask):
seq = cur_input_ids[cur_attention_mask]
image_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist()
if image_index == []:
image_token_posi.append(-1)
prompt_len.append(cur_input_ids.shape[0])
else:
image_token_posi.append(image_index[0])
prompt_len.append(cur_input_ids.shape[0] - 1)
vision_tokens.append(pruning_paras['vision_token_length'])

pruning_paras['image_token_posi'] = image_token_posi
pruning_paras['prompt_len'] = prompt_len
pruning_paras['image_tokens'] = vision_tokens

return fn(*args, **kwargs)
return wrapper

@prefill_wrapper
def read_parameter_hook(module, args, kwargs, pruning_pars):
kwargs['attention_mask'] = pruning_pars['attention_mask']
Expand All @@ -330,17 +373,27 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):

return args, kwargs

self.model.embed_tokens.register_forward_pre_hook(
functools.partial(input_hook, pruning_pars=self.model.model.parameters)
)
if self.model.__class__.__name__ == 'LlavaHf':
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(input_hook, pruning_pars=self.pruning_paras)
)
elif self.model.__class__.__name__ == 'Llava':
from llava.constants import IMAGE_TOKEN_INDEX
hook_fn = input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
)
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
hook_fn, self.model.vlm_model
)
Comment on lines +376 to +388

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/elif block to register hooks based on the model type introduces logic specific to LlavaHf and Llava. A similar pattern is used in sparsevlm.py. To improve maintainability and reduce code duplication across different token reduction modules, consider abstracting this model-specific hook registration logic.


for layer_idx in range(self.pruning_loc[0], len(self.blocks)):
if layer_idx in self.pruning_loc:
stage = self.pruning_loc.index(layer_idx)
self.blocks[layer_idx].register_forward_pre_hook(
functools.partial(
pruning_hook,
pruning_pars=self.model.model.parameters,
pruning_pars=self.pruning_paras,
cur_num=stage,
layer_idx=layer_idx,
),
Expand All @@ -349,7 +402,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
else:
self.blocks[layer_idx].register_forward_pre_hook(
functools.partial(
read_parameter_hook, pruning_pars=self.model.model.parameters
read_parameter_hook, pruning_pars=self.pruning_paras
),
with_kwargs=True,
)
79 changes: 64 additions & 15 deletions llmc/compression/token_reduction/sparsevlm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import functools
from functools import wraps
from types import MethodType

import einops as ein
import torch
Expand Down Expand Up @@ -27,7 +29,7 @@ def add_sparse_config(self):
special_config['token_length_list'] = []
special_config['image_shape'] = self.model.pruning_config['image_token_length']
special_config['image_token_index'] = self.model.pruning_config['image_token_index']
self.model.model.parameters = special_config
self.pruning_paras = special_config

def register_reduction_modules(self):
@prefill_wrapper
Expand All @@ -52,16 +54,48 @@ def input_hook(module, input_args, pruning_pars):

return input_args

def input_hook_llava(fn, pruning_paras):
@wraps(fn)
def wrapper(self, *args, **kwargs):
if len(args) == 0:
return fn(*args, **kwargs)
input_args = args[0]
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
return fn(*args, **kwargs)

input_ids = args[0]
attention_mask = args[2]

pre_prompt_length_list = []
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask):
seq = cur_input_ids[cur_attention_mask]
image_token_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist()
if len(image_token_index) > 0:
pre_prompt_length_list.append(image_token_index[0])
else:
pre_prompt_length_list.append(0)
pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list

outputs = fn(*args, **kwargs)

token_length_list = []
for cur_attention_mask in outputs[2]:
token_length_list.append(cur_attention_mask.sum().item())
pruning_paras['token_length_list'] = token_length_list

return outputs
return wrapper

@prefill_wrapper_model
def register_module_pars(module, args, kwargs, pruning_pars):
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
inputs_embeds = kwargs['inputs_embeds']
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(kwargs['input_ids'])
inputs_embeds = module.embed_tokens(kwargs['input_ids'])
hidden_states = inputs_embeds # shape: (B, L, C)

pruning_pars['B'], L, _ = hidden_states.shape
B = pruning_pars['B']
B, L, _ = hidden_states.shape
pruning_pars['B'] = B
init_n = pruning_pars['init_token_total_shape'] + \
pruning_pars['generate_process_count'] # 668
pruning_pars['prev_decision'] = torch.ones(
Expand All @@ -80,7 +114,7 @@ def register_module_pars(module, args, kwargs, pruning_pars):
if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1):
v_t = hidden_states[:, v_token_start: text_token_start, :]
t_t = hidden_states[:, text_token_start:, :]
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53]
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] # 52?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment # 52? appears to be a temporary debugging note and should be removed.

Suggested change
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] # 52?
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53]

m_v_t = m_v_t.softmax(2).mean(1) # [1, 53]
pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean())

Expand Down Expand Up @@ -206,17 +240,31 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):

return args, kwargs

self.model.embed_tokens.register_forward_pre_hook(
functools.partial(
input_hook,
pruning_pars=self.model.model.parameters
if self.model.__class__.__name__ == 'LlavaHf':
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(
input_hook,
pruning_pars=self.pruning_paras
)
)
elif self.model.__class__.__name__ == 'Llava':
from llava.constants import IMAGE_TOKEN_INDEX
hook_fn = input_hook_llava(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
)
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
hook_fn, self.model.vlm_model
)
)

self.model.model.register_forward_pre_hook(
if self.model.__class__.__name__ == 'LlavaHf':
llama_model = self.model.model
elif self.model.__class__.__name__ == 'Llava':
llama_model = self.model.model.model
llama_model.register_forward_pre_hook(
functools.partial(
register_module_pars,
pruning_pars=self.model.model.parameters),
pruning_pars=self.pruning_paras),
with_kwargs=True
)
Comment on lines +243 to 269

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are two consecutive if/elif blocks here that check self.model.__class__.__name__. This introduces duplicated conditional logic and makes the code harder to read and maintain. This logic is also very similar to what's in llmc/compression/token_reduction/pyramiddrop.py. Consider determining the model type once and storing it as a boolean flag or refactoring the logic to get llama_model into a helper method.


Expand All @@ -228,15 +276,15 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
self.blocks[block_idx].register_forward_pre_hook(
functools.partial(
update_output_attentions_hook,
pruning_pars=self.model.model.parameters,
pruning_pars=self.pruning_paras,
layer_idx=block_idx,
),
with_kwargs=True
)
self.blocks[block_idx].register_forward_hook(
functools.partial(
decoder_attn_hook,
pruning_pars=self.model.model.parameters,
pruning_pars=self.pruning_paras,
layer_idx=block_idx,
),
with_kwargs=True
Expand All @@ -245,7 +293,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
self.blocks[block_idx].register_forward_pre_hook(
functools.partial(
read_parameter_hook,
pruning_pars=self.model.model.parameters
pruning_pars=self.pruning_paras
),
with_kwargs=True
)
Expand Down Expand Up @@ -278,6 +326,7 @@ def attn_postprocess_topk(
self_attn_weights = self_attn_weights.mean(1) # B, L[Q], L[K]

t_token_idx = t_token_idx[1] + text_token_start

relation_vis_text = self_attn_weights[:, t_token_idx,
v_token_start: v_token_start + v_token_num] # B, L2, L1

Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/token_reduction/tome.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def add_sparse_config(self):
else:
raise ValueError('Invalid r format. Expected int or (start, step) tuple.')

self.model.model.parameters = special_config
self.pruning_paras = special_config

def patch_layer(self):
for idx, block in enumerate(self.blocks):
Expand Down