Skip to content

Commit c038265

Browse files
authored
update vispruner (#425)
1 parent aa80886 commit c038265

File tree

9 files changed

+298
-31
lines changed

9 files changed

+298
-31
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: Llava
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
9+
eval_pos: [pretrain, transformed]
10+
type: vqa
11+
name: [mme]
12+
download: False
13+
path: MME dataset path
14+
bs: 1
15+
inference_per_block: False
16+
sparse:
17+
vision:
18+
method: TokenReduction
19+
special:
20+
method: VisPruner
21+
prune_ratio: 0.778 # 0.667 0.778 0.889
22+
important_ratio: 0.5
23+
save:
24+
save_trans: False
25+
save_fake: False
26+
save_path: /path/to/save/

llmc/compression/token_reduction/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
from .sparsevlm import SparseVLM
1414
from .tome import ToMe
1515
from .visionzip import VisionZip
16+
from .vispruner import VisPruner
17+
from .visualizer import Visualizer

llmc/compression/token_reduction/visionzip.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def wrapper(self, *args, **kwargs):
454454

455455
def merger_hook(module, inputs, kwargs, layer_outs, pruning_paras):
456456
with torch.no_grad():
457-
attn_mean = pruning_paras['attn_logits'].mean(dim=0)
457+
attn_mean = pruning_paras['attn_logits'].mean(dim=0) # 16 1120, 1120 -> 1120, 1120
458458
attn_key = pruning_paras['attn_key']
459459

460460
window_index, _ = module.get_window_index(kwargs['grid_thw'])
@@ -539,10 +539,21 @@ def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras):
539539
st_idx = torch.nonzero(img_mask, as_tuple=True)[0]
540540

541541
if st_idx.numel() > 0:
542-
first, last = st_idx[0].item(), st_idx[-1].item()
543-
img_mask[first: last + 1] = ~select_mask
542+
discontinuities = torch.where(st_idx[1:] - st_idx[:-1] != 1)[0]
543+
if discontinuities.numel() > 0:
544+
raise ValueError('Visual tokens are not contiguous in input_ids!')
545+
segment_starts = [st_idx[0].item()] + [st_idx[i + 1].item() for i in discontinuities.tolist()] # noqa
546+
segment_ends = [st_idx[i].item() for i in discontinuities.tolist()] + [st_idx[-1].item()] # noqa
547+
offset = 0
548+
for first, last in zip(segment_starts, segment_ends):
549+
length = last - first + 1
550+
# [15 1502] [1505 3289]
551+
img_mask[first: last + 1] = ~select_mask[offset: offset + length]
552+
else:
553+
first, last = st_idx[0].item(), st_idx[-1].item()
554+
img_mask[first: last + 1] = ~select_mask
544555
img_mask = ~img_mask
545-
contexual_input_idx = false_pos[target_indices] + first
556+
contextual_input_idx = false_pos[target_indices] + first
546557

547558
hidden_states_filtered = inputs_embeds[:, first: last + 1][:, contextual_mask]
548559
hidden_to_merge = hidden_states_filtered[
@@ -562,7 +573,7 @@ def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras):
562573

