1717sparse_token_list_192 = []
1818sparse_token_list_128 = []
1919sparse_token_list_64 = []
20+ sparse_token_list_640 = []
21+ sparse_token_list_320 = []
22+ sparse_token_list_160 = []
2023sparse_token_dict = {}
2124
2225
@@ -55,7 +58,7 @@ def input_hook(module, args, pruning_paras):
5558 pre_prompt_length_list .append (0 )
5659 pruning_paras ['pre_prompt_length_list' ] = pre_prompt_length_list
5760
58- def input_hook_llava (fn , pruning_paras ):
61+ def input_hook_llava (fn , pruning_paras , llava_next = False ):
5962 @wraps (fn )
6063 def wrapper (self , * args , ** kwargs ):
6164 if args [0 ].shape [1 ] == 1 :
@@ -81,11 +84,14 @@ def wrapper(self, *args, **kwargs):
8184
8285 pruning_paras ['pre_prompt_length_list' ] = pre_prompt_length_list
8386
84- return fn (* args , ** kwargs )
87+ outs = fn (* args , ** kwargs )
88+ if llava_next :
89+ pruning_paras ['vision_token_length' ] = outs [- 1 ]
90+ return outs
8591 return wrapper
8692
8793 @prefill_wrapper_model
88- def register_module_pars (module , args , kwargs , pruning_paras ):
94+ def register_module_paras (module , args , kwargs , pruning_paras ):
8995 pre_prompt_length_list = pruning_paras ['pre_prompt_length_list' ]
9096 hidden_states = kwargs ['inputs_embeds' ]
9197 if hidden_states is None :
@@ -227,7 +233,8 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_paras, laye
227233 text_token_start ,
228234 t_token_idx ,
229235 layer_idx ,
230- retained_tokens
236+ retained_tokens ,
237+ pruning_paras ['reduction_ratio' ]
231238 )
232239 if not prune_flag :
233240 pred_score_vis = torch .zeros_like (relation_vis_text , dtype = bool )
@@ -353,7 +360,8 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
353360 self .model .vlm_model .prepare_inputs_labels_for_multimodal = MethodType (
354361 input_hook_llava (
355362 self .model .vlm_model .prepare_inputs_labels_for_multimodal ,
356- self .pruning_paras
363+ self .pruning_paras ,
364+ llava_next = self .special_config ['vision_token_length' ] is None
357365 ), self .model .vlm_model
358366 )
359367
@@ -362,7 +370,7 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
362370 elif self .model .__class__ .__name__ == 'Llava' :
363371 llama_model = self .model .model .model
364372 llama_model .register_forward_pre_hook (
365- functools .partial (register_module_pars , pruning_paras = self .pruning_paras ),
373+ functools .partial (register_module_paras , pruning_paras = self .pruning_paras ),
366374 with_kwargs = True
367375 )
368376
@@ -417,6 +425,7 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
417425
418426def update_list ():
419427 global sparse_token_list_192 , sparse_token_list_128 , sparse_token_list_64
428+ global sparse_token_list_640 , sparse_token_list_320 , sparse_token_list_160
420429 global prune_flag , merge_flag , sparse_token_dict
421430
422431 if layer_dict == {2 : 0 , 6 : 1 , 15 : 2 }: # 2*576 4*300 10*200 16*110
@@ -428,10 +437,16 @@ def update_list():
428437 sparse_token_list_192 = [180 ]
429438 sparse_token_list_128 = [114 ]
430439 sparse_token_list_64 = [48 ]
440+ sparse_token_list_640 = [0.1979 ]
441+ sparse_token_list_320 = [0.0833 ]
442+ sparse_token_list_160 = [0.0261 ]
431443 elif prune_flag :
432444 sparse_token_list_192 = [192 ]
433445 sparse_token_list_128 = [128 ]
434446 sparse_token_list_64 = [64 ]
447+ sparse_token_list_640 = [0.2222 ]
448+ sparse_token_list_320 = [0.1111 ]
449+ sparse_token_list_160 = [0.0555 ]
435450 elif merge_flag :
436451 sparse_token_list_192 = [149 ]
437452 sparse_token_list_128 = [78 ]
@@ -444,7 +459,10 @@ def update_list():
444459 sparse_token_dict = {
445460 192 : sparse_token_list_192 ,
446461 128 : sparse_token_list_128 ,
447- 64 : sparse_token_list_64
462+ 64 : sparse_token_list_64 ,
463+ 640 : sparse_token_list_640 ,
464+ 320 : sparse_token_list_320 ,
465+ 160 : sparse_token_list_160
448466 }
449467
450468
@@ -455,7 +473,8 @@ def attn_postprocess_topk(
455473 text_token_start ,
456474 t_token_idx ,
457475 layer_idx ,
458- retained_tokens ):
476+ retained_tokens ,
477+ reduction_ratio ):
459478 '''
460479 self_attn_weights: [B, H, L, L]
461480 '''
@@ -470,13 +489,17 @@ def attn_postprocess_topk(
470489
471490 relation_vis = relation_vis_text
472491 s_flag = True # s_flag controls whether token merge is needed.
473-
474- sparse_token_list = sparse_token_dict [retained_tokens ]
475-
492+ if retained_tokens in [192 , 128 , 64 ]:
493+ sparse_token_list = sparse_token_dict [retained_tokens ]
494+ else :
495+ sparse_token_list = sparse_token_dict [round ((1 - reduction_ratio ) * 2880 )]
496+ retained_tokens_prune = sparse_token_list [layer_dict [layer_idx ]]
497+ if retained_tokens_prune < 1 :
498+ retained_tokens_prune = round (retained_tokens_prune * v_token_num )
476499 if v_token_num != 0 :
477500 mask = torch .zeros_like (relation_vis , dtype = bool )
478501 _ , indices = torch .topk (relation_vis , min (
479- sparse_token_list [ layer_dict [ layer_idx ]] , v_token_num - 1 ), dim = 1 )
502+ retained_tokens_prune , v_token_num - 1 ), dim = 1 )
480503 mask [0 ][indices ] = 1
481504 else :
482505 mask = torch .ones_like (relation_vis_text , dtype = bool )
0 commit comments