Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions configs/sparsification/methods/SparseVLM/sparsevlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ sparse:
method: TokenReduction
special:
method: SparseVLM
pruning_loc: [2] # [2, 6, 15]
pruning_loc: [2, 6, 15]
retained_tokens: 192
init_token_total_shape: 668
merge_flag: False
prune_flag: True
merge_flag: True
save:
save_trans: False
save_fake: False
Expand Down
171 changes: 108 additions & 63 deletions llmc/compression/token_reduction/sparsevlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
from .utils import prefill_wrapper, prefill_wrapper_model

layer_dict = {}
prune_flag = True
merge_flag = True
sparse_token_list_192 = []
sparse_token_list_128 = []
sparse_token_list_64 = []
sparse_token_dict = {}
Comment on lines +16 to +21

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The introduction of these module-level global variables can cause unexpected behavior, especially if multiple SparseVLM instances are created. Each instance could overwrite the global configuration, leading to race conditions or incorrect configurations. Consider encapsulating these variables as instance attributes of the SparseVLM class to ensure that each instance manages its own state.



@TOKEN_REDUCTION_REGISTRY.register('SparseVLM')
Expand All @@ -26,13 +32,13 @@ def add_sparse_config(self):
special_config = self.config.get('special', {})

self.pruning_loc = special_config.get('pruning_loc', [2, 6, 15])
global layer_dict
global layer_dict, prune_flag, merge_flag
layer_dict = {layer: idx for idx, layer in enumerate(self.pruning_loc)}
prune_flag = special_config.get('prune_flag', True)
merge_flag = special_config.get('merge_flag', True)
update_list()
special_config['retained_tokens'] = special_config.get('retained_tokens', 192)
special_config['init_token_total_shape'] = special_config.get('init_token_total_shape', 668)
special_config['generate_process_count'] = 0
special_config['pre_prompt_length_list'] = []
special_config['token_length_list'] = []
special_config['image_shape'] = self.model.pruning_config['image_token_length']
special_config['image_token_index'] = self.model.pruning_config['image_token_index']
self.pruning_paras = special_config
Expand All @@ -42,7 +48,6 @@ def register_reduction_modules(self):
def input_hook(module, input_args, pruning_pars):
input_ids = input_args[0]
pre_prompt_length_list = []
token_length_list = []
IMAGE_TOKEN_INDEX = pruning_pars['image_token_index']

# find the position of the first image token
Expand All @@ -54,10 +59,7 @@ def input_hook(module, input_args, pruning_pars):
pre_prompt_length_list.append(image_token_index[0].item())
else:
pre_prompt_length_list.append(0)
token_length_list.append(seq.shape[0])

pruning_pars['pre_prompt_length_list'] = pre_prompt_length_list
pruning_pars['token_length_list'] = token_length_list

return input_args

Expand Down Expand Up @@ -90,11 +92,7 @@ def wrapper(self, *args, **kwargs):

pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list

outputs = fn(*args, **kwargs)

pruning_paras['token_length_list'] = outputs[2].sum(dim=1).tolist()

return outputs
return fn(*args, **kwargs)
return wrapper

@prefill_wrapper_model
Expand All @@ -106,12 +104,6 @@ def register_module_pars(module, args, kwargs, pruning_pars):

B, L, _ = hidden_states.shape
pruning_pars['B'] = B
init_n = pruning_pars['init_token_total_shape'] + \
pruning_pars['generate_process_count'] # 668
pruning_pars['prev_decision'] = torch.ones(
B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device)
pruning_pars['policy'] = torch.ones(
B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device)

v_token_start = pre_prompt_length_list[0] if len(
pre_prompt_length_list) != 0 else 0
Expand All @@ -123,8 +115,8 @@ def register_module_pars(module, args, kwargs, pruning_pars):
if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1):
v_t = hidden_states[:, v_token_start: text_token_start, :]
t_t = hidden_states[:, text_token_start:, :]
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 52]
m_v_t = m_v_t.softmax(2).mean(1) # [1, 52]
m_v_t = v_t @ t_t.transpose(1, 2)
m_v_t = m_v_t.softmax(2).mean(1)
pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean())

return args, kwargs
Expand All @@ -133,6 +125,7 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
kwargs['output_attentions'] = True
if layer_idx != self.pruning_loc[0]:
kwargs['position_ids'] = pruning_pars['position_ids']
kwargs['attention_mask'] = pruning_pars['attention_mask']
kwargs['cache_position'] = pruning_pars['cache_position']
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
return args, kwargs
Expand All @@ -143,8 +136,14 @@ def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx):
return args, kwargs
if layer_idx != self.pruning_loc[0]:
kwargs['position_ids'] = pruning_pars['position_ids']
kwargs['attention_mask'] = pruning_pars['attention_mask']
kwargs['cache_position'] = pruning_pars['cache_position']
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
else:
pruning_pars['position_ids'] = kwargs['position_ids']
pruning_pars['attention_mask'] = kwargs['attention_mask']
pruning_pars['cache_position'] = kwargs['cache_position']
pruning_pars['position_embeddings'] = kwargs['position_embeddings']
return args, kwargs

