11import functools
2+
23import torch
34
45from llmc .utils .registry_factory import TOKEN_REDUCTION_REGISTRY
@@ -17,40 +18,51 @@ def add_sparse_config(self):
1718 self .pruning_paras = self .special_config
1819
1920 def register_reduction_modules (self ):
20-
21+
2122 import math
2223 from typing import Callable , Tuple
2324
25+ import numpy as np
2426 import torch .nn .functional as F
2527 from einops import rearrange
26- import numpy as np
2728
2829 def conditional_pooling (
2930 feat : torch .Tensor ,
30- threshold :float ,
31+ threshold : float ,
3132 window_size : Tuple [int , int ],
3233 ) -> Tuple [Callable , Callable ]:
33-
34+
3435 with torch .no_grad ():
35-
36- ws_h , ws_w = int (window_size [0 ]), int (window_size [1 ]) #窗口尺寸,2*2
36+
37+ ws_h , ws_w = int (window_size [0 ]), int (window_size [1 ]) # 窗口尺寸,2*2
3738 stride_h , stride_w = ws_h , ws_w
38- num_token_window = stride_h * stride_w #窗口内token数量,4
39-
40- x_cls , feat = feat [:, :1 , :], feat [:, 1 :, :] # 取出cls token之外的所有tokens,一共576个vision token
39+ num_token_window = stride_h * stride_w # 窗口内token数量,4
40+
41+ _ , feat = (
42+ feat [:, :1 , :],
43+ feat [:, 1 :, :],
44+ ) # 取出cls token之外的所有tokens,一共576个vision token
4145 B , N , D = feat .size ()
4246 base_grid_H = int (math .sqrt (N ))
4347 base_grid_W = base_grid_H
44- assert base_grid_H * base_grid_W == N and base_grid_H % ws_h == 0 and base_grid_W % ws_w == 0
45-
46- feat = rearrange (feat , "b (h w) c -> b c h w" , h = base_grid_H )
47-
48- feat = rearrange (feat , 'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w' , gh = base_grid_H // ws_h , gw = base_grid_W // ws_w )
48+ assert (
49+ base_grid_H * base_grid_W == N
50+ and base_grid_H % ws_h == 0
51+ and base_grid_W % ws_w == 0
52+ )
53+
54+ feat = rearrange (feat , 'b (h w) c -> b c h w' , h = base_grid_H )
55+
56+ feat = rearrange (
57+ feat ,
58+ 'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w' ,
59+ gh = base_grid_H // ws_h ,
60+ gw = base_grid_W // ws_w ,
61+ )
4962 b , gh , gw , c , ps_h , ps_w = feat .shape
5063
5164 # Flatten mxm window for pairwise operations
5265 tensor_flattened = feat .reshape (b , gh , gw , c , - 1 )
53-
5466
5567 # Expand dims for pairwise operations
5668 tensor_1 = tensor_flattened .unsqueeze (- 1 )
@@ -64,65 +76,95 @@ def conditional_pooling(
6476 sims = sims * sims_mask
6577
6678 # Average similarities (excluding the self-similarity)
67- similarity_map = sims .sum (- 1 ).sum (- 1 ) / ((ps_h * ps_w ) * (ps_h * ps_w - 1 ))
68-
69- similarity_map = rearrange (similarity_map .unsqueeze (1 ), 'b c h w-> b (c h w)' )
70-
71- #--- adaptive section ---#
72-
79+ similarity_map = sims .sum (- 1 ).sum (- 1 ) / (
80+ (ps_h * ps_w ) * (ps_h * ps_w - 1 )
81+ )
82+
83+ similarity_map = rearrange (
84+ similarity_map .unsqueeze (1 ), 'b c h w-> b (c h w)'
85+ )
86+
87+ # --- adaptive section ---#
88+
7389 n_B , n_H = similarity_map .shape
7490 node_mean = torch .tensor (threshold ).cuda (sims .device )
75- node_mean = node_mean .repeat (1 ,n_H )
91+ node_mean = node_mean .repeat (1 , n_H )
7692 r = torch .ge (similarity_map , node_mean ).sum (dim = 1 ).min ()
77- # -------------#
78-
79- # get top k similar super patches
80- _ , sim_super_patch_idxs = similarity_map .topk (r ,dim = - 1 )
81-
82- # --- creating the mergabel and unmergable super pathes
83- tensor = torch .arange (base_grid_H * base_grid_W ).reshape (base_grid_H , base_grid_W ).to (feat .device )
93+ # -------------#
94+
95+ # get top k similar super patches
96+ _ , sim_super_patch_idxs = similarity_map .topk (r , dim = - 1 )
97+
98+ # --- creating the mergabel and unmergable super patches
99+ tensor = (
100+ torch .arange (base_grid_H * base_grid_W )
101+ .reshape (base_grid_H , base_grid_W )
102+ .to (feat .device )
103+ )
84104
85105 # Repeat the tensor to create a batch of size 2
86106 tensor = tensor .unsqueeze (0 ).repeat (B , 1 , 1 )
87-
88107
89108 # Apply unfold operation on last two dimensions to create the sliding window
90- windowed_tensor = tensor .unfold (1 , ws_h , stride_h ).unfold (2 , ws_w , stride_w )
109+ windowed_tensor = tensor .unfold (1 , ws_h , stride_h ).unfold (
110+ 2 , ws_w , stride_w
111+ )
91112
92- # Reshape the tensor to the desired shape
113+ # Reshape the tensor to the desired shape
93114 windowed_tensor = windowed_tensor .reshape (B , - 1 , num_token_window )
94-
95- # Use torch.gather to collect the desired elements
96- gathered_tensor = torch .gather (windowed_tensor , 1 , sim_super_patch_idxs .unsqueeze (- 1 ).expand (- 1 , - 1 , num_token_window ))
97115
116+ # Use torch.gather to collect the desired elements
117+ gathered_tensor = torch .gather (
118+ windowed_tensor ,
119+ 1 ,
120+ sim_super_patch_idxs .unsqueeze (- 1 ).expand (- 1 , - 1 , num_token_window ),
121+ )
98122
99123 # Create a mask for all indices, for each batch
100- mask = torch .ones ((B , windowed_tensor .shape [1 ]), dtype = bool ).to (feat .device )
124+ mask = torch .ones ((B , windowed_tensor .shape [1 ]), dtype = bool ).to (
125+ feat .device
126+ )
101127
102128 # Create a tensor that matches the shape of indices and fill it with False
103- mask_values = torch .zeros_like (sim_super_patch_idxs , dtype = torch .bool ).to (feat .device )
129+ mask_values = torch .zeros_like (
130+ sim_super_patch_idxs , dtype = torch .bool
131+ ).to (feat .device )
104132
105- # Use scatter_ to update the mask. This will set mask[b, indices[b]] = False for all b
133+ # Use scatter_ to update the mask.
134+ # This will set mask[b, indices[b]] = False for all b
106135 mask .scatter_ (1 , sim_super_patch_idxs , mask_values )
107136
108137 # Get the remaining tensor
109- remaining_tensor = windowed_tensor [mask .unsqueeze (- 1 ).expand (- 1 , - 1 , num_token_window )].reshape (B , - 1 , num_token_window )
110- unm_idx = remaining_tensor .reshape (B , - 1 ).sort (dim = - 1 ).values .unsqueeze (- 1 )
111- dim_index = (num_token_window )- 1
112- src_idx = gathered_tensor [:, :, :dim_index ].reshape (B , - 1 ).unsqueeze (- 1 )
113- dst_idx = gathered_tensor [:, :, dim_index ].reshape (B , - 1 ).unsqueeze (- 1 )
114- merge_idx = torch .arange (src_idx .shape [1 ]// dim_index ).repeat_interleave (dim_index ).repeat (B , 1 ).unsqueeze (- 1 ).to (feat .device )
115-
116-
117- def merge (x : torch .Tensor , mode = "mean" ) -> torch .Tensor :
118- # TODO: num_token_window can be undefined
119-
120- x_cls , x_feat = x [:, :1 , :], x [:, 1 :, :]
138+ remaining_tensor = windowed_tensor [
139+ mask .unsqueeze (- 1 ).expand (- 1 , - 1 , num_token_window )
140+ ].reshape (B , - 1 , num_token_window )
141+ unm_idx = (
142+ remaining_tensor .reshape (B , - 1 ).sort (dim = - 1 ).values .unsqueeze (- 1 )
143+ )
144+ dim_index = (num_token_window ) - 1
145+ src_idx = gathered_tensor [:, :, :dim_index ].reshape (B , - 1 ).unsqueeze (- 1 )
146+ dst_idx = gathered_tensor [:, :, dim_index ].reshape (B , - 1 ).unsqueeze (- 1 )
147+ merge_idx = (
148+ torch .arange (src_idx .shape [1 ] // dim_index )
149+ .repeat_interleave (dim_index )
150+ .repeat (B , 1 )
151+ .unsqueeze (- 1 )
152+ .to (feat .device )
153+ )
154+
155+ def merge (x : torch .Tensor , mode = 'mean' ) -> torch .Tensor :
156+ # TODO: num_token_window can be undefined
157+
158+ x_cls , x_feat = x [:, :1 , :], x [:, 1 :, :]
121159 n , t1 , c = x_feat .shape
122- src = x_feat .gather (dim = - 2 , index = src_idx .expand (n , r * dim_index , c ))
160+ src = x_feat .gather (dim = - 2 , index = src_idx .expand (n , r * dim_index , c ))
123161 dst = x_feat .gather (dim = - 2 , index = dst_idx .expand (n , r , c ))
124- unm = x_feat .gather (dim = - 2 , index = unm_idx .expand (n , t1 - (r * num_token_window ), c ))
125- dst = dst .scatter_reduce (- 2 , merge_idx .expand (n ,r * dim_index , c ), src , reduce = mode )
162+ unm = x_feat .gather (
163+ dim = - 2 , index = unm_idx .expand (n , t1 - (r * num_token_window ), c )
164+ )
165+ dst = dst .scatter_reduce (
166+ - 2 , merge_idx .expand (n , r * dim_index , c ), src , reduce = mode
167+ )
126168 x = torch .cat ([dst , unm ], dim = 1 )
127169 x = torch .cat ((x_cls , x ), dim = 1 )
128170 return x
@@ -132,27 +174,27 @@ def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
132174 def merge_wavg (
133175 merge : Callable , x : torch .Tensor , size : torch .Tensor = None
134176 ) -> Tuple [torch .Tensor , torch .Tensor ]:
135-
177+
136178 if size is None :
137179 size = torch .ones_like (x [..., 0 , None ])
138180
139- x = merge (x * size , mode = " sum" )
140- size = merge (size , mode = " sum" )
181+ x = merge (x * size , mode = ' sum' )
182+ size = merge (size , mode = ' sum' )
141183 x = x / size
142-
184+
143185 return x , size
144-
186+
145187 def spatial_merge_hook (module , args , kwargs , pruning_paras ):
146188 spatial_threshold = pruning_paras ['spatial_threshold' ]
147189 window_size = pruning_paras ['window_size' ]
148190 hidden_states = args [0 ]
149191 merge = conditional_pooling (hidden_states , spatial_threshold , window_size )
150- hidden_states , size = merge_wavg (merge , hidden_states , None )
192+ hidden_states , size = merge_wavg (merge , hidden_states , None )
151193 return (hidden_states ,) + args [1 :], kwargs
152-
194+
153195 self .model .set_modality ('vision' )
154196 self .model .find_blocks ()
155197 self .model .blocks [1 ].register_forward_pre_hook (
156198 functools .partial (spatial_merge_hook , pruning_paras = self .pruning_paras ),
157- with_kwargs = True
199+ with_kwargs = True ,
158200 )
0 commit comments