Skip to content

Commit 695fbc3

Browse files
authored
update divprune,mustdrop for llava-next (#428)
1 parent 93afb24 commit 695fbc3

File tree

6 files changed

+487
-227
lines changed

6 files changed

+487
-227
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: Llava
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [pretrain, transformed]
9+
type: vqa
10+
name: [mme]
11+
download: False
12+
path: MME dataset path
13+
bs: 1
14+
inference_per_block: False
15+
sparse:
16+
method: TokenReduction
17+
special:
18+
method: DivPrune
19+
reduction_ratio: 0.9444 # 0.7778 0.8889 0.9444
20+
save:
21+
save_trans: False
22+
save_fake: False
23+
save_path: /path/to/save/
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: Llava
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [pretrain, transformed]
9+
type: vqa
10+
name: [mme]
11+
download: False
12+
path: MME dataset path
13+
bs: 1
14+
sparse:
15+
vision:
16+
method: TokenReduction
17+
special:
18+
method: MustDrop
19+
spatial_threshold: 0.6
20+
window_size: [3, 3]
21+
retained_tokens: 128 # llava_next: 128, 64, 32 llava: 192, 128, 64
22+
save:
23+
save_trans: False
24+
save_fake: False
25+
save_path: /path/to/save/
Lines changed: 31 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import functools
21
from functools import wraps
32
from types import MethodType
43

@@ -7,7 +6,6 @@
76
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
87

98
from .token_reduction_module import TokenReductionModule
10-
from .utils import prefill_wrapper
119

1210

1311
def pairwise_cosine_similarity(matrix):
@@ -22,7 +20,7 @@ def divprune(
2220
cosine_matrix=None,
2321
threshold_ratio=0.1,
2422
):
25-
threshold_terms = int(round(threshold_ratio * image_feature_length))
23+
threshold_terms = round(threshold_ratio * image_feature_length)
2624
if cosine_matrix is None:
2725
cosine_matrix = 1.0 - (pairwise_cosine_similarity(visual_feature_vectors))
2826

@@ -53,22 +51,16 @@ def divprune(
5351
return s, cosine_matrix
5452

5553

56-
def divprune_post_hook(
57-
input_ids,
58-
position_ids,
59-
attention_mask,
60-
past_key_values,
61-
inputs_embeds,
62-
labels,
63-
pruning_paras=None,
64-
):
65-
rate = pruning_paras['rate']
66-
SYS_TOKEN_LEN = pruning_paras['image_token_start_index']
67-
img_feature_len = pruning_paras['image_token_length']
54+
def divprune_post_hook(*args, pruning_paras=None):
55+
args = list(args)
56+
position_ids, attention_mask, inputs_embeds = args[1], args[2], args[4]
57+
rate = pruning_paras['reduction_ratio']
58+
SYS_TOKEN_LEN = pruning_paras['vision_token_start_index']
59+
img_feature_len = pruning_paras['vision_token_length']
6860
device = inputs_embeds.device
6961
visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len]
7062
selected_visual_tokens, cosine_matrix = divprune(
71-
visual_tokens, img_feature_len, None, threshold_ratio=rate
63+
visual_tokens, img_feature_len, None, threshold_ratio=1 - rate
7264
)
7365

7466
selected_visual_tokens += SYS_TOKEN_LEN
@@ -83,20 +75,13 @@ def divprune_post_hook(
8375
)
8476
keep_indexs = keep_indexs.sort().values
8577

86-
inputs_embeds = inputs_embeds[:, keep_indexs]
8778
if position_ids is not None:
88-
position_ids = position_ids[:, keep_indexs, :]
79+
args[1] = position_ids[:, keep_indexs, :]
8980
if attention_mask is not None:
90-
attention_mask = attention_mask[:, keep_indexs]
91-
92-
return (
93-
input_ids,
94-
position_ids,
95-
attention_mask,
96-
past_key_values,
97-
inputs_embeds,
98-
labels,
99-
)
81+
args[2] = attention_mask[:, keep_indexs]
82+
args[4] = inputs_embeds[:, keep_indexs]
83+
84+
return tuple(args)
10085

10186

10287
@TOKEN_REDUCTION_REGISTRY.register('DivPrune')
@@ -107,43 +92,34 @@ def __init__(self, config, model, blocks):
10792
self.register_reduction_modules()
10893

10994
def add_sparse_config(self):
110-
self.special_config['image_token_length'] = self.model.pruning_config[
111-
'image_token_length'
112-
]
113-
11495
self.pruning_paras = self.special_config
11596

11697
def register_reduction_modules(self):
11798

118-
def input_hook_llava(fn, pruning_paras):
99+
def input_hook_llava(fn, pruning_paras, llava_next):
119100
@wraps(fn)
120101
def wrapper(self, *args, **kwargs):
121-
if len(args) == 0:
122-
return fn(*args, **kwargs)
123-
input_args = args[0]
124-
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
102+
if args[0].shape[1] == 1:
125103
return fn(*args, **kwargs)
126-
127-
input_ids = args[0]
128-
attention_mask = args[2]
129-
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
130-
pruning_paras['image_token_start_index'] = torch.where(token_indices)[
131-
0
132-
][0].item()
133-
134-
outputs = fn(*args, **kwargs)
135-
136-
return divprune_post_hook(*outputs, pruning_paras=pruning_paras)
137-
104+
outs = fn(*args, **kwargs)
105+
106+
if llava_next:
107+
message = (
108+
'To obtain the vision_token_length for LLaVA-1.6, you should append '
109+
'`image_features[0].shape[0]` to the return value of the function '
110+
'`prepare_inputs_labels_for_multimodal`, and modify the related code.'
111+
)
112+
assert len(outs) == 7, message
113+
pruning_paras['vision_token_length'] = outs[-1]
114+
return divprune_post_hook(*outs, pruning_paras=pruning_paras)
138115
return wrapper
139116

140117
if self.model.__class__.__name__ == 'Llava':
141-
from llava.constants import IMAGE_TOKEN_INDEX
142118

143-
hook_fn = input_hook_llava(
144-
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
145-
self.pruning_paras,
146-
)
147119
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
148-
hook_fn, self.model.vlm_model
120+
input_hook_llava(
121+
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
122+
self.pruning_paras,
123+
llava_next=self.special_config['vision_token_length'] is None
124+
), self.model.vlm_model
149125
)

