diff --git a/llmc/compression/token_reduction/sparsevlm.py b/llmc/compression/token_reduction/sparsevlm.py index 02cb2c74..aae8f722 100755 --- a/llmc/compression/token_reduction/sparsevlm.py +++ b/llmc/compression/token_reduction/sparsevlm.py @@ -17,6 +17,9 @@ sparse_token_list_192 = [] sparse_token_list_128 = [] sparse_token_list_64 = [] +sparse_token_list_640 = [] +sparse_token_list_320 = [] +sparse_token_list_160 = [] sparse_token_dict = {} @@ -55,7 +58,7 @@ def input_hook(module, args, pruning_paras): pre_prompt_length_list.append(0) pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list - def input_hook_llava(fn, pruning_paras): + def input_hook_llava(fn, pruning_paras, llava_next=False): @wraps(fn) def wrapper(self, *args, **kwargs): if args[0].shape[1] == 1: @@ -81,11 +84,14 @@ def wrapper(self, *args, **kwargs): pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list - return fn(*args, **kwargs) + outs = fn(*args, **kwargs) + if llava_next: + pruning_paras['vision_token_length'] = outs[-1] + return outs return wrapper @prefill_wrapper_model - def register_module_pars(module, args, kwargs, pruning_paras): + def register_module_paras(module, args, kwargs, pruning_paras): pre_prompt_length_list = pruning_paras['pre_prompt_length_list'] hidden_states = kwargs['inputs_embeds'] if hidden_states is None: @@ -227,7 +233,8 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_paras, laye text_token_start, t_token_idx, layer_idx, - retained_tokens + retained_tokens, + pruning_paras['reduction_ratio'] ) if not prune_flag: pred_score_vis = torch.zeros_like(relation_vis_text, dtype=bool) @@ -353,7 +360,8 @@ def read_parameter_hook(module, args, kwargs, pruning_paras): self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( input_hook_llava( self.model.vlm_model.prepare_inputs_labels_for_multimodal, - self.pruning_paras + self.pruning_paras, + llava_next=self.special_config['vision_token_length'] is None ), self.model.vlm_model ) @@ -362,7 +370,7 @@ def read_parameter_hook(module, args, kwargs, pruning_paras): elif self.model.__class__.__name__ == 'Llava': llama_model = self.model.model.model llama_model.register_forward_pre_hook( - functools.partial(register_module_pars, pruning_paras=self.pruning_paras), + functools.partial(register_module_paras, pruning_paras=self.pruning_paras), with_kwargs=True ) @@ -417,6 +425,7 @@ def read_parameter_hook(module, args, kwargs, pruning_paras): def update_list(): global sparse_token_list_192, sparse_token_list_128, sparse_token_list_64 + global sparse_token_list_640, sparse_token_list_320, sparse_token_list_160 global prune_flag, merge_flag, sparse_token_dict if layer_dict == {2: 0, 6: 1, 15: 2}: # 2*576 4*300 10*200 16*110 @@ -428,10 +437,16 @@ def update_list(): sparse_token_list_192 = [180] sparse_token_list_128 = [114] sparse_token_list_64 = [48] + sparse_token_list_640 = [0.1979] + sparse_token_list_320 = [0.0833] + sparse_token_list_160 = [0.0261] elif prune_flag: sparse_token_list_192 = [192] sparse_token_list_128 = [128] sparse_token_list_64 = [64] + sparse_token_list_640 = [0.2222] + sparse_token_list_320 = [0.1111] + sparse_token_list_160 = [0.0555] elif merge_flag: sparse_token_list_192 = [149] sparse_token_list_128 = [78] @@ -444,7 +459,10 @@ def update_list(): sparse_token_dict = { 192: sparse_token_list_192, 128: sparse_token_list_128, - 64: sparse_token_list_64 + 64: sparse_token_list_64, + 640: sparse_token_list_640, + 320: sparse_token_list_320, + 160: sparse_token_list_160 } @@ -455,7 +473,8 @@ def attn_postprocess_topk( text_token_start, t_token_idx, layer_idx, - retained_tokens): + retained_tokens, + reduction_ratio): ''' self_attn_weights: [B, H, L, L] ''' @@ -470,13 +489,17 @@ def attn_postprocess_topk( relation_vis = relation_vis_text s_flag = True # s_flag controls whether token merge is needed. - - sparse_token_list = sparse_token_dict[retained_tokens] - + if retained_tokens in [192, 128, 64]: + sparse_token_list = sparse_token_dict[retained_tokens] + else: + sparse_token_list = sparse_token_dict[round((1 - reduction_ratio) * 2880)] + retained_tokens_prune = sparse_token_list[layer_dict[layer_idx]] + if retained_tokens_prune < 1: + retained_tokens_prune = round(retained_tokens_prune * v_token_num) if v_token_num != 0: mask = torch.zeros_like(relation_vis, dtype=bool) _, indices = torch.topk(relation_vis, min( - sparse_token_list[layer_dict[layer_idx]], v_token_num - 1), dim=1) + retained_tokens_prune, v_token_num - 1), dim=1) mask[0][indices] = 1 else: mask = torch.ones_like(relation_vis_text, dtype=bool) diff --git a/llmc/compression/token_reduction/token_reduction_module.py b/llmc/compression/token_reduction/token_reduction_module.py index 5bd72676..bb3a1c5d 100644 --- a/llmc/compression/token_reduction/token_reduction_module.py +++ b/llmc/compression/token_reduction/token_reduction_module.py @@ -42,7 +42,7 @@ def wrapper(self, *args, **kwargs): message = ( 'To obtain the vision_token_length for LLaVA-1.6, you should append ' - '`image_features.shape[1]` to the return value of the function ' + '`image_features[0].shape[0]` to the return value of the function ' '`prepare_inputs_labels_for_multimodal`, and modify the related code accordingly.' ) outs = fn(*args, **kwargs)