Skip to content

Commit 439b78c

Browse files
committed
mustdrop
1 parent 2d3b736 commit 439b78c

File tree

1 file changed

+103
-61
lines changed

1 file changed

+103
-61
lines changed
Lines changed: 103 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
23
import torch
34

45
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
@@ -17,40 +18,51 @@ def add_sparse_config(self):
1718
self.pruning_paras = self.special_config
1819

1920
def register_reduction_modules(self):
20-
21+
2122
import math
2223
from typing import Callable, Tuple
2324

25+
import numpy as np
2426
import torch.nn.functional as F
2527
from einops import rearrange
26-
import numpy as np
2728

2829
def conditional_pooling(
2930
feat: torch.Tensor,
30-
threshold:float,
31+
threshold: float,
3132
window_size: Tuple[int, int],
3233
) -> Tuple[Callable, Callable]:
33-
34+
3435
with torch.no_grad():
35-
36-
ws_h, ws_w = int(window_size[0]), int(window_size[1]) #窗口尺寸,2*2
36+
37+
ws_h, ws_w = int(window_size[0]), int(window_size[1]) # 窗口尺寸,2*2
3738
stride_h, stride_w = ws_h, ws_w
38-
num_token_window = stride_h * stride_w #窗口内token数量,4
39-
40-
x_cls, feat = feat[:, :1, :], feat[:, 1:, :] # 取出cls token之外的所有tokens,一共576个vision token
39+
num_token_window = stride_h * stride_w # 窗口内token数量,4
40+
41+
_, feat = (
42+
feat[:, :1, :],
43+
feat[:, 1:, :],
44+
) # 取出cls token之外的所有tokens,一共576个vision token
4145
B, N, D = feat.size()
4246
base_grid_H = int(math.sqrt(N))
4347
base_grid_W = base_grid_H
44-
assert base_grid_H * base_grid_W == N and base_grid_H % ws_h == 0 and base_grid_W % ws_w == 0
45-
46-
feat = rearrange(feat, "b (h w) c -> b c h w", h=base_grid_H)
47-
48-
feat = rearrange(feat, 'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w', gh=base_grid_H//ws_h, gw=base_grid_W//ws_w)
48+
assert (
49+
base_grid_H * base_grid_W == N
50+
and base_grid_H % ws_h == 0
51+
and base_grid_W % ws_w == 0
52+
)
53+
54+
feat = rearrange(feat, 'b (h w) c -> b c h w', h=base_grid_H)
55+
56+
feat = rearrange(
57+
feat,
58+
'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w',
59+
gh=base_grid_H // ws_h,
60+
gw=base_grid_W // ws_w,
61+
)
4962
b, gh, gw, c, ps_h, ps_w = feat.shape
5063

5164
# Flatten mxm window for pairwise operations
5265
tensor_flattened = feat.reshape(b, gh, gw, c, -1)
53-
5466

5567
# Expand dims for pairwise operations
5668
tensor_1 = tensor_flattened.unsqueeze(-1)
@@ -64,65 +76,95 @@ def conditional_pooling(
6476
sims = sims * sims_mask
6577

6678
# Average similarities (excluding the self-similarity)
67-
similarity_map = sims.sum(-1).sum(-1) / ((ps_h * ps_w) * (ps_h * ps_w - 1))
68-
69-
similarity_map = rearrange(similarity_map.unsqueeze(1), 'b c h w-> b (c h w)')
70-
71-
#--- adaptive section ---#
72-
79+
similarity_map = sims.sum(-1).sum(-1) / (
80+
(ps_h * ps_w) * (ps_h * ps_w - 1)
81+
)
82+
83+
similarity_map = rearrange(
84+
similarity_map.unsqueeze(1), 'b c h w-> b (c h w)'
85+
)
86+
87+
# --- adaptive section ---#
88+
7389
n_B, n_H = similarity_map.shape
7490
node_mean = torch.tensor(threshold).cuda(sims.device)
75-
node_mean=node_mean.repeat(1,n_H)
91+
node_mean = node_mean.repeat(1, n_H)
7692
r = torch.ge(similarity_map, node_mean).sum(dim=1).min()
77-
# -------------#
78-
79-
# get top k similar super patches
80-
_, sim_super_patch_idxs = similarity_map.topk(r,dim=-1)
81-
82-
# --- creating the mergabel and unmergable super pathes
83-
tensor = torch.arange(base_grid_H * base_grid_W).reshape(base_grid_H, base_grid_W).to(feat.device)
93+
# -------------#
94+
95+
# get top k similar super patches
96+
_, sim_super_patch_idxs = similarity_map.topk(r, dim=-1)
97+
98+
# --- creating the mergabel and unmergable super patches
99+
tensor = (
100+
torch.arange(base_grid_H * base_grid_W)
101+
.reshape(base_grid_H, base_grid_W)
102+
.to(feat.device)
103+
)
84104

85105
# Repeat the tensor to create a batch of size 2
86106
tensor = tensor.unsqueeze(0).repeat(B, 1, 1)
87-
88107

89108
# Apply unfold operation on last two dimensions to create the sliding window
90-
windowed_tensor = tensor.unfold(1, ws_h, stride_h).unfold(2, ws_w, stride_w)
109+
windowed_tensor = tensor.unfold(1, ws_h, stride_h).unfold(
110+
2, ws_w, stride_w
111+
)
91112

92-
# Reshape the tensor to the desired shape
113+
# Reshape the tensor to the desired shape
93114
windowed_tensor = windowed_tensor.reshape(B, -1, num_token_window)
94-
95-
# Use torch.gather to collect the desired elements
96-
gathered_tensor = torch.gather(windowed_tensor, 1, sim_super_patch_idxs.unsqueeze(-1).expand(-1, -1, num_token_window))
97115

116+
# Use torch.gather to collect the desired elements
117+
gathered_tensor = torch.gather(
118+
windowed_tensor,
119+
1,
120+
sim_super_patch_idxs.unsqueeze(-1).expand(-1, -1, num_token_window),
121+
)
98122

99123
# Create a mask for all indices, for each batch
100-
mask = torch.ones((B, windowed_tensor.shape[1]), dtype=bool).to(feat.device)
124+
mask = torch.ones((B, windowed_tensor.shape[1]), dtype=bool).to(
125+
feat.device
126+
)
101127

102128
# Create a tensor that matches the shape of indices and fill it with False
103-
mask_values = torch.zeros_like(sim_super_patch_idxs, dtype=torch.bool).to(feat.device)
129+
mask_values = torch.zeros_like(
130+
sim_super_patch_idxs, dtype=torch.bool
131+
).to(feat.device)
104132

105-
# Use scatter_ to update the mask. This will set mask[b, indices[b]] = False for all b
133+
# Use scatter_ to update the mask.
134+
# This will set mask[b, indices[b]] = False for all b
106135
mask.scatter_(1, sim_super_patch_idxs, mask_values)
107136

108137
# Get the remaining tensor
109-
remaining_tensor = windowed_tensor[mask.unsqueeze(-1).expand(-1, -1, num_token_window)].reshape(B, -1, num_token_window)
110-
unm_idx = remaining_tensor.reshape(B, -1).sort(dim=-1).values.unsqueeze(-1)
111-
dim_index = (num_token_window)- 1
112-
src_idx= gathered_tensor[:, :, :dim_index].reshape(B, -1).unsqueeze(-1)
113-
dst_idx= gathered_tensor[:, :, dim_index].reshape(B, -1).unsqueeze(-1)
114-
merge_idx = torch.arange(src_idx.shape[1]//dim_index).repeat_interleave(dim_index).repeat(B, 1).unsqueeze(-1).to(feat.device)
115-
116-
117-
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
118-
# TODO: num_token_window can be undefined
119-
120-
x_cls , x_feat = x[:, :1, :], x[:, 1:, :]
138+
remaining_tensor = windowed_tensor[
139+
mask.unsqueeze(-1).expand(-1, -1, num_token_window)
140+
].reshape(B, -1, num_token_window)
141+
unm_idx = (
142+
remaining_tensor.reshape(B, -1).sort(dim=-1).values.unsqueeze(-1)
143+
)
144+
dim_index = (num_token_window) - 1
145+
src_idx = gathered_tensor[:, :, :dim_index].reshape(B, -1).unsqueeze(-1)
146+
dst_idx = gathered_tensor[:, :, dim_index].reshape(B, -1).unsqueeze(-1)
147+
merge_idx = (
148+
torch.arange(src_idx.shape[1] // dim_index)
149+
.repeat_interleave(dim_index)
150+
.repeat(B, 1)
151+
.unsqueeze(-1)
152+
.to(feat.device)
153+
)
154+
155+
def merge(x: torch.Tensor, mode='mean') -> torch.Tensor:
156+
# TODO: num_token_window can be undefined
157+
158+
x_cls, x_feat = x[:, :1, :], x[:, 1:, :]
121159
n, t1, c = x_feat.shape
122-
src = x_feat.gather(dim=-2, index=src_idx.expand(n, r*dim_index, c))
160+
src = x_feat.gather(dim=-2, index=src_idx.expand(n, r * dim_index, c))
123161
dst = x_feat.gather(dim=-2, index=dst_idx.expand(n, r, c))
124-
unm = x_feat.gather(dim=-2, index=unm_idx.expand(n, t1 - (r*num_token_window), c))
125-
dst = dst.scatter_reduce(-2, merge_idx.expand(n,r*dim_index, c), src, reduce=mode)
162+
unm = x_feat.gather(
163+
dim=-2, index=unm_idx.expand(n, t1 - (r * num_token_window), c)
164+
)
165+
dst = dst.scatter_reduce(
166+
-2, merge_idx.expand(n, r * dim_index, c), src, reduce=mode
167+
)
126168
x = torch.cat([dst, unm], dim=1)
127169
x = torch.cat((x_cls, x), dim=1)
128170
return x
@@ -132,27 +174,27 @@ def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
132174
def merge_wavg(
133175
merge: Callable, x: torch.Tensor, size: torch.Tensor = None
134176
) -> Tuple[torch.Tensor, torch.Tensor]:
135-
177+
136178
if size is None:
137179
size = torch.ones_like(x[..., 0, None])
138180

139-
x = merge(x * size, mode="sum")
140-
size = merge(size, mode="sum")
181+
x = merge(x * size, mode='sum')
182+
size = merge(size, mode='sum')
141183
x = x / size
142-
184+
143185
return x, size
144-
186+
145187
def spatial_merge_hook(module, args, kwargs, pruning_paras):
146188
spatial_threshold = pruning_paras['spatial_threshold']
147189
window_size = pruning_paras['window_size']
148190
hidden_states = args[0]
149191
merge = conditional_pooling(hidden_states, spatial_threshold, window_size)
150-
hidden_states, size =merge_wavg(merge, hidden_states, None)
192+
hidden_states, size = merge_wavg(merge, hidden_states, None)
151193
return (hidden_states,) + args[1:], kwargs
152-
194+
153195
self.model.set_modality('vision')
154196
self.model.find_blocks()
155197
self.model.blocks[1].register_forward_pre_hook(
156198
functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras),
157-
with_kwargs=True
199+
with_kwargs=True,
158200
)

0 commit comments

Comments
 (0)