Skip to content

Commit 5449ca3

Browse files
authored
Merge pull request #377 from ModelTC/vlm
add pyramiddrop
1 parent d3ffb37 commit 5449ca3

File tree

4 files changed

+377
-3
lines changed

4 files changed

+377
-3
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: Llava
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [transformed]
9+
type: vqa
10+
name: [mme]
11+
download: False
12+
path: MME dataset path
13+
bs: 1
14+
inference_per_block: False
15+
sparse:
16+
method: TokenReduction
17+
special:
18+
method: PyramidDrop
19+
image_token_ratio_list: [0.5, 0.25, 0.125]
20+
layer_list: [8, 16, 24]
21+
save:
22+
save_trans: False
23+
save_fake: False
24+
save_path: /path/to/save/
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .base_blockwise_token_reduction import TokenReduction
22
from .fastervlm import FasterVLM
33
from .fastv import FastV
4+
from .pyramiddrop import PyramidDrop
45
from .sparsevlm import SparseVLM
56
from .tome import ToMe
67
from .visionzip import VisionZip
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
import functools
2+
import math
3+
4+
import torch
5+
from torch import nn
6+
from transformers.modeling_attn_mask_utils import \
7+
_prepare_4d_causal_attention_mask
8+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
9+
10+
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
11+
12+
from .token_reduction_module import TokenReductionModule
13+
14+
15+
@TOKEN_REDUCTION_REGISTRY.register('PyramidDrop')
16+
class PyramidDrop(TokenReductionModule):
17+
def __init__(self, config, model, blocks):
18+
super().__init__(config, model, blocks)
19+
self.add_sparse_config()
20+
self.register_reduction_modules()
21+
22+
def add_sparse_config(self):
23+
special_config = self.config.get('special', {})
24+
self.pruning_loc = special_config['layer_list']
25+
image_token_ratio_list = special_config['image_token_ratio_list']
26+
image_token_ratio_list.insert(0, 1.0)
27+
special_config['image_token_ratio_list'] = image_token_ratio_list
28+
special_config['tokenizer_padding_side'] = getattr(
29+
self.model.vlm_model.language_model.model.config,
30+
'tokenizer_padding_side',
31+
'right',
32+
)
33+
special_config['image_token_index'] = self.model.pruning_config[
34+
'image_token_index'
35+
]
36+
self.model.model.parameters = special_config
37+
38+
def register_reduction_modules(self):
39+
40+
def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):
41+
42+
if layer_idx == self.pruning_loc[0]:
43+
position_ids = kwargs['position_ids']
44+
attention_mask = kwargs['attention_mask']
45+
position_embeddings = kwargs['position_embeddings']
46+
else:
47+
attention_mask = pruning_pars['attention_mask']
48+
position_ids = pruning_pars['position_ids']
49+
position_embeddings = pruning_pars['position_embeddings']
50+
51+
features = args[0]
52+
_position_ids = position_ids
53+
_attention_mask = attention_mask
54+
prompt_len = pruning_pars['prompt_len']
55+
image_tokens_list = pruning_pars['image_tokens']
56+
image_token_posi = pruning_pars['image_token_posi']
57+
image_token_ratio_list = pruning_pars['image_token_ratio_list']
58+
59+
if position_ids is None:
60+
position_ids = torch.arange(
61+
0, features.shape[1], dtype=torch.long, device=features.device
62+
).unsqueeze(0)
63+
64+
if pruning_pars['tokenizer_padding_side'] == 'right':
65+
66+
batch_size = features.shape[0]
67+
image_tokens = [
68+
int(cur_image_token * image_token_ratio_list[cur_num])
69+
for cur_image_token in image_tokens_list
70+
]
71+
keep_length = [
72+
int(cur_image_token * image_token_ratio_list[cur_num + 1])
73+
for cur_image_token in image_tokens_list
74+
]
75+
76+
features_list = []
77+
attention_mask_list = []
78+
79+
if attention_mask is None:
80+
attention_mask = torch.ones(
81+
(batch_size, features.shape[1]),
82+
dtype=torch.bool,
83+
device=features.device,
84+
)
85+
else:
86+
attention_mask = attention_mask.bool()
87+
88+
# obtain query_states and key_states to calculate attention map
89+
hidden_states = features.clone().detach()
90+
self_attn = module.self_attn
91+
hidden_states = module.input_layernorm(hidden_states)
92+
93+
num_heads = self_attn.num_heads
94+
num_key_value_heads = self_attn.num_key_value_heads
95+
head_dim = self_attn.head_dim
96+
97+
bsz, q_len, _ = hidden_states.size()
98+
99+
query_states = self_attn.q_proj(hidden_states)
100+
key_states = self_attn.k_proj(hidden_states)
101+
value_states = self_attn.v_proj(hidden_states)
102+
103+
query_states = query_states.view(
104+
bsz, q_len, num_heads, head_dim
105+
).transpose(1, 2)
106+
key_states = key_states.view(
107+
bsz, q_len, num_key_value_heads, head_dim
108+
).transpose(1, 2)
109+
value_states = value_states.view(
110+
bsz, q_len, num_key_value_heads, head_dim
111+
).transpose(1, 2)
112+
113+
if position_embeddings is None:
114+
cos, sin = self_attn.rotary_emb(value_states, position_ids)
115+
else:
116+
cos, sin = position_embeddings
117+
118+
query_states, key_states = apply_rotary_pos_emb(
119+
query_states, key_states, cos, sin
120+
)
121+
122+
# attention_mask
123+
eager_attention_mask = _prepare_4d_causal_attention_mask(
124+
attention_mask,
125+
(batch_size, q_len),
126+
hidden_states,
127+
past_key_values_length=0,
128+
).to(device=query_states.device)
129+
130+
# take valid features
131+
features = [
132+
cur_features[cur_attention_mask]
133+
for cur_features, cur_attention_mask in zip(
134+
features, attention_mask
135+
)
136+
]
137+
attention_mask = [
138+
cur_attention_mask[cur_attention_mask]
139+
for cur_attention_mask, cur_attention_mask in zip(
140+
attention_mask, attention_mask
141+
)
142+
]
143+
144+
# rank & drop
145+
for i in range(batch_size):
146+
image_index = image_token_posi[i]
147+
if image_index == -1:
148+
cur_input_embeds = features[i]
149+
features_list.append(cur_input_embeds)
150+
attention_mask_list.append(attention_mask[i])
151+
continue
152+
153+
# obtain current states
154+
cur_key_states = key_states[i]
155+
cur_query_states = query_states[i]
156+
cur_eager_attention_mask = eager_attention_mask[i]
157+
158+
prompt_total_len = prompt_len[i] + image_tokens[i]
159+
text_query_states = cur_query_states[
160+
:, prompt_total_len - 1, :
161+
].unsqueeze(1)
162+
text_eager_attention_mask = cur_eager_attention_mask[
163+
:, prompt_total_len - 1, :
164+
].unsqueeze(1)
165+
166+
# calculate attention map
167+
attn_weights = torch.matmul(
168+
text_query_states, cur_key_states.transpose(1, 2)
169+
) / math.sqrt(
170+
head_dim
171+
) # (num_head, text_token,seq_len)
172+
attn_weights = attn_weights + text_eager_attention_mask
173+
attn_weights = nn.functional.softmax(
174+
attn_weights, dim=-1, dtype=torch.float32
175+
).to(
176+
query_states.dtype
177+
) # (num_head, text_token,seq_len)
178+
179+
attention_avg_head = torch.mean(
180+
attn_weights, dim=0
181+
) # ave across heads
182+
attention_avg_head = attention_avg_head[
183+
:, image_index: image_index + image_tokens[i]
184+
] # select image token as keys
185+
attention_avg_text = torch.mean(attention_avg_head, dim=0) # (576)
186+
187+
# rank and drop by attention score
188+
top_rank_index = attention_avg_text.topk(keep_length[i]).indices
189+
top_rank_index = top_rank_index + image_index
190+
top_rank_index = top_rank_index.sort().values
191+
192+
start_index = image_index + image_tokens[i]
193+
new_input_embeds = torch.cat(
194+
[
195+
features[i][:image_index, :],
196+
features[i][top_rank_index, :],
197+
features[i][start_index:, :],
198+
],
199+
dim=0,
200+
)
201+
new_attention_mask = torch.cat(
202+
[
203+
attention_mask[i][:image_index],
204+
attention_mask[i][top_rank_index],
205+
attention_mask[i][start_index:],
206+
],
207+
dim=0,
208+
)
209+
210+
features_list.append(new_input_embeds)
211+
attention_mask_list.append(new_attention_mask)
212+
213+
# Truncate sequences to max length as image embeddings can make the sequence longer
214+
tokenizer_model_max_length = getattr(
215+
self.model.vlm_model.language_model.model.config,
216+
'tokenizer_model_max_length',
217+
2048,
218+
)
219+
if tokenizer_model_max_length is not None:
220+
new_input_embeds = [
221+
x[:tokenizer_model_max_length] for x in features_list
222+
]
223+
new_attention_mask = [
224+
x[:tokenizer_model_max_length] for x in attention_mask_list
225+
]
226+
227+
max_len = max(x.shape[0] for x in new_input_embeds)
228+
229+
# padding the sequences to form batch
230+
embeds_padded = []
231+
attention_mask_padded = []
232+
position_ids = torch.zeros(
233+
(batch_size, max_len),
234+
dtype=position_ids.dtype,
235+
device=position_ids.device,
236+
)
237+
for i, cur_new_embed in enumerate(new_input_embeds):
238+
cur_len_emb = cur_new_embed.shape[0]
239+
dif = max_len - cur_len_emb # padding to longest seq
240+
241+
cur_new_embed = torch.cat(
242+
[
243+
cur_new_embed,
244+
torch.zeros(
245+
(dif, cur_new_embed.shape[1]),
246+
dtype=cur_new_embed.dtype,
247+
device=cur_new_embed.device,
248+
),
249+
],
250+
dim=0,
251+
)
252+
cur_attention_mask = new_attention_mask[i]
253+
cur_attention_mask = torch.cat(
254+
[
255+
cur_attention_mask,
256+
torch.full(
257+
(dif,),
258+
False,
259+
dtype=cur_attention_mask.dtype,
260+
device=cur_attention_mask.device,
261+
),
262+
],
263+
dim=0,
264+
)
265+
266+
embeds_padded.append(cur_new_embed)
267+
268+
attention_mask_padded.append(cur_attention_mask)
269+
270+
cur_len = new_attention_mask[i].sum().item()
271+
position_ids[i, :cur_len] = torch.arange(
272+
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
273+
)
274+
275+
new_input_embeds = torch.stack(embeds_padded, dim=0)
276+
new_input_embeds = new_input_embeds.to(features[0].dtype)
277+
278+
new_attention_mask = torch.stack(attention_mask_padded, dim=0)
279+
280+
if _position_ids is None:
281+
position_ids = None
282+
283+
if _attention_mask is None:
284+
new_attention_mask = None
285+
else:
286+
new_attention_mask = new_attention_mask.to(
287+
dtype=_attention_mask.dtype
288+
)
289+
290+
kwargs['attention_mask'] = new_attention_mask
291+
kwargs['position_ids'] = position_ids
292+
kwargs['position_embeddings'] = None
293+
pruning_pars['attention_mask'] = new_attention_mask
294+
pruning_pars['position_ids'] = position_ids
295+
pruning_pars['position_embeddings'] = None
296+
297+
return (new_input_embeds,), kwargs
298+
299+
def input_hook(module, input_args, pruning_pars):
300+
input_ids = input_args[0]
301+
pre_prompt_length_list = []
302+
image_token_posi = []
303+
image_tokens = []
304+
IMAGE_TOKEN_INDEX = pruning_pars['image_token_index']
305+
306+
# find the position of the first image token
307+
for seq in input_ids:
308+
image_token_idxs = (seq == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0]
309+
image_tokens.append(image_token_idxs.shape[0])
310+
image_token_posi.append(image_token_idxs[0].item())
311+
pre_prompt_length_list.append(seq.shape[0] - image_token_idxs.shape[0])
312+
313+
pruning_pars['prompt_len'] = pre_prompt_length_list
314+
pruning_pars['image_token_posi'] = image_token_posi
315+
pruning_pars['image_tokens'] = image_tokens
316+
317+
return input_args
318+
319+
def read_parameter_hook(module, args, kwargs, pruning_pars):
320+
kwargs['attention_mask'] = pruning_pars['attention_mask']
321+
# kwargs['cache_position'] = pruning_pars['cache_position']
322+
kwargs['position_ids'] = pruning_pars['position_ids']
323+
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
324+
325+
return args, kwargs
326+
327+
self.model.embed_tokens.register_forward_pre_hook(
328+
functools.partial(input_hook, pruning_pars=self.model.model.parameters)
329+
)
330+
331+
for layer_idx in range(self.pruning_loc[0], len(self.blocks)):
332+
if layer_idx in self.pruning_loc:
333+
stage = self.pruning_loc.index(layer_idx)
334+
self.blocks[layer_idx].register_forward_pre_hook(
335+
functools.partial(
336+
pruning_hook,
337+
pruning_pars=self.model.model.parameters,
338+
cur_num=stage,
339+
layer_idx=layer_idx,
340+
),
341+
with_kwargs=True,
342+
)
343+
else:
344+
self.blocks[layer_idx].register_forward_pre_hook(
345+
functools.partial(
346+
read_parameter_hook, pruning_pars=self.model.model.parameters
347+
),
348+
with_kwargs=True,
349+
)

0 commit comments

Comments
 (0)