def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_idx):
Expand All @@ -155,11 +154,6 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i
from transformers.models.llama.modeling_llama import \
apply_rotary_pos_emb

if layer_idx != self.pruning_loc[0]:
kwargs['position_ids'] = pruning_pars['position_ids']
kwargs['cache_position'] = pruning_pars['cache_position']
kwargs['position_embeddings'] = pruning_pars['position_embeddings']

hidden_states = kwargs['hidden_states']
position_embeddings = kwargs['position_embeddings']
position_ids = kwargs['position_ids']
Expand Down Expand Up @@ -215,9 +209,10 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i
def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx):

if 'attn_logits' not in pruning_pars:
attn_logits = layer_outputs[1]
attn_logits = layer_outputs[1] # for LlavaHf
else:
attn_logits = pruning_pars['attn_logits']
prune_flag = pruning_pars.get('prune_flag', True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency, merge_flag should also be accessed using .get() to prevent potential KeyError exceptions, similar to how prune_flag is handled.

prune_flag = pruning_pars.get('prune_flag', True)

merge_flag = pruning_pars['merge_flag']
v_token_start = pruning_pars['v_token_start']
v_token_num = pruning_pars['v_token_num']
Expand All @@ -227,13 +222,11 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
B = pruning_pars['B']
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
image_shape = pruning_pars['image_shape']
if layer_idx == self.pruning_loc[0]:
position_ids = kwargs['position_ids']
pruning_pars['position_ids'] = position_ids
else:
position_ids = pruning_pars['position_ids']
hidden_states = inputs[0] # [B, L, D]

attention_mask = kwargs['attention_mask']
position_embeddings = kwargs['position_embeddings']

hidden_states = inputs[0] # [B, L, D]
pred_score_vis, s_flag, relation_vis_text = attn_postprocess_topk(
attn_logits,
v_token_start,
Expand All @@ -243,7 +236,8 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
layer_idx,
retained_tokens
)

if not prune_flag:
pred_score_vis = torch.zeros_like(relation_vis_text, dtype=bool)
policy = torch.ones(B, hidden_states.shape[1], dtype=hidden_states.dtype,
device=hidden_states.device)
policy[:, v_token_start:text_token_start] = \
Expand All @@ -261,60 +255,91 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer

# merge and cluster
if s_flag and merge_flag and total_sparse_token_idx.shape[1] > 0:
total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx)
total_sparse_token = batch_index_select(
layer_outputs[0], total_sparse_token_idx
)

merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1]
merge_token_stage1 = relation_vis_text[0][merge_token_idx_stage1]
merge_token_num_stage1 = int(merge_token_idx_stage1.shape[0] * 0.3) + 1 # Top 30%
if prune_flag:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value 0.3 is a magic number. Define it as a named constant to improve readability and maintainability.

merge_token_num_stage1 = int(merge_token_idx_stage1.shape[0] * 0.3) + 1
else:
merge_token_num_stage1 = (
merge_token_idx_stage1.shape[0]
- sparse_token_dict[retained_tokens][layer_dict[layer_idx]]
)
merge_token_stage2_idx = merge_token_stage1.topk(merge_token_num_stage1)[1]
if not prune_flag:
all_idx = torch.arange(
merge_token_stage1.size(0),
device=merge_token_stage1.device
)
non_topk_idx = all_idx[~torch.isin(all_idx, merge_token_stage2_idx)]
pred_score_vis[0][non_topk_idx] = 1
policy[:, v_token_start:text_token_start] = \
pred_score_vis.type(dtype=hidden_states.dtype)

merge_token_stage2 = total_sparse_token[:, merge_token_stage2_idx, :]
cluster_num = int(merge_token_stage2.shape[1] / 10) + 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value 10 is a magic number used to determine the number of clusters. Define it as a named constant to improve code clarity and maintainability.

if cluster_num == 0:
cluster_num = merge_token_stage2.shape[1]
merge_sparse_token, index_down = cluster_and_merge(merge_token_stage2, cluster_num)

merge_sparse_token = cluster_and_merge(merge_token_stage2, cluster_num)

