-
Notifications
You must be signed in to change notification settings - Fork 66
refine sparsevlm for llava #418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,12 @@ | |
| from .utils import prefill_wrapper, prefill_wrapper_model | ||
|
|
||
| layer_dict = {} | ||
| prune_flag = True | ||
| merge_flag = True | ||
| sparse_token_list_192 = [] | ||
| sparse_token_list_128 = [] | ||
| sparse_token_list_64 = [] | ||
| sparse_token_dict = {} | ||
|
|
||
|
|
||
| @TOKEN_REDUCTION_REGISTRY.register('SparseVLM') | ||
|
|
@@ -26,13 +32,13 @@ def add_sparse_config(self): | |
| special_config = self.config.get('special', {}) | ||
|
|
||
| self.pruning_loc = special_config.get('pruning_loc', [2, 6, 15]) | ||
| global layer_dict | ||
| global layer_dict, prune_flag, merge_flag | ||
| layer_dict = {layer: idx for idx, layer in enumerate(self.pruning_loc)} | ||
| prune_flag = special_config.get('prune_flag', True) | ||
| merge_flag = special_config.get('merge_flag', True) | ||
| update_list() | ||
| special_config['retained_tokens'] = special_config.get('retained_tokens', 192) | ||
| special_config['init_token_total_shape'] = special_config.get('init_token_total_shape', 668) | ||
| special_config['generate_process_count'] = 0 | ||
| special_config['pre_prompt_length_list'] = [] | ||
| special_config['token_length_list'] = [] | ||
| special_config['image_shape'] = self.model.pruning_config['image_token_length'] | ||
| special_config['image_token_index'] = self.model.pruning_config['image_token_index'] | ||
| self.pruning_paras = special_config | ||
|
|
@@ -42,7 +48,6 @@ def register_reduction_modules(self): | |
| def input_hook(module, input_args, pruning_pars): | ||
| input_ids = input_args[0] | ||
| pre_prompt_length_list = [] | ||
| token_length_list = [] | ||
| IMAGE_TOKEN_INDEX = pruning_pars['image_token_index'] | ||
|
|
||
| # find the position of the first image token | ||
|
|
@@ -54,10 +59,7 @@ def input_hook(module, input_args, pruning_pars): | |
| pre_prompt_length_list.append(image_token_index[0].item()) | ||
| else: | ||
| pre_prompt_length_list.append(0) | ||
| token_length_list.append(seq.shape[0]) | ||
|
|
||
| pruning_pars['pre_prompt_length_list'] = pre_prompt_length_list | ||
| pruning_pars['token_length_list'] = token_length_list | ||
|
|
||
| return input_args | ||
|
|
||
|
|
@@ -90,11 +92,7 @@ def wrapper(self, *args, **kwargs): | |
|
|
||
| pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list | ||
|
|
||
| outputs = fn(*args, **kwargs) | ||
|
|
||
| pruning_paras['token_length_list'] = outputs[2].sum(dim=1).tolist() | ||
|
|
||
| return outputs | ||
| return fn(*args, **kwargs) | ||
| return wrapper | ||
|
|
||
| @prefill_wrapper_model | ||
|
|
@@ -106,12 +104,6 @@ def register_module_pars(module, args, kwargs, pruning_pars): | |
|
|
||
| B, L, _ = hidden_states.shape | ||
| pruning_pars['B'] = B | ||
| init_n = pruning_pars['init_token_total_shape'] + \ | ||
| pruning_pars['generate_process_count'] # 668 | ||
| pruning_pars['prev_decision'] = torch.ones( | ||
| B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device) | ||
| pruning_pars['policy'] = torch.ones( | ||
| B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device) | ||
|
|
||
| v_token_start = pre_prompt_length_list[0] if len( | ||
| pre_prompt_length_list) != 0 else 0 | ||
|
|
@@ -123,8 +115,8 @@ def register_module_pars(module, args, kwargs, pruning_pars): | |
| if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1): | ||
| v_t = hidden_states[:, v_token_start: text_token_start, :] | ||
| t_t = hidden_states[:, text_token_start:, :] | ||
| m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 52] | ||
| m_v_t = m_v_t.softmax(2).mean(1) # [1, 52] | ||
| m_v_t = v_t @ t_t.transpose(1, 2) | ||
| m_v_t = m_v_t.softmax(2).mean(1) | ||
| pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean()) | ||
|
|
||
| return args, kwargs | ||
|
|
@@ -133,6 +125,7 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx) | |
| kwargs['output_attentions'] = True | ||
| if layer_idx != self.pruning_loc[0]: | ||
| kwargs['position_ids'] = pruning_pars['position_ids'] | ||
| kwargs['attention_mask'] = pruning_pars['attention_mask'] | ||
| kwargs['cache_position'] = pruning_pars['cache_position'] | ||
| kwargs['position_embeddings'] = pruning_pars['position_embeddings'] | ||
| return args, kwargs | ||
|
|
@@ -143,8 +136,14 @@ def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx): | |
| return args, kwargs | ||
| if layer_idx != self.pruning_loc[0]: | ||
| kwargs['position_ids'] = pruning_pars['position_ids'] | ||
| kwargs['attention_mask'] = pruning_pars['attention_mask'] | ||
| kwargs['cache_position'] = pruning_pars['cache_position'] | ||
| kwargs['position_embeddings'] = pruning_pars['position_embeddings'] | ||
| else: | ||
| pruning_pars['position_ids'] = kwargs['position_ids'] | ||
| pruning_pars['attention_mask'] = kwargs['attention_mask'] | ||
| pruning_pars['cache_position'] = kwargs['cache_position'] | ||
| pruning_pars['position_embeddings'] = kwargs['position_embeddings'] | ||
| return args, kwargs | ||
|
|
||
| def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_idx): | ||
|
|
@@ -155,11 +154,6 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i | |
| from transformers.models.llama.modeling_llama import \ | ||
| apply_rotary_pos_emb | ||
|
|
||
| if layer_idx != self.pruning_loc[0]: | ||
| kwargs['position_ids'] = pruning_pars['position_ids'] | ||
| kwargs['cache_position'] = pruning_pars['cache_position'] | ||
| kwargs['position_embeddings'] = pruning_pars['position_embeddings'] | ||
|
|
||
| hidden_states = kwargs['hidden_states'] | ||
| position_embeddings = kwargs['position_embeddings'] | ||
| position_ids = kwargs['position_ids'] | ||
|
|
@@ -215,9 +209,10 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i | |
| def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx): | ||
|
|
||
| if 'attn_logits' not in pruning_pars: | ||
| attn_logits = layer_outputs[1] | ||
| attn_logits = layer_outputs[1] # for LlavaHf | ||
| else: | ||
| attn_logits = pruning_pars['attn_logits'] | ||
| prune_flag = pruning_pars.get('prune_flag', True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| merge_flag = pruning_pars['merge_flag'] | ||
| v_token_start = pruning_pars['v_token_start'] | ||
| v_token_num = pruning_pars['v_token_num'] | ||
|
|
@@ -227,13 +222,11 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer | |
| B = pruning_pars['B'] | ||
| pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] | ||
| image_shape = pruning_pars['image_shape'] | ||
| if layer_idx == self.pruning_loc[0]: | ||
| position_ids = kwargs['position_ids'] | ||
| pruning_pars['position_ids'] = position_ids | ||
| else: | ||
| position_ids = pruning_pars['position_ids'] | ||
| hidden_states = inputs[0] # [B, L, D] | ||
|
|
||
| attention_mask = kwargs['attention_mask'] | ||
| position_embeddings = kwargs['position_embeddings'] | ||
|
|
||
| hidden_states = inputs[0] # [B, L, D] | ||
| pred_score_vis, s_flag, relation_vis_text = attn_postprocess_topk( | ||
| attn_logits, | ||
| v_token_start, | ||
|
|
@@ -243,7 +236,8 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer | |
| layer_idx, | ||
| retained_tokens | ||
| ) | ||
|
|
||
| if not prune_flag: | ||
| pred_score_vis = torch.zeros_like(relation_vis_text, dtype=bool) | ||
| policy = torch.ones(B, hidden_states.shape[1], dtype=hidden_states.dtype, | ||
| device=hidden_states.device) | ||
| policy[:, v_token_start:text_token_start] = \ | ||
|
|
@@ -261,60 +255,91 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer | |
|
|
||
| # merge and cluster | ||
| if s_flag and merge_flag and total_sparse_token_idx.shape[1] > 0: | ||
| total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx) | ||
| total_sparse_token = batch_index_select( | ||
| layer_outputs[0], total_sparse_token_idx | ||
| ) | ||
|
|
||
| merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1] | ||
| merge_token_stage1 = relation_vis_text[0][merge_token_idx_stage1] | ||
| merge_token_num_stage1 = int(merge_token_idx_stage1.shape[0] * 0.3) + 1 # Top 30% | ||
| if prune_flag: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| merge_token_num_stage1 = int(merge_token_idx_stage1.shape[0] * 0.3) + 1 | ||
| else: | ||
| merge_token_num_stage1 = ( | ||
| merge_token_idx_stage1.shape[0] | ||
| - sparse_token_dict[retained_tokens][layer_dict[layer_idx]] | ||
| ) | ||
| merge_token_stage2_idx = merge_token_stage1.topk(merge_token_num_stage1)[1] | ||
| if not prune_flag: | ||
| all_idx = torch.arange( | ||
| merge_token_stage1.size(0), | ||
| device=merge_token_stage1.device | ||
| ) | ||
| non_topk_idx = all_idx[~torch.isin(all_idx, merge_token_stage2_idx)] | ||
| pred_score_vis[0][non_topk_idx] = 1 | ||
| policy[:, v_token_start:text_token_start] = \ | ||
| pred_score_vis.type(dtype=hidden_states.dtype) | ||
|
|
||
| merge_token_stage2 = total_sparse_token[:, merge_token_stage2_idx, :] | ||
| cluster_num = int(merge_token_stage2.shape[1] / 10) + 1 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| if cluster_num == 0: | ||
| cluster_num = merge_token_stage2.shape[1] | ||
| merge_sparse_token, index_down = cluster_and_merge(merge_token_stage2, cluster_num) | ||
|
|
||
| merge_sparse_token = cluster_and_merge(merge_token_stage2, cluster_num) | ||
|
|
||
| cluster_idx = total_sparse_token_idx.squeeze(0)[merge_token_stage2_idx[index_down]] | ||
| cluster_idx = cluster_idx.squeeze(0) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| select_token_idx = torch.where(policy == 1)[1].unsqueeze(0) | ||
| select_token = batch_index_select(layer_outputs[0], select_token_idx) | ||
| select_vis_token_num = pred_score_vis.sum() | ||
|
|
||
| keep_indexs = torch.cat( | ||
| ( | ||
| select_token_idx.squeeze(0)[:v_token_start + select_vis_token_num], | ||
| cluster_idx, | ||
| select_token_idx.squeeze(0)[v_token_start + select_vis_token_num:] | ||
| ) | ||
| ) | ||
| select_and_merge_token = torch.cat( | ||
| ( | ||
| select_token[:, :v_token_start + | ||
| select_vis_token_num, :], | ||
| select_token[:, :v_token_start + select_vis_token_num, :], | ||
| merge_sparse_token, | ||
| select_token[:, v_token_start + | ||
| select_vis_token_num:, :] | ||
| select_token[:, v_token_start + select_vis_token_num:, :] | ||
| ), | ||
| dim=1 | ||
| ) | ||
| layer_outputs = (select_and_merge_token, layer_outputs[1]) | ||
| position_ids = position_ids[:, :len(select_token_idx[0]) + cluster_num] | ||
| v_token_num = pred_score_vis.sum() + cluster_num | ||
| text_token_start = v_token_start + v_token_num | ||
|
|
||
| else: | ||
| select_token_idx = torch.where(policy == 1)[1].unsqueeze(0) | ||
| keep_indexs = torch.where(policy == 1)[1] | ||
| select_token_idx = keep_indexs.unsqueeze(0) | ||
| layer_outputs = (batch_index_select(layer_outputs[0], select_token_idx), | ||
| layer_outputs[1]) | ||
| position_ids = position_ids[:, :len(select_token_idx[0])] | ||
| v_token_num = pred_score_vis.sum() | ||
| text_token_start = v_token_start + v_token_num | ||
|
|
||
| text_token_start = v_token_start + v_token_num | ||
| position_ids = keep_indexs.unsqueeze(0) | ||
| new_output = layer_outputs | ||
| cache_position = position_ids.detach().clone() | ||
| cache_position = position_ids.squeeze(0) | ||
|
|
||
| if attention_mask is not None: | ||
| attention_mask = attention_mask[:, :, keep_indexs, keep_indexs] | ||
| new_pe0 = position_embeddings[0][:, keep_indexs, :].clone() | ||
| new_pe1 = position_embeddings[1][:, keep_indexs, :].clone() | ||
| position_embeddings = (new_pe0, new_pe1) | ||
|
|
||
| pruning_pars['v_token_num'] = v_token_num | ||
| pruning_pars['text_token_start'] = text_token_start | ||
|
|
||
| pruning_pars['position_ids'] = position_ids | ||
| pruning_pars['cache_position'] = cache_position | ||
| pruning_pars['position_embeddings'] = None | ||
| pruning_pars['position_embeddings'] = position_embeddings | ||
| pruning_pars['attention_mask'] = attention_mask | ||
|
|
||
| return new_output | ||
|
|
||
| @prefill_wrapper | ||
| def read_parameter_hook(module, args, kwargs, pruning_pars): | ||
| kwargs['position_ids'] = pruning_pars['position_ids'] | ||
| kwargs['attention_mask'] = pruning_pars['attention_mask'] | ||
| kwargs['cache_position'] = pruning_pars['cache_position'] | ||
| kwargs['position_embeddings'] = pruning_pars['position_embeddings'] | ||
|
|
||
|
|
@@ -363,7 +388,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): | |
| with_kwargs=True | ||
| ) | ||
| elif self.model.__class__.__name__ == 'Llava': | ||
| self.blocks[block_idx].self_attn.register_forward_pre_hook( | ||
| self.blocks[block_idx].register_forward_pre_hook( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider registering the forward pre-hook directly to the block instead of the self_attn module. This simplifies the hook registration and ensures that the hook is applied to the entire block. self.blocks[block_idx].register_forward_pre_hook(
functools.partial(
update_kwargs_hook,
pruning_pars=self.pruning_paras,
layer_idx=block_idx
),
with_kwargs=True
) |
||
| functools.partial( | ||
| update_kwargs_hook, | ||
| pruning_pars=self.pruning_paras, | ||
|
|
@@ -383,7 +408,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): | |
| functools.partial( | ||
| decoder_attn_hook, | ||
| pruning_pars=self.pruning_paras, | ||
| layer_idx=block_idx, | ||
| layer_idx=block_idx | ||
| ), | ||
| with_kwargs=True | ||
| ) | ||
|
|
@@ -397,17 +422,37 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): | |
| ) | ||
|
|
||
|
|
||
| layer_dict = {2: 0, 6: 1, 15: 2} | ||
|
|
||
| sparse_token_list_192 = [300, 200, 110] # 2*576 4*300 10*200 16*110 | ||
| sparse_token_list_128 = [303, 110, 36] | ||
| sparse_token_list_64 = [66, 30, 17] | ||
| def update_list(): | ||
| global sparse_token_list_192, sparse_token_list_128, sparse_token_list_64 | ||
| global prune_flag, merge_flag, sparse_token_dict | ||
|
|
||
| if layer_dict == {2: 0, 6: 1, 15: 2}: # 2*576 4*300 10*200 16*110 | ||
| sparse_token_list_192 = [300, 200, 110] | ||
| sparse_token_list_128 = [303, 110, 36] | ||
| sparse_token_list_64 = [66, 30, 17] | ||
| prune_flag, merge_flag = True, True | ||
| elif prune_flag and merge_flag: | ||
| sparse_token_list_192 = [180] | ||
| sparse_token_list_128 = [114] | ||
| sparse_token_list_64 = [48] | ||
| elif prune_flag: | ||
| sparse_token_list_192 = [192] | ||
| sparse_token_list_128 = [128] | ||
| sparse_token_list_64 = [64] | ||
| elif merge_flag: | ||
| sparse_token_list_192 = [149] | ||
| sparse_token_list_128 = [78] | ||
| sparse_token_list_64 = [7] | ||
| else: | ||
| raise RuntimeError( | ||
| 'Both prune_flag and merge_flag are False — sparseVLM is inactive.' | ||
| ) | ||
|
|
||
| sparse_token_dict = { | ||
| 192: sparse_token_list_192, | ||
| 128: sparse_token_list_128, | ||
| 64: sparse_token_list_64 | ||
| } | ||
| sparse_token_dict = { | ||
| 192: sparse_token_list_192, | ||
| 128: sparse_token_list_128, | ||
| 64: sparse_token_list_64 | ||
| } | ||
|
Comment on lines
+425
to
+455
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function heavily relies on and modifies global variables, contributing to maintainability issues. The condition |
||
|
|
||
|
|
||
| def attn_postprocess_topk( | ||
|
|
@@ -567,4 +612,4 @@ def cluster_and_merge(x, cluster_num): | |
| source=source.reshape(B * N, C).type(x.dtype)) | ||
| x_merged = x_merged.reshape(B, cluster_num, C) | ||
|
|
||
| return x_merged | ||
| return x_merged, index_down | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The introduction of these module-level global variables can cause unexpected behavior, especially if multiple
SparseVLMinstances are created. Each instance could overwrite the global configuration, leading to race conditions or incorrect configurations. Consider encapsulating these variables as instance attributes of theSparseVLMclass to ensure that each instance manages its own state.