@@ -15,6 +15,7 @@ def __init__(self, config, model, blocks):
1515 self .register_reduction_modules ()
1616
1717 def add_sparse_config (self ):
18+ self .pruning_loc = self .special_config ['pruning_loc' ]
1819 self .pruning_paras = self .special_config
1920
2021 def register_reduction_modules (self ):
@@ -30,6 +31,7 @@ def conditional_pooling(
3031 feat : torch .Tensor ,
3132 threshold : float ,
3233 window_size : Tuple [int , int ],
34+ fix_r : int = 0 ,
3335 ) -> Tuple [Callable , Callable ]:
3436
3537 with torch .no_grad ():
@@ -91,7 +93,8 @@ def conditional_pooling(
9193 node_mean = node_mean .repeat (1 , n_H )
9294 r = torch .ge (similarity_map , node_mean ).sum (dim = 1 ).min ()
9395 # -------------#
94-
96+ if fix_r != 0 :
97+ r = fix_r
9598 # get top k similar super patches
9699 _ , sim_super_patch_idxs = similarity_map .topk (r , dim = - 1 )
97100
@@ -184,17 +187,20 @@ def merge_wavg(
184187
185188 return x , size
186189
187- def spatial_merge_hook (module , args , kwargs , pruning_paras ):
190+ def spatial_merge_hook (module , args , kwargs , layer_outs , pruning_paras ):
188191 spatial_threshold = pruning_paras ['spatial_threshold' ]
189192 window_size = pruning_paras ['window_size' ]
190- hidden_states = args [0 ]
191- merge = conditional_pooling (hidden_states , spatial_threshold , window_size )
193+ hidden_states = layer_outs [0 ]
194+ fix_r = 0
195+ if pruning_paras .get ('retained_tokens' , None ) is not None :
196+ retained_tokens = pruning_paras ['retained_tokens' ]
197+ fix_r = (pruning_paras ['vision_token_length' ] - retained_tokens ) \
198+ // (window_size [0 ] * window_size [1 ] - 1 )
199+ merge = conditional_pooling (hidden_states , spatial_threshold , window_size , fix_r )
192200 hidden_states , size = merge_wavg (merge , hidden_states , None )
193- return (hidden_states ,) + args [ 1 :], kwargs
201+ return (hidden_states ,)
194202
195- self .model .set_modality ('vision' )
196- self .model .find_blocks ()
197- self .model .blocks [1 ].register_forward_pre_hook (
203+ self .blocks [self .pruning_loc - 1 ].register_forward_hook (
198204 functools .partial (spatial_merge_hook , pruning_paras = self .pruning_paras ),
199205 with_kwargs = True ,
200206 )
0 commit comments