cluster_idx = total_sparse_token_idx.squeeze(0)[merge_token_stage2_idx[index_down]]
cluster_idx = cluster_idx.squeeze(0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tensor cluster_idx is 1-dimensional. If its size is 1 (i.e., shape is (1,)), squeeze(0) will convert it into a 0-dimensional scalar. This will cause an error in the torch.cat operation on line 293, which expects a sequence of tensors. Remove this line.

select_token_idx = torch.where(policy == 1)[1].unsqueeze(0)
select_token = batch_index_select(layer_outputs[0], select_token_idx)
select_vis_token_num = pred_score_vis.sum()

keep_indexs = torch.cat(
(
select_token_idx.squeeze(0)[:v_token_start + select_vis_token_num],
cluster_idx,
select_token_idx.squeeze(0)[v_token_start + select_vis_token_num:]
)
)
select_and_merge_token = torch.cat(
(
select_token[:, :v_token_start +
select_vis_token_num, :],
select_token[:, :v_token_start + select_vis_token_num, :],
merge_sparse_token,
select_token[:, v_token_start +
select_vis_token_num:, :]
select_token[:, v_token_start + select_vis_token_num:, :]
),
dim=1
)
layer_outputs = (select_and_merge_token, layer_outputs[1])
position_ids = position_ids[:, :len(select_token_idx[0]) + cluster_num]
v_token_num = pred_score_vis.sum() + cluster_num
text_token_start = v_token_start + v_token_num

else:
select_token_idx = torch.where(policy == 1)[1].unsqueeze(0)
keep_indexs = torch.where(policy == 1)[1]
select_token_idx = keep_indexs.unsqueeze(0)
layer_outputs = (batch_index_select(layer_outputs[0], select_token_idx),
layer_outputs[1])
position_ids = position_ids[:, :len(select_token_idx[0])]
v_token_num = pred_score_vis.sum()
text_token_start = v_token_start + v_token_num

text_token_start = v_token_start + v_token_num
position_ids = keep_indexs.unsqueeze(0)
new_output = layer_outputs
cache_position = position_ids.detach().clone()
cache_position = position_ids.squeeze(0)

if attention_mask is not None:
attention_mask = attention_mask[:, :, keep_indexs, keep_indexs]
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
position_embeddings = (new_pe0, new_pe1)

pruning_pars['v_token_num'] = v_token_num
pruning_pars['text_token_start'] = text_token_start

pruning_pars['position_ids'] = position_ids
pruning_pars['cache_position'] = cache_position
pruning_pars['position_embeddings'] = None
pruning_pars['position_embeddings'] = position_embeddings
pruning_pars['attention_mask'] = attention_mask

return new_output

@prefill_wrapper
def read_parameter_hook(module, args, kwargs, pruning_pars):
kwargs['position_ids'] = pruning_pars['position_ids']
kwargs['attention_mask'] = pruning_pars['attention_mask']
kwargs['cache_position'] = pruning_pars['cache_position']
kwargs['position_embeddings'] = pruning_pars['position_embeddings']

Expand Down Expand Up @@ -363,7 +388,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
with_kwargs=True
)
elif self.model.__class__.__name__ == 'Llava':
self.blocks[block_idx].self_attn.register_forward_pre_hook(
self.blocks[block_idx].register_forward_pre_hook(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider registering the forward pre-hook directly to the block instead of the self_attn module. This simplifies the hook registration and ensures that the hook is applied to the entire block.

self.blocks[block_idx].register_forward_pre_hook(
    functools.partial(
        update_kwargs_hook,
        pruning_pars=self.pruning_paras,
        layer_idx=block_idx
    ),
    with_kwargs=True
)

functools.partial(
update_kwargs_hook,
pruning_pars=self.pruning_paras,
Expand All @@ -383,7 +408,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
functools.partial(
decoder_attn_hook,
pruning_pars=self.pruning_paras,
layer_idx=block_idx,
layer_idx=block_idx
),
with_kwargs=True
)
Expand All @@ -397,17 +422,37 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
)


layer_dict = {2: 0, 6: 1, 15: 2}

sparse_token_list_192 = [300, 200, 110] # 2*576 4*300 10*200 16*110
sparse_token_list_128 = [303, 110, 36]
sparse_token_list_64 = [66, 30, 17]
def update_list():
global sparse_token_list_192, sparse_token_list_128, sparse_token_list_64
global prune_flag, merge_flag, sparse_token_dict

if layer_dict == {2: 0, 6: 1, 15: 2}: # 2*576 4*300 10*200 16*110
sparse_token_list_192 = [300, 200, 110]
sparse_token_list_128 = [303, 110, 36]
sparse_token_list_64 = [66, 30, 17]
prune_flag, merge_flag = True, True
elif prune_flag and merge_flag:
sparse_token_list_192 = [180]
sparse_token_list_128 = [114]
sparse_token_list_64 = [48]
elif prune_flag:
sparse_token_list_192 = [192]
sparse_token_list_128 = [128]
sparse_token_list_64 = [64]
elif merge_flag:
sparse_token_list_192 = [149]
sparse_token_list_128 = [78]
sparse_token_list_64 = [7]
else:
raise RuntimeError(
'Both prune_flag and merge_flag are False — sparseVLM is inactive.'
)

sparse_token_dict = {
192: sparse_token_list_192,
128: sparse_token_list_128,
64: sparse_token_list_64
}
sparse_token_dict = {
192: sparse_token_list_192,
128: sparse_token_list_128,
64: sparse_token_list_64
}
Comment on lines +425 to +455

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function heavily relies on and modifies global variables, contributing to maintainability issues. The condition if layer_dict == {2: 0, 6: 1, 15: 2}: is brittle. It hardcodes a specific configuration, making the code less flexible. Refactor this into a private method of SparseVLM that operates on instance attributes.



def attn_postprocess_topk(
Expand Down Expand Up @@ -567,4 +612,4 @@ def cluster_and_merge(x, cluster_num):
source=source.reshape(B * N, C).type(x.dtype))
x_merged = x_merged.reshape(B, cluster_num, C)

return x_merged
return x_merged, index_down