563574
kwargs['position_ids'] = position_ids[:, :, img_mask]
564575
kwargs['attention_mask'] = attention_mask[:, img_mask]
565-
inputs_embeds[:, contexual_input_idx] = contextual_tokens
576+
inputs_embeds[:, contextual_input_idx] = contextual_tokens
566577
kwargs['inputs_embeds'] = inputs_embeds[:, img_mask]
567578
del contextual_tokens, hidden_states_filtered, hidden_to_merge, aggregated_hidden
568579
torch.cuda.empty_cache()
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import functools
2+
3+
import torch
4+
5+
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
6+
7+
from .token_reduction_module import TokenReductionModule
8+
9+
10+
@TOKEN_REDUCTION_REGISTRY.register('VisPruner')
11+
class VisPruner(TokenReductionModule):
12+
def __init__(self, config, model, blocks):
13+
super().__init__(config, model, blocks)
14+
self.add_sparse_config()
15+
self.register_reduction_modules()
16+
17+
def add_sparse_config(self):
18+
self.special_config['select_layer'] = self.model.pruning_config.get(
19+
'select_layer', -1
20+
)
21+
self.special_config['select_feature'] = self.model.pruning_config.get(
22+
'select_feature', None
23+
)
24+
25+
self.pruning_paras = self.special_config
26+
27+
def register_reduction_modules(self):
28+
29+
def update_output_attentions_hook(module, args, kwargs):
30+
kwargs['output_attentions'] = True
31+
32+
def store_attention_hook(module, inps, outs, pruning_paras):
33+
image_attentions = outs.attentions[pruning_paras['select_layer']]
34+
if pruning_paras['select_feature'] == 'patch':
35+
image_attentions = image_attentions[:, :, 0, 1:]
36+
elif pruning_paras['select_feature'] == 'cls_patch':
37+
image_attentions = image_attentions
38+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
39+
40+
pruning_paras['image_attentions'] = image_attentions.to(inps[0].dtype)
41+
42+
def get_index_masks_hook(module, args, pruning_paras):
43+
image_features = args[0]
44+
image_attentions = pruning_paras['image_attentions']
45+
46+
B, N, C = image_features.shape
47+
device = image_features.device
48+
index_masks = torch.ones(B, N, dtype=torch.bool, device=device)
49+
50+
visual_token_num = round(
51+
self.special_config['vision_token_length'] * (
52+
1 - self.special_config['prune_ratio']
53+
)
54+
) # T
55+
important_ratio = self.pruning_paras['important_ratio'] # r
56+
important_token_num = int(visual_token_num * important_ratio) # T_imp = T * r
57+
diverse_token_num = visual_token_num - important_token_num # T_div = T * (1 - r)
58+
59+
# [VisPruner] Select important tokens using attention scores
60+
image_attentions = image_attentions.mean(dim=1) # (B, N)
61+
token_indices = image_attentions.argsort(dim=-1, descending=True) # (B, N)
62+
important_indices = token_indices[:, :important_token_num] # (B, T_imp)
63+
residual_indices = token_indices[:, important_token_num:] # (B, N - T_imp)
64+
65+
# [VisPruner] Remove duplicate tokens by iterative matching and pruning
66+
image_normalized = image_features / image_features.norm(dim=-1, keepdim=True)
67+
while diverse_token_num > 0:
68+
R = residual_indices.shape[1]
69+
r = min(8, R - diverse_token_num)
70+
if r <= 0:
71+
break
72+
73+
residual_tokens = image_normalized[
74+
torch.arange(B).unsqueeze(-1).expand(-1, R),
75+
residual_indices
76+
] # (B, R, C)
77+
a, b = residual_tokens[..., ::2, :], residual_tokens[..., 1::2, :] # (B, R // 2, C)
78+
scores = a @ b.transpose(-1, -2) # (B, R // 2, R // 2)
79+
scores = scores.max(dim=-1).values # (B, R // 2)
80+
81+
distinct_indices = scores.argsort(dim=-1, descending=True)[:, r:] # (B, R // 2 - r)
82+
residual_indices = torch.cat([
83+
residual_indices[..., ::2][
84+
torch.arange(B).unsqueeze(-1).expand(-1, R // 2 - r),
85+
distinct_indices
86+
],
87+
residual_indices[..., 1::2]
88+
], dim=-1) # (B, R - r)
89+
90+
if diverse_token_num > 0:
91+
selected_indices = torch.cat([important_indices, residual_indices], dim=-1)
92+
else:
93+
selected_indices = important_indices # (B, T)
94+
index_masks = torch.zeros(B, N, dtype=torch.bool, device=device)
95+
index_masks.scatter_(1, selected_indices, True)
96+
97+
pruning_paras['index_masks'] = index_masks
98+
99+
def prune_hook(module, inputs, outputs, pruning_paras):
100+
image_features = outputs
101+
index_masks = pruning_paras['index_masks']
102+
return image_features[index_masks].unsqueeze(0)
103+
104+
self.model.vision_model.vision_tower.register_forward_pre_hook(
105+
update_output_attentions_hook,
106+
with_kwargs=True
107+
)
108+
109+
self.model.vision_model.vision_tower.register_forward_hook(
110+
functools.partial(store_attention_hook, pruning_paras=self.pruning_paras),
111+
)
112+
113+
self.model.vision_projector.register_forward_pre_hook(
114+
functools.partial(get_index_masks_hook, pruning_paras=self.pruning_paras),
115+
)
116+
117+
self.model.vision_projector.register_forward_hook(
118+
functools.partial(prune_hook, pruning_paras=self.pruning_paras),
119+
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import functools
2+
3+
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
4+
from llmc.utils.visualizer import (visualize_grid_to_grid, visualize_heads,
5+
visualize_kept_patches)
6+
7+
from .token_reduction_module import TokenReductionModule
8+
from .utils import prefill_wrapper
9+
10+
11+
@TOKEN_REDUCTION_REGISTRY.register('Visualizer')
12+
class Visualizer(TokenReductionModule):
13+
def __init__(self, config, model, blocks):
14+
super().__init__(config, model, blocks)
15+
self.add_sparse_config()
16+
self.register_reduction_modules()
17+
18+
def add_sparse_config(self):
19+
self.pruning_paras = self.special_config
20+
self.pruning_paras['attentions'] = []
21+
22+
def register_reduction_modules(self):
23+
24+
@prefill_wrapper
25+
def update_attentions_hook(module, args, kwargs):
26+
kwargs['output_attentions'] = True
27+
return args, kwargs
28+
29+
@prefill_wrapper
30+
def get_images_hook(module, input_args, pruning_paras):
31+
pruning_paras['images'] = input_args[0]
32+
return input_args
33+
34+
@prefill_wrapper
35+
def get_attentions_hook(module, inps, layer_outs, pruning_paras):
36+
pruning_paras['attentions'].append(layer_outs[1])
37+
return layer_outs
38+
39+
@prefill_wrapper
40+
def visualizer_hook(module, inps, layer_outs, pruning_paras):
41+
attention_maps = pruning_paras['attentions'][0]
42+
visual_attention_maps = attention_maps[:, :, 35: 35 + 576, 35: 35 + 576]
43+
image = pruning_paras['images'][0]
44+
45+
visualize_heads(
46+
visual_attention_maps[:, :6],
47+
cols=4,
48+
save_path=''
49+
)
50+
visualize_grid_to_grid(
51+
visual_attention_maps[0, 4, :, :],
52+
300,
53+
image,
54+
grid_size=24,
55+
save_path=''
56+
)
57+
visualize_kept_patches(
58+
pruning_paras['images'][0],
59+
pruning_paras['visual_keep_indexs'],
60+
save_path='',
61+
)
62+
return layer_outs
63+
64+
self.model.vision_model.register_forward_pre_hook(
65+
functools.partial(get_images_hook, pruning_paras=self.pruning_paras),
66+
)
67+
68+
for idx, blk in enumerate(self.blocks):
69+
if idx == 5:
70+
blk.register_forward_pre_hook(update_attentions_hook, with_kwargs=True)
71+
blk.register_forward_hook(
72+
functools.partial(get_attentions_hook, pruning_paras=self.pruning_paras),
73+
)
74+
if idx == (len(self.blocks) - 1):
75+
blk.register_forward_hook(
76+
functools.partial(visualizer_hook, pruning_paras=self.pruning_paras),
77+
)

llmc/eval/eval_vqa.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ def eval(
8989
datetime_str: str = get_datetime_str(),
9090
cli_args=None,
9191
):
92+
import argparse
93+
cli_args = argparse.Namespace(
94+
process_with_media=True,
95+
)
9296

9397
model = llmc_model.eval_name
9498
model_args = 'pretrained=' + self.model_path + ',device_map=auto'

llmc/models/llava.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,10 @@ def build_tokenizer(self):
3939
pass
4040

4141
def build_model(self):
42-
self.llava_config = LlavaConfig.from_pretrained(
43-
self.model_path, trust_remote_code=True
44-
)
42+
4543
self.vlm_model_config = AutoConfig.from_pretrained(
4644
self.model_path, trust_remote_code=True
4745
)
48-
# llava need: use_cache
49-
self.llava_config.use_cache = True
50-
self.vlm_model_config.use_cache = True
5146
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
5247

5348
self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model(
@@ -70,8 +65,8 @@ def build_model(self):
7065
self.pruning_config = {
7166
'is_video_model': False,
7267
'image_token_length': self.vlm_model_config.image_seq_length,
73-
'select_layer': self.vlm_model_config.vision_feature_layer,
74-
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
68+
'select_layer': self.vision_model.select_layer,
69+
'select_feature': self.vision_model.select_feature,
7570
'image_token_index': IMAGE_TOKEN_INDEX,
7671
'IMAGE_TOKEN_INDEX': IMAGE_TOKEN_INDEX, # for llava
7772
'vision_token_start_index': 35,

llmc/models/qwen2_5vl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ def build_model(self):
6262

6363
self.min_pixels = 256 * 28 * 28
6464
self.max_pixels = 1280 * 28 * 28
65-
logger.warning(f'min_pixels is set to: {self.min_pixels}')
66-
logger.warning(f'max_pixels is set to: {self.max_pixels}')
6765
self.processor = AutoProcessor.from_pretrained(
6866
self.model_path,
6967
min_pixels=self.min_pixels,
@@ -76,6 +74,7 @@ def build_model(self):
7674
'vision_start_token_id': self.vlm_model_config.vision_start_token_id,
7775
'vision_token_start_index': 15
7876
}
77+
self.first_turn_question = True
7978

8079
# todo: check
8180
def get_subsets_in_block(self, block):

0 commit comments

Comments
 (0)