1313from .utils import prefill_wrapper , prefill_wrapper_model
1414
1515layer_dict = {}
16+ prune_flag = True
17+ merge_flag = True
18+ sparse_token_list_192 = []
19+ sparse_token_list_128 = []
20+ sparse_token_list_64 = []
21+ sparse_token_dict = {}
1622
1723
1824@TOKEN_REDUCTION_REGISTRY .register ('SparseVLM' )
@@ -26,13 +32,13 @@ def add_sparse_config(self):
2632 special_config = self .config .get ('special' , {})
2733
2834 self .pruning_loc = special_config .get ('pruning_loc' , [2 , 6 , 15 ])
29- global layer_dict
35+ global layer_dict , prune_flag , merge_flag
3036 layer_dict = {layer : idx for idx , layer in enumerate (self .pruning_loc )}
37+ prune_flag = special_config .get ('prune_flag' , True )
38+ merge_flag = special_config .get ('merge_flag' , True )
39+ update_list ()
3140 special_config ['retained_tokens' ] = special_config .get ('retained_tokens' , 192 )
32- special_config ['init_token_total_shape' ] = special_config .get ('init_token_total_shape' , 668 )
33- special_config ['generate_process_count' ] = 0
3441 special_config ['pre_prompt_length_list' ] = []
35- special_config ['token_length_list' ] = []
3642 special_config ['image_shape' ] = self .model .pruning_config ['image_token_length' ]
3743 special_config ['image_token_index' ] = self .model .pruning_config ['image_token_index' ]
3844 self .pruning_paras = special_config
@@ -42,7 +48,6 @@ def register_reduction_modules(self):
4248 def input_hook (module , input_args , pruning_pars ):
4349 input_ids = input_args [0 ]
4450 pre_prompt_length_list = []
45- token_length_list = []
4651 IMAGE_TOKEN_INDEX = pruning_pars ['image_token_index' ]
4752
4853 # find the position of the first image token
@@ -54,10 +59,7 @@ def input_hook(module, input_args, pruning_pars):
5459 pre_prompt_length_list .append (image_token_index [0 ].item ())
5560 else :
5661 pre_prompt_length_list .append (0 )
57- token_length_list .append (seq .shape [0 ])
58-
5962 pruning_pars ['pre_prompt_length_list' ] = pre_prompt_length_list
60- pruning_pars ['token_length_list' ] = token_length_list
6163
6264 return input_args
6365
@@ -90,11 +92,7 @@ def wrapper(self, *args, **kwargs):
9092
9193 pruning_paras ['pre_prompt_length_list' ] = pre_prompt_length_list
9294
93- outputs = fn (* args , ** kwargs )
94-
95- pruning_paras ['token_length_list' ] = outputs [2 ].sum (dim = 1 ).tolist ()
96-
97- return outputs
95+ return fn (* args , ** kwargs )
9896 return wrapper
9997
10098 @prefill_wrapper_model
@@ -106,12 +104,6 @@ def register_module_pars(module, args, kwargs, pruning_pars):
106104
107105 B , L , _ = hidden_states .shape
108106 pruning_pars ['B' ] = B
109- init_n = pruning_pars ['init_token_total_shape' ] + \
110- pruning_pars ['generate_process_count' ] # 668
111- pruning_pars ['prev_decision' ] = torch .ones (
112- B , init_n , 1 , dtype = hidden_states .dtype , device = hidden_states .device )
113- pruning_pars ['policy' ] = torch .ones (
114- B , init_n , 1 , dtype = hidden_states .dtype , device = hidden_states .device )
115107
116108 v_token_start = pre_prompt_length_list [0 ] if len (
117109 pre_prompt_length_list ) != 0 else 0
@@ -123,8 +115,8 @@ def register_module_pars(module, args, kwargs, pruning_pars):
123115 if (len (pre_prompt_length_list ) != 0 and hidden_states .shape [1 ] != 1 ):
124116 v_t = hidden_states [:, v_token_start : text_token_start , :]
125117 t_t = hidden_states [:, text_token_start :, :]
126- m_v_t = v_t @ t_t .transpose (1 , 2 ) # [1, 576, 52]
127- m_v_t = m_v_t .softmax (2 ).mean (1 ) # [1, 52]
118+ m_v_t = v_t @ t_t .transpose (1 , 2 )
119+ m_v_t = m_v_t .softmax (2 ).mean (1 )
128120 pruning_pars ['t_token_idx' ] = torch .where (m_v_t > m_v_t .mean ())
129121
130122 return args , kwargs
@@ -133,6 +125,7 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
133125 kwargs ['output_attentions' ] = True
134126 if layer_idx != self .pruning_loc [0 ]:
135127 kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
128+ kwargs ['attention_mask' ] = pruning_pars ['attention_mask' ]
136129 kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
137130 kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
138131 return args , kwargs
@@ -143,8 +136,14 @@ def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx):
143136 return args , kwargs
144137 if layer_idx != self .pruning_loc [0 ]:
145138 kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
139+ kwargs ['attention_mask' ] = pruning_pars ['attention_mask' ]
146140 kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
147141 kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
142+ else :
143+ pruning_pars ['position_ids' ] = kwargs ['position_ids' ]
144+ pruning_pars ['attention_mask' ] = kwargs ['attention_mask' ]
145+ pruning_pars ['cache_position' ] = kwargs ['cache_position' ]
146+ pruning_pars ['position_embeddings' ] = kwargs ['position_embeddings' ]
148147 return args , kwargs
149148
150149 def get_attn_logits_hook (module , args , kwargs , layer_outs , pruning_pars , layer_idx ):
@@ -155,11 +154,6 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i
155154 from transformers .models .llama .modeling_llama import \
156155 apply_rotary_pos_emb
157156
158- if layer_idx != self .pruning_loc [0 ]:
159- kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
160- kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
161- kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
162-
163157 hidden_states = kwargs ['hidden_states' ]
164158 position_embeddings = kwargs ['position_embeddings' ]
165159 position_ids = kwargs ['position_ids' ]
@@ -215,9 +209,10 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i
215209 def decoder_attn_hook (module , inputs , kwargs , layer_outputs , pruning_pars , layer_idx ):
216210
217211 if 'attn_logits' not in pruning_pars :
218- attn_logits = layer_outputs [1 ]
212+ attn_logits = layer_outputs [1 ] # for LlavaHf
219213 else :
220214 attn_logits = pruning_pars ['attn_logits' ]
215+ prune_flag = pruning_pars .get ('prune_flag' , True )
221216 merge_flag = pruning_pars ['merge_flag' ]
222217 v_token_start = pruning_pars ['v_token_start' ]
223218 v_token_num = pruning_pars ['v_token_num' ]
@@ -227,13 +222,11 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
227222 B = pruning_pars ['B' ]
228223 pre_prompt_length_list = pruning_pars ['pre_prompt_length_list' ]
229224 image_shape = pruning_pars ['image_shape' ]
230- if layer_idx == self .pruning_loc [0 ]:
231- position_ids = kwargs ['position_ids' ]
232- pruning_pars ['position_ids' ] = position_ids
233- else :
234- position_ids = pruning_pars ['position_ids' ]
235- hidden_states = inputs [0 ] # [B, L, D]
236225
226+ attention_mask = kwargs ['attention_mask' ]
227+ position_embeddings = kwargs ['position_embeddings' ]
228+
229+ hidden_states = inputs [0 ] # [B, L, D]
237230 pred_score_vis , s_flag , relation_vis_text = attn_postprocess_topk (
238231 attn_logits ,
239232 v_token_start ,
@@ -243,7 +236,8 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
243236 layer_idx ,
244237 retained_tokens
245238 )
246-
239+ if not prune_flag :
240+ pred_score_vis = torch .zeros_like (relation_vis_text , dtype = bool )
247241 policy = torch .ones (B , hidden_states .shape [1 ], dtype = hidden_states .dtype ,
248242 device = hidden_states .device )
249243 policy [:, v_token_start :text_token_start ] = \
@@ -261,60 +255,91 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
261255
262256 # merge and cluster
263257 if s_flag and merge_flag and total_sparse_token_idx .shape [1 ] > 0 :
264- total_sparse_token = batch_index_select (layer_outputs [0 ], total_sparse_token_idx )
258+ total_sparse_token = batch_index_select (
259+ layer_outputs [0 ], total_sparse_token_idx
260+ )
265261
266262 merge_token_idx_stage1 = torch .where (pred_score_vis == 0 )[1 ]
267263 merge_token_stage1 = relation_vis_text [0 ][merge_token_idx_stage1 ]
268- merge_token_num_stage1 = int (merge_token_idx_stage1 .shape [0 ] * 0.3 ) + 1 # Top 30%
264+ if prune_flag :
265+ merge_token_num_stage1 = int (merge_token_idx_stage1 .shape [0 ] * 0.3 ) + 1
266+ else :
267+ merge_token_num_stage1 = (
268+ merge_token_idx_stage1 .shape [0 ]
269+ - sparse_token_dict [retained_tokens ][layer_dict [layer_idx ]]
270+ )
269271 merge_token_stage2_idx = merge_token_stage1 .topk (merge_token_num_stage1 )[1 ]
272+ if not prune_flag :
273+ all_idx = torch .arange (
274+ merge_token_stage1 .size (0 ),
275+ device = merge_token_stage1 .device
276+ )
277+ non_topk_idx = all_idx [~ torch .isin (all_idx , merge_token_stage2_idx )]
278+ pred_score_vis [0 ][non_topk_idx ] = 1
279+ policy [:, v_token_start :text_token_start ] = \
280+ pred_score_vis .type (dtype = hidden_states .dtype )
270281
271282 merge_token_stage2 = total_sparse_token [:, merge_token_stage2_idx , :]
272283 cluster_num = int (merge_token_stage2 .shape [1 ] / 10 ) + 1
273284 if cluster_num == 0 :
274285 cluster_num = merge_token_stage2 .shape [1 ]
286+ merge_sparse_token , index_down = cluster_and_merge (merge_token_stage2 , cluster_num )
275287
276- merge_sparse_token = cluster_and_merge ( merge_token_stage2 , cluster_num )
277-
288+ cluster_idx = total_sparse_token_idx . squeeze ( 0 )[ merge_token_stage2_idx [ index_down ]]
289+ cluster_idx = cluster_idx . squeeze ( 0 )
278290 select_token_idx = torch .where (policy == 1 )[1 ].unsqueeze (0 )
279291 select_token = batch_index_select (layer_outputs [0 ], select_token_idx )
280292 select_vis_token_num = pred_score_vis .sum ()
281-
293+ keep_indexs = torch .cat (
294+ (
295+ select_token_idx .squeeze (0 )[:v_token_start + select_vis_token_num ],
296+ cluster_idx ,
297+ select_token_idx .squeeze (0 )[v_token_start + select_vis_token_num :]
298+ )
299+ )
282300 select_and_merge_token = torch .cat (
283301 (
284- select_token [:, :v_token_start +
285- select_vis_token_num , :],
302+ select_token [:, :v_token_start + select_vis_token_num , :],
286303 merge_sparse_token ,
287- select_token [:, v_token_start +
288- select_vis_token_num :, :]
304+ select_token [:, v_token_start + select_vis_token_num :, :]
289305 ),
290306 dim = 1
291307 )
292308 layer_outputs = (select_and_merge_token , layer_outputs [1 ])
293- position_ids = position_ids [:, :len (select_token_idx [0 ]) + cluster_num ]
294309 v_token_num = pred_score_vis .sum () + cluster_num
295- text_token_start = v_token_start + v_token_num
310+
296311 else :
297- select_token_idx = torch .where (policy == 1 )[1 ].unsqueeze (0 )
312+ keep_indexs = torch .where (policy == 1 )[1 ]
313+ select_token_idx = keep_indexs .unsqueeze (0 )
298314 layer_outputs = (batch_index_select (layer_outputs [0 ], select_token_idx ),
299315 layer_outputs [1 ])
300- position_ids = position_ids [:, :len (select_token_idx [0 ])]
301316 v_token_num = pred_score_vis .sum ()
302- text_token_start = v_token_start + v_token_num
303317
318+ text_token_start = v_token_start + v_token_num
319+ position_ids = keep_indexs .unsqueeze (0 )
304320 new_output = layer_outputs
305- cache_position = position_ids .detach ().clone ()
321+ cache_position = position_ids .squeeze (0 )
322+
323+ if attention_mask is not None :
324+ attention_mask = attention_mask [:, :, keep_indexs , keep_indexs ]
325+ new_pe0 = position_embeddings [0 ][:, keep_indexs , :].clone ()
326+ new_pe1 = position_embeddings [1 ][:, keep_indexs , :].clone ()
327+ position_embeddings = (new_pe0 , new_pe1 )
306328
307329 pruning_pars ['v_token_num' ] = v_token_num
308330 pruning_pars ['text_token_start' ] = text_token_start
331+
309332 pruning_pars ['position_ids' ] = position_ids
310333 pruning_pars ['cache_position' ] = cache_position
311- pruning_pars ['position_embeddings' ] = None
334+ pruning_pars ['position_embeddings' ] = position_embeddings
335+ pruning_pars ['attention_mask' ] = attention_mask
312336
313337 return new_output
314338
315339 @prefill_wrapper
316340 def read_parameter_hook (module , args , kwargs , pruning_pars ):
317341 kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
342+ kwargs ['attention_mask' ] = pruning_pars ['attention_mask' ]
318343 kwargs ['cache_position' ] = pruning_pars ['cache_position' ]
319344 kwargs ['position_embeddings' ] = pruning_pars ['position_embeddings' ]
320345
@@ -363,7 +388,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
363388 with_kwargs = True
364389 )
365390 elif self .model .__class__ .__name__ == 'Llava' :
366- self .blocks [block_idx ].self_attn . register_forward_pre_hook (
391+ self .blocks [block_idx ].register_forward_pre_hook (
367392 functools .partial (
368393 update_kwargs_hook ,
369394 pruning_pars = self .pruning_paras ,
@@ -383,7 +408,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
383408 functools .partial (
384409 decoder_attn_hook ,
385410 pruning_pars = self .pruning_paras ,
386- layer_idx = block_idx ,
411+ layer_idx = block_idx
387412 ),
388413 with_kwargs = True
389414 )
@@ -397,17 +422,37 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
397422 )
398423
399424
400- layer_dict = {2 : 0 , 6 : 1 , 15 : 2 }
401-
402- sparse_token_list_192 = [300 , 200 , 110 ] # 2*576 4*300 10*200 16*110
403- sparse_token_list_128 = [303 , 110 , 36 ]
404- sparse_token_list_64 = [66 , 30 , 17 ]
425+ def update_list ():
426+ global sparse_token_list_192 , sparse_token_list_128 , sparse_token_list_64
427+ global prune_flag , merge_flag , sparse_token_dict
428+
429+ if layer_dict == {2 : 0 , 6 : 1 , 15 : 2 }: # 2*576 4*300 10*200 16*110
430+ sparse_token_list_192 = [300 , 200 , 110 ]
431+ sparse_token_list_128 = [303 , 110 , 36 ]
432+ sparse_token_list_64 = [66 , 30 , 17 ]
433+ prune_flag , merge_flag = True , True
434+ elif prune_flag and merge_flag :
435+ sparse_token_list_192 = [180 ]
436+ sparse_token_list_128 = [114 ]
437+ sparse_token_list_64 = [48 ]
438+ elif prune_flag :
439+ sparse_token_list_192 = [192 ]
440+ sparse_token_list_128 = [128 ]
441+ sparse_token_list_64 = [64 ]
442+ elif merge_flag :
443+ sparse_token_list_192 = [149 ]
444+ sparse_token_list_128 = [78 ]
445+ sparse_token_list_64 = [7 ]
446+ else :
447+ raise RuntimeError (
448+ 'Both prune_flag and merge_flag are False — sparseVLM is inactive.'
449+ )
405450
406- sparse_token_dict = {
407- 192 : sparse_token_list_192 ,
408- 128 : sparse_token_list_128 ,
409- 64 : sparse_token_list_64
410- }
451+ sparse_token_dict = {
452+ 192 : sparse_token_list_192 ,
453+ 128 : sparse_token_list_128 ,
454+ 64 : sparse_token_list_64
455+ }
411456
412457
413458def attn_postprocess_topk (
@@ -567,4 +612,4 @@ def cluster_and_merge(x, cluster_num):
567612 source = source .reshape (B * N , C ).type (x .dtype ))
568613 x_merged = x_merged .reshape (B , cluster_num , C )
569614
570- return x_merged
615+ return x_merged , index_down
0 commit comments