diff --git a/configs/sparsification/methods/SparseVLM/sparsevlm.yml b/configs/sparsification/methods/SparseVLM/sparsevlm.yml index 84d91ada..8dc69d92 100644 --- a/configs/sparsification/methods/SparseVLM/sparsevlm.yml +++ b/configs/sparsification/methods/SparseVLM/sparsevlm.yml @@ -16,10 +16,10 @@ sparse: method: TokenReduction special: method: SparseVLM - pruning_loc: [2] # [2, 6, 15] + pruning_loc: [2, 6, 15] retained_tokens: 192 - init_token_total_shape: 668 - merge_flag: False + prune_flag: True + merge_flag: True save: save_trans: False save_fake: False diff --git a/llmc/compression/token_reduction/sparsevlm.py b/llmc/compression/token_reduction/sparsevlm.py index 997a88b7..22269cf7 100755 --- a/llmc/compression/token_reduction/sparsevlm.py +++ b/llmc/compression/token_reduction/sparsevlm.py @@ -13,6 +13,12 @@ from .utils import prefill_wrapper, prefill_wrapper_model layer_dict = {} +prune_flag = True +merge_flag = True +sparse_token_list_192 = [] +sparse_token_list_128 = [] +sparse_token_list_64 = [] +sparse_token_dict = {} @TOKEN_REDUCTION_REGISTRY.register('SparseVLM') @@ -26,13 +32,13 @@ def add_sparse_config(self): special_config = self.config.get('special', {}) self.pruning_loc = special_config.get('pruning_loc', [2, 6, 15]) - global layer_dict + global layer_dict, prune_flag, merge_flag layer_dict = {layer: idx for idx, layer in enumerate(self.pruning_loc)} + prune_flag = special_config.get('prune_flag', True) + merge_flag = special_config.get('merge_flag', True) + update_list() special_config['retained_tokens'] = special_config.get('retained_tokens', 192) - special_config['init_token_total_shape'] = special_config.get('init_token_total_shape', 668) - special_config['generate_process_count'] = 0 special_config['pre_prompt_length_list'] = [] - special_config['token_length_list'] = [] special_config['image_shape'] = self.model.pruning_config['image_token_length'] special_config['image_token_index'] = self.model.pruning_config['image_token_index'] self.pruning_paras = special_config @@ -42,7 +48,6 @@ def register_reduction_modules(self): def input_hook(module, input_args, pruning_pars): input_ids = input_args[0] pre_prompt_length_list = [] - token_length_list = [] IMAGE_TOKEN_INDEX = pruning_pars['image_token_index'] # find the position of the first image token @@ -54,10 +59,7 @@ def input_hook(module, input_args, pruning_pars): pre_prompt_length_list.append(image_token_index[0].item()) else: pre_prompt_length_list.append(0) - token_length_list.append(seq.shape[0]) - pruning_pars['pre_prompt_length_list'] = pre_prompt_length_list - pruning_pars['token_length_list'] = token_length_list return input_args @@ -90,11 +92,7 @@ def wrapper(self, *args, **kwargs): pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list - outputs = fn(*args, **kwargs) - - pruning_paras['token_length_list'] = outputs[2].sum(dim=1).tolist() - - return outputs + return fn(*args, **kwargs) return wrapper @prefill_wrapper_model @@ -106,12 +104,6 @@ def register_module_pars(module, args, kwargs, pruning_pars): B, L, _ = hidden_states.shape pruning_pars['B'] = B - init_n = pruning_pars['init_token_total_shape'] + \ - pruning_pars['generate_process_count'] # 668 - pruning_pars['prev_decision'] = torch.ones( - B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device) - pruning_pars['policy'] = torch.ones( - B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device) v_token_start = pre_prompt_length_list[0] if len( pre_prompt_length_list) != 0 else 0 @@ -123,8 +115,8 @@ def register_module_pars(module, args, kwargs, pruning_pars): if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1): v_t = hidden_states[:, v_token_start: text_token_start, :] t_t = hidden_states[:, text_token_start:, :] - m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 52] - m_v_t = m_v_t.softmax(2).mean(1) # [1, 52] + m_v_t = v_t @ t_t.transpose(1, 2) + m_v_t = m_v_t.softmax(2).mean(1) pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean()) return args, kwargs @@ -133,6 +125,7 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx) kwargs['output_attentions'] = True if layer_idx != self.pruning_loc[0]: kwargs['position_ids'] = pruning_pars['position_ids'] + kwargs['attention_mask'] = pruning_pars['attention_mask'] kwargs['cache_position'] = pruning_pars['cache_position'] kwargs['position_embeddings'] = pruning_pars['position_embeddings'] return args, kwargs @@ -143,8 +136,14 @@ def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx): return args, kwargs if layer_idx != self.pruning_loc[0]: kwargs['position_ids'] = pruning_pars['position_ids'] + kwargs['attention_mask'] = pruning_pars['attention_mask'] kwargs['cache_position'] = pruning_pars['cache_position'] kwargs['position_embeddings'] = pruning_pars['position_embeddings'] + else: + pruning_pars['position_ids'] = kwargs['position_ids'] + pruning_pars['attention_mask'] = kwargs['attention_mask'] + pruning_pars['cache_position'] = kwargs['cache_position'] + pruning_pars['position_embeddings'] = kwargs['position_embeddings'] return args, kwargs def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_idx): @@ -155,11 +154,6 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i from transformers.models.llama.modeling_llama import \ apply_rotary_pos_emb - if layer_idx != self.pruning_loc[0]: - kwargs['position_ids'] = pruning_pars['position_ids'] - kwargs['cache_position'] = pruning_pars['cache_position'] - kwargs['position_embeddings'] = pruning_pars['position_embeddings'] - hidden_states = kwargs['hidden_states'] position_embeddings = kwargs['position_embeddings'] position_ids = kwargs['position_ids'] @@ -215,9 +209,10 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx): if 'attn_logits' not in pruning_pars: - attn_logits = layer_outputs[1] + attn_logits = layer_outputs[1] # for LlavaHf else: attn_logits = pruning_pars['attn_logits'] + prune_flag = pruning_pars.get('prune_flag', True) merge_flag = pruning_pars['merge_flag'] v_token_start = pruning_pars['v_token_start'] v_token_num = pruning_pars['v_token_num'] @@ -227,13 +222,11 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer B = pruning_pars['B'] pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] image_shape = pruning_pars['image_shape'] - if layer_idx == self.pruning_loc[0]: - position_ids = kwargs['position_ids'] - pruning_pars['position_ids'] = position_ids - else: - position_ids = pruning_pars['position_ids'] - hidden_states = inputs[0] # [B, L, D] + attention_mask = kwargs['attention_mask'] + position_embeddings = kwargs['position_embeddings'] + + hidden_states = inputs[0] # [B, L, D] pred_score_vis, s_flag, relation_vis_text = attn_postprocess_topk( attn_logits, v_token_start, @@ -243,7 +236,8 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer layer_idx, retained_tokens ) - + if not prune_flag: + pred_score_vis = torch.zeros_like(relation_vis_text, dtype=bool) policy = torch.ones(B, hidden_states.shape[1], dtype=hidden_states.dtype, device=hidden_states.device) policy[:, v_token_start:text_token_start] = \ @@ -261,60 +255,91 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer # merge and cluster if s_flag and merge_flag and total_sparse_token_idx.shape[1] > 0: - total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx) + total_sparse_token = batch_index_select( + layer_outputs[0], total_sparse_token_idx + ) merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1] merge_token_stage1 = relation_vis_text[0][merge_token_idx_stage1] - merge_token_num_stage1 = int(merge_token_idx_stage1.shape[0] * 0.3) + 1 # Top 30% + if prune_flag: + merge_token_num_stage1 = int(merge_token_idx_stage1.shape[0] * 0.3) + 1 + else: + merge_token_num_stage1 = ( + merge_token_idx_stage1.shape[0] + - sparse_token_dict[retained_tokens][layer_dict[layer_idx]] + ) merge_token_stage2_idx = merge_token_stage1.topk(merge_token_num_stage1)[1] + if not prune_flag: + all_idx = torch.arange( + merge_token_stage1.size(0), + device=merge_token_stage1.device + ) + non_topk_idx = all_idx[~torch.isin(all_idx, merge_token_stage2_idx)] + pred_score_vis[0][non_topk_idx] = 1 + policy[:, v_token_start:text_token_start] = \ + pred_score_vis.type(dtype=hidden_states.dtype) merge_token_stage2 = total_sparse_token[:, merge_token_stage2_idx, :] cluster_num = int(merge_token_stage2.shape[1] / 10) + 1 if cluster_num == 0: cluster_num = merge_token_stage2.shape[1] + merge_sparse_token, index_down = cluster_and_merge(merge_token_stage2, cluster_num) - merge_sparse_token = cluster_and_merge(merge_token_stage2, cluster_num) - + cluster_idx = total_sparse_token_idx.squeeze(0)[merge_token_stage2_idx[index_down]] + cluster_idx = cluster_idx.squeeze(0) select_token_idx = torch.where(policy == 1)[1].unsqueeze(0) select_token = batch_index_select(layer_outputs[0], select_token_idx) select_vis_token_num = pred_score_vis.sum() - + keep_indexs = torch.cat( + ( + select_token_idx.squeeze(0)[:v_token_start + select_vis_token_num], + cluster_idx, + select_token_idx.squeeze(0)[v_token_start + select_vis_token_num:] + ) + ) select_and_merge_token = torch.cat( ( - select_token[:, :v_token_start + - select_vis_token_num, :], + select_token[:, :v_token_start + select_vis_token_num, :], merge_sparse_token, - select_token[:, v_token_start + - select_vis_token_num:, :] + select_token[:, v_token_start + select_vis_token_num:, :] ), dim=1 ) layer_outputs = (select_and_merge_token, layer_outputs[1]) - position_ids = position_ids[:, :len(select_token_idx[0]) + cluster_num] v_token_num = pred_score_vis.sum() + cluster_num - text_token_start = v_token_start + v_token_num + else: - select_token_idx = torch.where(policy == 1)[1].unsqueeze(0) + keep_indexs = torch.where(policy == 1)[1] + select_token_idx = keep_indexs.unsqueeze(0) layer_outputs = (batch_index_select(layer_outputs[0], select_token_idx), layer_outputs[1]) - position_ids = position_ids[:, :len(select_token_idx[0])] v_token_num = pred_score_vis.sum() - text_token_start = v_token_start + v_token_num + text_token_start = v_token_start + v_token_num + position_ids = keep_indexs.unsqueeze(0) new_output = layer_outputs - cache_position = position_ids.detach().clone() + cache_position = position_ids.squeeze(0) + + if attention_mask is not None: + attention_mask = attention_mask[:, :, keep_indexs, keep_indexs] + new_pe0 = position_embeddings[0][:, keep_indexs, :].clone() + new_pe1 = position_embeddings[1][:, keep_indexs, :].clone() + position_embeddings = (new_pe0, new_pe1) pruning_pars['v_token_num'] = v_token_num pruning_pars['text_token_start'] = text_token_start + pruning_pars['position_ids'] = position_ids pruning_pars['cache_position'] = cache_position - pruning_pars['position_embeddings'] = None + pruning_pars['position_embeddings'] = position_embeddings + pruning_pars['attention_mask'] = attention_mask return new_output @prefill_wrapper def read_parameter_hook(module, args, kwargs, pruning_pars): kwargs['position_ids'] = pruning_pars['position_ids'] + kwargs['attention_mask'] = pruning_pars['attention_mask'] kwargs['cache_position'] = pruning_pars['cache_position'] kwargs['position_embeddings'] = pruning_pars['position_embeddings'] @@ -363,7 +388,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): with_kwargs=True ) elif self.model.__class__.__name__ == 'Llava': - self.blocks[block_idx].self_attn.register_forward_pre_hook( + self.blocks[block_idx].register_forward_pre_hook( functools.partial( update_kwargs_hook, pruning_pars=self.pruning_paras, @@ -383,7 +408,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): functools.partial( decoder_attn_hook, pruning_pars=self.pruning_paras, - layer_idx=block_idx, + layer_idx=block_idx ), with_kwargs=True ) @@ -397,17 +422,37 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): ) -layer_dict = {2: 0, 6: 1, 15: 2} - -sparse_token_list_192 = [300, 200, 110] # 2*576 4*300 10*200 16*110 -sparse_token_list_128 = [303, 110, 36] -sparse_token_list_64 = [66, 30, 17] +def update_list(): + global sparse_token_list_192, sparse_token_list_128, sparse_token_list_64 + global prune_flag, merge_flag, sparse_token_dict + + if layer_dict == {2: 0, 6: 1, 15: 2}: # 2*576 4*300 10*200 16*110 + sparse_token_list_192 = [300, 200, 110] + sparse_token_list_128 = [303, 110, 36] + sparse_token_list_64 = [66, 30, 17] + prune_flag, merge_flag = True, True + elif prune_flag and merge_flag: + sparse_token_list_192 = [180] + sparse_token_list_128 = [114] + sparse_token_list_64 = [48] + elif prune_flag: + sparse_token_list_192 = [192] + sparse_token_list_128 = [128] + sparse_token_list_64 = [64] + elif merge_flag: + sparse_token_list_192 = [149] + sparse_token_list_128 = [78] + sparse_token_list_64 = [7] + else: + raise RuntimeError( + 'Both prune_flag and merge_flag are False — sparseVLM is inactive.' + ) -sparse_token_dict = { - 192: sparse_token_list_192, - 128: sparse_token_list_128, - 64: sparse_token_list_64 -} + sparse_token_dict = { + 192: sparse_token_list_192, + 128: sparse_token_list_128, + 64: sparse_token_list_64 + } def attn_postprocess_topk( @@ -567,4 +612,4 @@ def cluster_and_merge(x, cluster_num): source=source.reshape(B * N, C).type(x.dtype)) x_merged = x_merged.reshape(B, cluster_num, C) - return x_merged + return x_merged, index_down