1+ import copy
12import functools
3+ import math
24from functools import wraps
35from types import MethodType
46
@@ -66,22 +68,26 @@ def wrapper(self, *args, **kwargs):
6668 input_ids = args [0 ]
6769 attention_mask = args [2 ]
6870
71+ if attention_mask is None :
72+ attention_mask = torch .ones_like (input_ids , dtype = torch .bool )
73+ else :
74+ attention_mask = attention_mask .bool ()
75+
6976 pre_prompt_length_list = []
7077 for cur_input_ids , cur_attention_mask in zip (input_ids , attention_mask ):
7178 seq = cur_input_ids [cur_attention_mask ]
72- image_token_index = torch .where (seq == IMAGE_TOKEN_INDEX )[0 ].tolist ()
73- if len (image_token_index ) > 0 :
74- pre_prompt_length_list .append (image_token_index [0 ])
75- else :
76- pre_prompt_length_list .append (0 )
79+ image_token_index = (
80+ [- 1 ]
81+ + torch .where (seq == IMAGE_TOKEN_INDEX )[0 ].tolist ()
82+ + [seq .shape [0 ]]
83+ )
84+ pre_prompt_length_list .append (image_token_index [1 ])
85+
7786 pruning_paras ['pre_prompt_length_list' ] = pre_prompt_length_list
7887
7988 outputs = fn (* args , ** kwargs )
8089
81- token_length_list = []
82- for cur_attention_mask in outputs [2 ]:
83- token_length_list .append (cur_attention_mask .sum ().item ())
84- pruning_paras ['token_length_list' ] = token_length_list
90+ pruning_paras ['token_length_list' ] = outputs [2 ].sum (dim = 1 ).tolist ()
8591
8692 return outputs
8793 return wrapper
@@ -128,14 +134,90 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
128134 kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
129135 return args , kwargs
130136
137+ def get_attn_logits_hook (module , args , kwargs , pruning_pars , layer_idx ):
138+
139+ if len (kwargs ['position_ids' ][0 ]) == 1 :
140+ return args , kwargs
141+
142+ from transformers .models .llama .modeling_llama import \
143+ apply_rotary_pos_emb
144+
145+ if layer_idx != self .pruning_loc [0 ]:
146+ kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
147+ kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
148+ kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
149+
150+ hidden_states = kwargs ['hidden_states' ]
151+ position_embeddings = kwargs ['position_embeddings' ]
152+ position_ids = kwargs ['position_ids' ]
153+ past_key_value = kwargs ['past_key_value' ]
154+ cache_position = kwargs ['cache_position' ]
155+ attention_mask = kwargs ['attention_mask' ]
156+
157+ t_token_idx = pruning_pars ['t_token_idx' ]
158+ v_token_start = pruning_pars ['v_token_start' ]
159+ v_token_num = pruning_pars ['v_token_num' ]
160+
161+ bsz , q_len , _ = hidden_states .size ()
162+ query_states = module .q_proj (hidden_states )
163+ key_states = module .k_proj (hidden_states )
164+ value_states = module .v_proj (hidden_states )
165+ query_states = query_states .view (
166+ bsz , q_len , module .num_heads , module .head_dim
167+ ).transpose (1 , 2 )
168+ key_states = key_states .view (
169+ bsz , q_len , module .num_key_value_heads , module .head_dim
170+ ).transpose (1 , 2 )
171+ value_states = value_states .view (
172+ bsz , q_len , module .num_key_value_heads , module .head_dim
173+ ).transpose (1 , 2 )
174+
175+ if position_embeddings is None :
176+ cos , sin = module .rotary_emb (value_states , position_ids )
177+ else :
178+ cos , sin = position_embeddings
179+
180+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
181+ if past_key_value is not None :
182+ temp_cache = copy .deepcopy (past_key_value )
183+ cache_kwargs = {'sin' : sin , 'cos' : cos , 'cache_position' : cache_position }
184+ key_states , value_states = temp_cache .update (
185+ key_states , value_states ,
186+ layer_idx , cache_kwargs
187+ )
188+ t_token_idx = t_token_idx [1 ] + v_token_start + v_token_num
189+ L , S = query_states .size (- 2 ), key_states .size (- 2 )
190+ scale_factor = 1 / math .sqrt (query_states .size (- 1 ))
191+ attn_bias = torch .zeros (L , S , dtype = query_states .dtype )
192+ if module .is_causal :
193+ assert attention_mask is None
194+ temp_mask = torch .ones (L , S , dtype = torch .bool ).tril (diagonal = 0 )
195+ attn_bias .masked_fill_ (temp_mask .logical_not (), float ('-inf' ))
196+ attn_bias .to (query_states .dtype )
197+
198+ attn_logits = query_states @ key_states .transpose (2 , 3 ) * scale_factor
199+ attn_logits += attn_bias .to (query_states .device )
200+ attn_logits = torch .softmax (attn_logits , dim = - 1 )
201+
202+ pruning_pars ['attn_logits' ] = attn_logits
203+
204+ return args , kwargs
205+
131206 @prefill_wrapper
132207 def decoder_attn_hook (module , inputs , kwargs , layer_outputs , pruning_pars , layer_idx ):
133208
134- attn_logits = layer_outputs [1 ]
209+ # pruning_pars['attn_logits'] 对llavaHf运行存在BUG,
210+ # 使用layer_outputs[1]运行llavaHf无问题,但精度没对上
211+ # llava:attn_logits = pruning_pars['attn_logits']
212+ # llavahf:attn_logits = layer_outputs[1]
213+ if 'attn_logits' not in pruning_pars :
214+ attn_logits = layer_outputs [1 ]
215+ else :
216+ attn_logits = pruning_pars ['attn_logits' ]
135217 v_token_start = pruning_pars ['v_token_start' ]
218+ v_token_num = pruning_pars ['v_token_num' ]
136219 text_token_start = pruning_pars ['text_token_start' ]
137220 t_token_idx = pruning_pars ['t_token_idx' ]
138- v_token_num = pruning_pars ['v_token_num' ]
139221 retained_tokens = pruning_pars ['retained_tokens' ]
140222 B = pruning_pars ['B' ]
141223 pre_prompt_length_list = pruning_pars ['pre_prompt_length_list' ]
@@ -145,10 +227,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
145227 pruning_pars ['position_ids' ] = position_ids
146228 else :
147229 position_ids = pruning_pars ['position_ids' ]
148-
149230 hidden_states = inputs [0 ] # [B, L, D]
150- pre_prompt_length_list = pruning_pars ['pre_prompt_length_list' ]
151- image_shape = pruning_pars ['image_shape' ]
152231
153232 pred_score_vis , s_flag , relation_vis_text = attn_postprocess_topk (
154233 attn_logits ,
@@ -177,7 +256,6 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
177256
178257 # merge and cluster
179258 if s_flag and total_sparse_token_idx .shape [1 ] > 0 :
180- total_sparse_token_idx = torch .where (policy == 0 )[1 ].unsqueeze (0 )
181259 total_sparse_token = batch_index_select (layer_outputs [0 ], total_sparse_token_idx )
182260
183261 merge_token_idx_stage1 = torch .where (pred_score_vis == 0 )[1 ]
@@ -208,20 +286,17 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
208286 )
209287 layer_outputs = (select_and_merge_token , layer_outputs [1 ])
210288 position_ids = position_ids [:, :len (select_token_idx [0 ]) + cluster_num ]
211- # prev_decision = policy
212289 v_token_num = pred_score_vis .sum () + cluster_num
213290 text_token_start = v_token_start + v_token_num
214291 else :
215292 select_token_idx = torch .where (policy == 1 )[1 ].unsqueeze (0 )
216293 layer_outputs = (batch_index_select (layer_outputs [0 ], select_token_idx ),
217294 layer_outputs [1 ])
218295 position_ids = position_ids [:, :len (select_token_idx [0 ])]
219- # prev_decision = policy
220296 v_token_num = pred_score_vis .sum ()
221297 text_token_start = v_token_start + v_token_num
222298
223299 new_output = layer_outputs
224- # hidden_states = layer_outputs[0]
225300 cache_position = position_ids .detach ().clone ()
226301
227302 pruning_pars ['v_token_num' ] = v_token_num
@@ -273,14 +348,24 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
273348
274349 for block_idx in range (sorted_pruning_locs [0 ], total_layers ):
275350 if block_idx in sorted_pruning_locs :
276- self .blocks [block_idx ].register_forward_pre_hook (
277- functools .partial (
278- update_output_attentions_hook ,
279- pruning_pars = self .pruning_paras ,
280- layer_idx = block_idx ,
281- ),
282- with_kwargs = True
283- )
351+ if self .model .__class__ .__name__ == 'LlavaHf' :
352+ self .blocks [block_idx ].register_forward_pre_hook (
353+ functools .partial (
354+ update_output_attentions_hook ,
355+ pruning_pars = self .pruning_paras ,
356+ layer_idx = block_idx ,
357+ ),
358+ with_kwargs = True
359+ )
360+ elif self .model .__class__ .__name__ == 'Llava' :
361+ self .blocks [block_idx ].self_attn .register_forward_pre_hook (
362+ functools .partial (
363+ get_attn_logits_hook ,
364+ pruning_pars = self .pruning_paras ,
365+ layer_idx = block_idx ,
366+ ),
367+ with_kwargs = True
368+ )
284369 self .blocks [block_idx ].register_forward_hook (
285370 functools .partial (
286371 decoder_attn_hook ,
0 commit comments