11import functools
2+ from functools import wraps
3+ from types import MethodType
24
35import einops as ein
46import torch
@@ -27,7 +29,7 @@ def add_sparse_config(self):
2729 special_config ['token_length_list' ] = []
2830 special_config ['image_shape' ] = self .model .pruning_config ['image_token_length' ]
2931 special_config ['image_token_index' ] = self .model .pruning_config ['image_token_index' ]
30- self .model . model . parameters = special_config
32+ self .pruning_paras = special_config
3133
3234 def register_reduction_modules (self ):
3335 @prefill_wrapper
@@ -52,16 +54,48 @@ def input_hook(module, input_args, pruning_pars):
5254
5355 return input_args
5456
57+ def input_hook_llava (fn , pruning_paras ):
58+ @wraps (fn )
59+ def wrapper (self , * args , ** kwargs ):
60+ if len (args ) == 0 :
61+ return fn (* args , ** kwargs )
62+ input_args = args [0 ]
63+ if hasattr (input_args [0 ], 'shape' ) and input_args [0 ].shape [0 ] == 1 :
64+ return fn (* args , ** kwargs )
65+
66+ input_ids = args [0 ]
67+ attention_mask = args [2 ]
68+
69+ pre_prompt_length_list = []
70+ for cur_input_ids , cur_attention_mask in zip (input_ids , attention_mask ):
71+ seq = cur_input_ids [cur_attention_mask ]
72+ image_token_index = torch .where (seq == IMAGE_TOKEN_INDEX )[0 ].tolist ()
73+ if len (image_token_index ) > 0 :
74+ pre_prompt_length_list .append (image_token_index [0 ])
75+ else :
76+ pre_prompt_length_list .append (0 )
77+ pruning_paras ['pre_prompt_length_list' ] = pre_prompt_length_list
78+
79+ outputs = fn (* args , ** kwargs )
80+
81+ token_length_list = []
82+ for cur_attention_mask in outputs [2 ]:
83+ token_length_list .append (cur_attention_mask .sum ().item ())
84+ pruning_paras ['token_length_list' ] = token_length_list
85+
86+ return outputs
87+ return wrapper
88+
5589 @prefill_wrapper_model
5690 def register_module_pars (module , args , kwargs , pruning_pars ):
5791 pre_prompt_length_list = pruning_pars ['pre_prompt_length_list' ]
5892 inputs_embeds = kwargs ['inputs_embeds' ]
5993 if inputs_embeds is None :
60- inputs_embeds = self .embed_tokens (kwargs ['input_ids' ])
94+ inputs_embeds = module .embed_tokens (kwargs ['input_ids' ])
6195 hidden_states = inputs_embeds # shape: (B, L, C)
6296
63- pruning_pars [ 'B' ] , L , _ = hidden_states .shape
64- B = pruning_pars ['B' ]
97+ B , L , _ = hidden_states .shape
98+ pruning_pars ['B' ] = B
6599 init_n = pruning_pars ['init_token_total_shape' ] + \
66100 pruning_pars ['generate_process_count' ] # 668
67101 pruning_pars ['prev_decision' ] = torch .ones (
@@ -80,7 +114,7 @@ def register_module_pars(module, args, kwargs, pruning_pars):
80114 if (len (pre_prompt_length_list ) != 0 and hidden_states .shape [1 ] != 1 ):
81115 v_t = hidden_states [:, v_token_start : text_token_start , :]
82116 t_t = hidden_states [:, text_token_start :, :]
83- m_v_t = v_t @ t_t .transpose (1 , 2 ) # [1, 576, 53]
117+ m_v_t = v_t @ t_t .transpose (1 , 2 ) # [1, 576, 53] # 52?
84118 m_v_t = m_v_t .softmax (2 ).mean (1 ) # [1, 53]
85119 pruning_pars ['t_token_idx' ] = torch .where (m_v_t > m_v_t .mean ())
86120
@@ -206,17 +240,31 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
206240
207241 return args , kwargs
208242
209- self .model .embed_tokens .register_forward_pre_hook (
210- functools .partial (
211- input_hook ,
212- pruning_pars = self .model .model .parameters
243+ if self .model .__class__ .__name__ == 'LlavaHf' :
244+ self .model .embed_tokens .register_forward_pre_hook (
245+ functools .partial (
246+ input_hook ,
247+ pruning_pars = self .pruning_paras
248+ )
249+ )
250+ elif self .model .__class__ .__name__ == 'Llava' :
251+ from llava .constants import IMAGE_TOKEN_INDEX
252+ hook_fn = input_hook_llava (
253+ self .model .vlm_model .prepare_inputs_labels_for_multimodal ,
254+ self .pruning_paras
255+ )
256+ self .model .vlm_model .prepare_inputs_labels_for_multimodal = MethodType (
257+ hook_fn , self .model .vlm_model
213258 )
214- )
215259
216- self .model .model .register_forward_pre_hook (
260+ if self .model .__class__ .__name__ == 'LlavaHf' :
261+ llama_model = self .model .model
262+ elif self .model .__class__ .__name__ == 'Llava' :
263+ llama_model = self .model .model .model
264+ llama_model .register_forward_pre_hook (
217265 functools .partial (
218266 register_module_pars ,
219- pruning_pars = self .model . model . parameters ),
267+ pruning_pars = self .pruning_paras ),
220268 with_kwargs = True
221269 )
222270
@@ -228,15 +276,15 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
228276 self .blocks [block_idx ].register_forward_pre_hook (
229277 functools .partial (
230278 update_output_attentions_hook ,
231- pruning_pars = self .model . model . parameters ,
279+ pruning_pars = self .pruning_paras ,
232280 layer_idx = block_idx ,
233281 ),
234282 with_kwargs = True
235283 )
236284 self .blocks [block_idx ].register_forward_hook (
237285 functools .partial (
238286 decoder_attn_hook ,
239- pruning_pars = self .model . model . parameters ,
287+ pruning_pars = self .pruning_paras ,
240288 layer_idx = block_idx ,
241289 ),
242290 with_kwargs = True
@@ -245,7 +293,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
245293 self .blocks [block_idx ].register_forward_pre_hook (
246294 functools .partial (
247295 read_parameter_hook ,
248- pruning_pars = self .model . model . parameters
296+ pruning_pars = self .pruning_paras
249297 ),
250298 with_kwargs = True
251299 )
@@ -278,6 +326,7 @@ def attn_postprocess_topk(
278326 self_attn_weights = self_attn_weights .mean (1 ) # B, L[Q], L[K]
279327
280328 t_token_idx = t_token_idx [1 ] + text_token_start
329+
281330 relation_vis_text = self_attn_weights [:, t_token_idx ,
282331 v_token_start : v_token_start + v_token_num ] # B, L2, L1
283332
0 commit comments