llmc/compression/token_reduction/mustdrop.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import functools
2+
import math
3+
from types import MethodType
4+
from typing import Callable, Tuple
25

36
import torch
7+
import torch.nn.functional as F
8+
from einops import rearrange
49

510
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
611

712
from .token_reduction_module import TokenReductionModule
13+
from .utils import prepare_inputs_labels_for_multimodal_with_index_masks
814

915

1016
@TOKEN_REDUCTION_REGISTRY.register('MustDrop')
@@ -15,18 +21,11 @@ def __init__(self, config, model, blocks):
1521
self.register_reduction_modules()
1622

1723
def add_sparse_config(self):
18-
self.pruning_loc = self.special_config['pruning_loc']
24+
self.pruning_loc = self.model.pruning_config.get('select_layer', -1)
1925
self.pruning_paras = self.special_config
2026

2127
def register_reduction_modules(self):
2228

23-
import math
24-
from typing import Callable, Tuple
25-
26-
import numpy as np
27-
import torch.nn.functional as F
28-
from einops import rearrange
29-
3029
def conditional_pooling(
3130
feat: torch.Tensor,
3231
threshold: float,
@@ -170,7 +169,14 @@ def merge(x: torch.Tensor, mode='mean') -> torch.Tensor:
170169
)
171170
x = torch.cat([dst, unm], dim=1)
172171
x = torch.cat((x_cls, x), dim=1)
173-
return x
172+
173+
index_masks = torch.zeros((n, t1), dtype=torch.bool, device=x_feat.device)
174+
dst_flat = dst_idx.view(n, -1)
175+
unm_flat = unm_idx.view(n, -1)
176+
index_masks.scatter_(1, dst_flat, True)
177+
index_masks.scatter_(1, unm_flat, True)
178+
179+
return x, index_masks
174180

175181
return merge
176182

@@ -181,26 +187,49 @@ def merge_wavg(
181187
if size is None:
182188
size = torch.ones_like(x[..., 0, None])
183189

184-
x = merge(x * size, mode='sum')
185-
size = merge(size, mode='sum')
190+
x, index_masks = merge(x * size, mode='sum')
191+
size, _ = merge(size, mode='sum')
186192
x = x / size
187193

188-
return x, size
194+
return x, size, index_masks
189195

190-
def spatial_merge_hook(module, args, kwargs, layer_outs, pruning_paras):
196+
def spatial_merge_hook(module, inps, outs, pruning_paras, llava_next):
191197
spatial_threshold = pruning_paras['spatial_threshold']
192198
window_size = pruning_paras['window_size']
193-
hidden_states = layer_outs[0]
199+
hidden_states = outs[0]
200+
vtoken_length = hidden_states.shape[1]
194201
fix_r = 0
195202
if pruning_paras.get('retained_tokens', None) is not None:
196203
retained_tokens = pruning_paras['retained_tokens']
197-
fix_r = (pruning_paras['vision_token_length'] - retained_tokens) \
204+
fix_r = (vtoken_length - retained_tokens) \
198205
// (window_size[0] * window_size[1] - 1)
199206
merge = conditional_pooling(hidden_states, spatial_threshold, window_size, fix_r)
200-
hidden_states, size = merge_wavg(merge, hidden_states, None)
201-
return (hidden_states,)
207+
hidden_states, size, index_masks = merge_wavg(merge, hidden_states, None)
208+
209+
if not llava_next:
210+
return (hidden_states,)
202211

203-
self.blocks[self.pruning_loc - 1].register_forward_hook(
204-
functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras),
205-
with_kwargs=True,
212+
pruning_paras['index_masks'] = index_masks
213+
return outs
214+
215+
def update_index_masks_hook(module, inps, outs, pruning_paras):
216+
module.index_masks = pruning_paras['index_masks']
217+
218+
self.blocks[self.pruning_loc].register_forward_hook(
219+
functools.partial(
220+
spatial_merge_hook,
221+
pruning_paras=self.pruning_paras,
222+
llava_next=self.special_config['vision_token_length'] is None
223+
),
206224
)
225+
226+
if self.special_config['vision_token_length'] is None:
227+
228+
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
229+
prepare_inputs_labels_for_multimodal_with_index_masks,
230+
self.model.vlm_model
231+
)
232+
233+
self.model.vision_model.register_forward_hook(
234+
functools.partial(update_index_masks_hook, pruning_paras=self.pruning_paras),
235+
)

0 commit comments

Comments
 (0)