1212from .token_reduction_module import TokenReductionModule
1313from .utils import prefill_wrapper , prefill_wrapper_model
1414
15+ layer_dict = {}
16+
1517
1618@TOKEN_REDUCTION_REGISTRY .register ('SparseVLM' )
1719class SparseVLM (TokenReductionModule ):
@@ -24,6 +26,8 @@ def add_sparse_config(self):
2426 special_config = self .config .get ('special' , {})
2527
2628 self .pruning_loc = special_config .get ('pruning_loc' , [2 , 6 , 15 ])
29+ global layer_dict
30+ layer_dict = {layer : idx for idx , layer in enumerate (self .pruning_loc )}
2731 special_config ['retained_tokens' ] = special_config .get ('retained_tokens' , 192 )
2832 special_config ['init_token_total_shape' ] = special_config .get ('init_token_total_shape' , 668 )
2933 special_config ['generate_process_count' ] = 0
@@ -44,7 +48,8 @@ def input_hook(module, input_args, pruning_pars):
4448 # find the position of the first image token
4549 for seq in input_ids :
4650 image_token_index = (
47- seq == IMAGE_TOKEN_INDEX ).nonzero (as_tuple = True )[0 ]
51+ seq == IMAGE_TOKEN_INDEX
52+ ).nonzero (as_tuple = True )[0 ]
4853 if len (image_token_index ) > 0 :
4954 pre_prompt_length_list .append (image_token_index [0 ].item ())
5055 else :
@@ -95,33 +100,31 @@ def wrapper(self, *args, **kwargs):
95100 @prefill_wrapper_model
96101 def register_module_pars (module , args , kwargs , pruning_pars ):
97102 pre_prompt_length_list = pruning_pars ['pre_prompt_length_list' ]
98- inputs_embeds = kwargs ['inputs_embeds' ]
99- if inputs_embeds is None :
100- inputs_embeds = module .embed_tokens (kwargs ['input_ids' ])
101- hidden_states = inputs_embeds # shape: (B, L, C)
103+ hidden_states = kwargs ['inputs_embeds' ]
104+ if hidden_states is None :
105+ hidden_states = module .embed_tokens (kwargs ['input_ids' ])
102106
103107 B , L , _ = hidden_states .shape
104108 pruning_pars ['B' ] = B
105109 init_n = pruning_pars ['init_token_total_shape' ] + \
106- pruning_pars ['generate_process_count' ] # 668
110+ pruning_pars ['generate_process_count' ] # 668
107111 pruning_pars ['prev_decision' ] = torch .ones (
108112 B , init_n , 1 , dtype = hidden_states .dtype , device = hidden_states .device )
109113 pruning_pars ['policy' ] = torch .ones (
110114 B , init_n , 1 , dtype = hidden_states .dtype , device = hidden_states .device )
111115
112- pruning_pars ['v_token_start' ] = pre_prompt_length_list [0 ] if len (
113- pre_prompt_length_list ) != 0 else 0 # 35
114- v_token_start = pruning_pars ['v_token_start' ]
115- pruning_pars ['text_token_start' ] = pruning_pars ['v_token_start' ] + \
116- pruning_pars ['image_shape' ] # 35 + 576 = 611
117- text_token_start = pruning_pars ['text_token_start' ]
116+ v_token_start = pre_prompt_length_list [0 ] if len (
117+ pre_prompt_length_list ) != 0 else 0
118+ text_token_start = v_token_start + pruning_pars ['image_shape' ]
119+ pruning_pars ['v_token_start' ] = v_token_start # 35
120+ pruning_pars ['text_token_start' ] = text_token_start # 611
118121 pruning_pars ['v_token_num' ] = pruning_pars ['image_shape' ] # 576
119122
120123 if (len (pre_prompt_length_list ) != 0 and hidden_states .shape [1 ] != 1 ):
121124 v_t = hidden_states [:, v_token_start : text_token_start , :]
122125 t_t = hidden_states [:, text_token_start :, :]
123- m_v_t = v_t @ t_t .transpose (1 , 2 ) # [1, 576, 53] # 52?
124- m_v_t = m_v_t .softmax (2 ).mean (1 ) # [1, 53 ]
126+ m_v_t = v_t @ t_t .transpose (1 , 2 ) # [1, 576, 52]
127+ m_v_t = m_v_t .softmax (2 ).mean (1 ) # [1, 52 ]
125128 pruning_pars ['t_token_idx' ] = torch .where (m_v_t > m_v_t .mean ())
126129
127130 return args , kwargs
@@ -134,10 +137,20 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
134137 kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
135138 return args , kwargs
136139
137- def get_attn_logits_hook (module , args , kwargs , pruning_pars , layer_idx ):
140+ def update_kwargs_hook (module , args , kwargs , pruning_pars , layer_idx ):
138141
139142 if len (kwargs ['position_ids' ][0 ]) == 1 :
140143 return args , kwargs
144+ if layer_idx != self .pruning_loc [0 ]:
145+ kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
146+ kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
147+ kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
148+ return args , kwargs
149+
150+ def get_attn_logits_hook (module , args , kwargs , layer_outs , pruning_pars , layer_idx ):
151+
152+ if len (kwargs ['position_ids' ][0 ]) == 1 :
153+ return layer_outs
141154
142155 from transformers .models .llama .modeling_llama import \
143156 apply_rotary_pos_emb
@@ -150,8 +163,7 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
150163 hidden_states = kwargs ['hidden_states' ]
151164 position_embeddings = kwargs ['position_embeddings' ]
152165 position_ids = kwargs ['position_ids' ]
153- past_key_value = kwargs ['past_key_value' ]
154- cache_position = kwargs ['cache_position' ]
166+ past_key_value = layer_outs [2 ]
155167 attention_mask = kwargs ['attention_mask' ]
156168
157169 t_token_idx = pruning_pars ['t_token_idx' ]
@@ -179,12 +191,8 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
179191
180192 query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
181193 if past_key_value is not None :
182- temp_cache = copy .deepcopy (past_key_value )
183- cache_kwargs = {'sin' : sin , 'cos' : cos , 'cache_position' : cache_position }
184- key_states , value_states = temp_cache .update (
185- key_states , value_states ,
186- layer_idx , cache_kwargs
187- )
194+ key_states = past_key_value .key_cache [layer_idx ]
195+ value_states = past_key_value .value_cache [layer_idx ]
188196 t_token_idx = t_token_idx [1 ] + v_token_start + v_token_num
189197 L , S = query_states .size (- 2 ), key_states .size (- 2 )
190198 scale_factor = 1 / math .sqrt (query_states .size (- 1 ))
@@ -201,19 +209,16 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
201209
202210 pruning_pars ['attn_logits' ] = attn_logits
203211
204- return args , kwargs
212+ return layer_outs
205213
206214 @prefill_wrapper
207215 def decoder_attn_hook (module , inputs , kwargs , layer_outputs , pruning_pars , layer_idx ):
208216
209- # pruning_pars['attn_logits'] 对llavaHf运行存在BUG,
210- # 使用layer_outputs[1]运行llavaHf无问题,但精度没对上
211- # llava:attn_logits = pruning_pars['attn_logits']
212- # llavahf:attn_logits = layer_outputs[1]
213217 if 'attn_logits' not in pruning_pars :
214218 attn_logits = layer_outputs [1 ]
215219 else :
216220 attn_logits = pruning_pars ['attn_logits' ]
221+ merge_flag = pruning_pars ['merge_flag' ]
217222 v_token_start = pruning_pars ['v_token_start' ]
218223 v_token_num = pruning_pars ['v_token_num' ]
219224 text_token_start = pruning_pars ['text_token_start' ]
@@ -255,7 +260,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
255260 total_sparse_token_idx = torch .where (policy == 0 )[1 ].unsqueeze (0 )
256261
257262 # merge and cluster
258- if s_flag and total_sparse_token_idx .shape [1 ] > 0 :
263+ if s_flag and merge_flag and total_sparse_token_idx .shape [1 ] > 0 :
259264 total_sparse_token = batch_index_select (layer_outputs [0 ], total_sparse_token_idx )
260265
261266 merge_token_idx_stage1 = torch .where (pred_score_vis == 0 )[1 ]
@@ -359,6 +364,14 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
359364 )
360365 elif self .model .__class__ .__name__ == 'Llava' :
361366 self .blocks [block_idx ].self_attn .register_forward_pre_hook (
367+ functools .partial (
368+ update_kwargs_hook ,
369+ pruning_pars = self .pruning_paras ,
370+ layer_idx = block_idx ,
371+ ),
372+ with_kwargs = True
373+ )
374+ self .blocks [block_idx ].self_attn .register_forward_hook (
362375 functools .partial (
363376 get_attn_logits_hook ,
364377 pruning_pars = self .pruning_paras ,
0 commit comments