Skip to content

Commit f74e2bb

Browse files
committed
update vispruner
1 parent aa80886 commit f74e2bb

File tree

9 files changed

+278
-29
lines changed

9 files changed

+278
-29
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: Llava
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
9+
eval_pos: [pretrain, transformed]
10+
type: vqa
11+
name: [mme]
12+
download: False
13+
path: MME dataset path
14+
bs: 1
15+
inference_per_block: False
16+
sparse:
17+
vision:
18+
method: TokenReduction
19+
special:
20+
method: VisPruner
21+
prune_ratio: 0.778 # 0.667 0.778 0.889
22+
important_ratio: 0.5
23+
save:
24+
save_trans: False
25+
save_fake: False
26+
save_path: /path/to/save/

llmc/compression/token_reduction/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
from .sparsevlm import SparseVLM
1414
from .tome import ToMe
1515
from .visionzip import VisionZip
16+
from .vispruner import VisPruner
17+
from .visualizer import Visualizer

llmc/compression/token_reduction/visionzip.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def wrapper(self, *args, **kwargs):
454454

455455
def merger_hook(module, inputs, kwargs, layer_outs, pruning_paras):
456456
with torch.no_grad():
457-
attn_mean = pruning_paras['attn_logits'].mean(dim=0)
457+
attn_mean = pruning_paras['attn_logits'].mean(dim=0) # 16 1120, 1120 -> 1120, 1120
458458
attn_key = pruning_paras['attn_key']
459459

460460
window_index, _ = module.get_window_index(kwargs['grid_thw'])
@@ -539,10 +539,21 @@ def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras):
539539
st_idx = torch.nonzero(img_mask, as_tuple=True)[0]
540540

541541
if st_idx.numel() > 0:
542-
first, last = st_idx[0].item(), st_idx[-1].item()
543-
img_mask[first: last + 1] = ~select_mask
542+
discontinuities = torch.where(st_idx[1:] - st_idx[:-1] != 1)[0]
543+
if discontinuities.numel() > 0:
544+
raise ValueError('Visual tokens are not contiguous in input_ids!')
545+
segment_starts = [st_idx[0].item()] + [st_idx[i + 1].item() for i in discontinuities.tolist()] # noqa
546+
segment_ends = [st_idx[i].item() for i in discontinuities.tolist()] + [st_idx[-1].item()] # noqa
547+
offset = 0
548+
for first, last in zip(segment_starts, segment_ends):
549+
length = last - first + 1
550+
# [15 1502] [1505 3289]
551+
img_mask[first: last + 1] = ~select_mask[offset: offset + length]
552+
else:
553+
first, last = st_idx[0].item(), st_idx[-1].item()
554+
img_mask[first: last + 1] = ~select_mask
544555
img_mask = ~img_mask
545-
contexual_input_idx = false_pos[target_indices] + first
556+
contextual_input_idx = false_pos[target_indices] + first
546557

547558
hidden_states_filtered = inputs_embeds[:, first: last + 1][:, contextual_mask]
548559
hidden_to_merge = hidden_states_filtered[
@@ -562,7 +573,7 @@ def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras):
562573

563574
kwargs['position_ids'] = position_ids[:, :, img_mask]
564575
kwargs['attention_mask'] = attention_mask[:, img_mask]
565-
inputs_embeds[:, contexual_input_idx] = contextual_tokens
576+
inputs_embeds[:, contextual_input_idx] = contextual_tokens
566577
kwargs['inputs_embeds'] = inputs_embeds[:, img_mask]
567578
del contextual_tokens, hidden_states_filtered, hidden_to_merge, aggregated_hidden
568579
torch.cuda.empty_cache()
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import functools
2+
import torch
3+
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
4+
from .token_reduction_module import TokenReductionModule
5+
6+
7+
@TOKEN_REDUCTION_REGISTRY.register('VisPruner')
8+
class VisPruner(TokenReductionModule):
9+
def __init__(self, config, model, blocks):
10+
super().__init__(config, model, blocks)
11+
self.add_sparse_config()
12+
self.register_reduction_modules()
13+
14+
def add_sparse_config(self):
15+
self.special_config['select_layer'] = self.model.pruning_config.get('select_layer', -1)
16+
self.special_config['select_feature'] = self.model.pruning_config.get('select_feature', None)
17+
18+
self.pruning_paras = self.special_config
19+
20+
def register_reduction_modules(self):
21+
22+
def update_output_attentions_hook(module, args, kwargs):
23+
kwargs['output_attentions'] = True
24+
25+
def store_attention_hook(module, inps, outs, pruning_paras):
26+
image_attentions = outs.attentions[pruning_paras['select_layer']]
27+
if pruning_paras['select_feature'] == 'patch':
28+
image_attentions = image_attentions[:, :, 0, 1:]
29+
elif pruning_paras['select_feature'] == 'cls_patch':
30+
image_attentions = image_attentions
31+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
32+
33+
pruning_paras['image_attentions'] = image_attentions.to(inps[0].dtype)
34+
35+
def get_index_masks_hook(module, args, pruning_paras):
36+
image_features = args[0]
37+
image_attentions = pruning_paras['image_attentions']
38+
39+
B, N, C = image_features.shape
40+
device = image_features.device
41+
index_masks = torch.ones(B, N, dtype=torch.bool, device=device)
42+
43+
visual_token_num = round(
44+
self.special_config['vision_token_length'] * (1 - self.special_config['prune_ratio'])
45+
) # T
46+
important_ratio = self.pruning_paras['important_ratio'] # r
47+
important_token_num = int(visual_token_num * important_ratio) # T_imp = T * r
48+
diverse_token_num = visual_token_num - important_token_num # T_div = T * (1 - r)
49+
50+
# [VisPruner] Select important tokens using attention scores
51+
image_attentions = image_attentions.mean(dim=1) # (B, N)
52+
token_indices = image_attentions.argsort(dim=-1, descending=True) # (B, N)
53+
important_indices = token_indices[:, :important_token_num] # (B, T_imp)
54+
residual_indices = token_indices[:, important_token_num:] # (B, N - T_imp)
55+
56+
# [VisPruner] Remove duplicate tokens by iterative matching and pruning
57+
image_normalized = image_features / image_features.norm(dim=-1, keepdim=True) # (B, N, C)
58+
while diverse_token_num > 0:
59+
R = residual_indices.shape[1]
60+
r = min(8, R - diverse_token_num)
61+
if r <= 0:
62+
break
63+
64+
residual_tokens = image_normalized[torch.arange(B).unsqueeze(-1).expand(-1, R), residual_indices] # (B, R, C)
65+
a, b = residual_tokens[..., ::2, :], residual_tokens[..., 1::2, :] # (B, R // 2, C)
66+
scores = a @ b.transpose(-1, -2) # (B, R // 2, R // 2)
67+
scores = scores.max(dim=-1).values # (B, R // 2)
68+
69+
distinct_indices = scores.argsort(dim=-1, descending=True)[:, r:] # (B, R // 2 - r)
70+
residual_indices = torch.cat([
71+
residual_indices[..., ::2][torch.arange(B).unsqueeze(-1).expand(-1, R // 2 - r), distinct_indices],
72+
residual_indices[..., 1::2]
73+
], dim=-1) # (B, R - r)
74+
75+
if diverse_token_num > 0:
76+
selected_indices = torch.cat([important_indices, residual_indices], dim=-1) # (B, T)
77+
else:
78+
selected_indices = important_indices # (B, T)
79+
index_masks = torch.zeros(B, N, dtype=torch.bool, device=device)
80+
index_masks.scatter_(1, selected_indices, True)
81+
82+
pruning_paras['index_masks'] = index_masks
83+
84+
def prune_hook(module, inputs, outputs, pruning_paras):
85+
image_features = outputs
86+
index_masks = pruning_paras['index_masks']
87+
return image_features[index_masks].unsqueeze(0)
88+
89+
self.model.vision_model.vision_tower.register_forward_pre_hook(
90+
update_output_attentions_hook,
91+
with_kwargs=True
92+
)
93+
94+
self.model.vision_model.vision_tower.register_forward_hook(
95+
functools.partial(store_attention_hook, pruning_paras=self.pruning_paras),
96+
)
97+
98+
self.model.vision_projector.register_forward_pre_hook(
99+
functools.partial(get_index_masks_hook, pruning_paras=self.pruning_paras),
100+
)
101+
102+
self.model.vision_projector.register_forward_hook(
103+
functools.partial(prune_hook, pruning_paras=self.pruning_paras),
104+
)
105+
106+
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import functools
2+
3+
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
4+
5+
from .token_reduction_module import TokenReductionModule
6+
from .utils import prefill_wrapper
7+
from llmc.utils.visualizer import *
8+
9+
@TOKEN_REDUCTION_REGISTRY.register('Visualizer')
10+
class Visualizer(TokenReductionModule):
11+
def __init__(self, config, model, blocks):
12+
super().__init__(config, model, blocks)
13+
self.add_sparse_config()
14+
self.register_reduction_modules()
15+
16+
def add_sparse_config(self):
17+
self.pruning_paras = self.special_config
18+
self.pruning_paras['attentions'] = []
19+
20+
def register_reduction_modules(self):
21+
22+
@prefill_wrapper
23+
def update_attentions_hook(module, args, kwargs):
24+
kwargs['output_attentions'] = True
25+
return args, kwargs
26+
27+
@prefill_wrapper
28+
def get_images_hook(module, input_args, pruning_paras):
29+
pruning_paras['images'] = input_args[0]
30+
return input_args
31+
32+
@prefill_wrapper
33+
def get_attentions_hook(module, inps, layer_outs, pruning_paras):
34+
pruning_paras['attentions'].append(layer_outs[1])
35+
return layer_outs
36+
37+
@prefill_wrapper
38+
def visualizer_hook(module, inps, layer_outs, pruning_paras):
39+
attention_maps = pruning_paras['attentions'][0]
40+
visual_attention_maps = attention_maps[:,:,35:35+576,35:35+576]
41+
image = pruning_paras['images'][0]
42+
import pdb;pdb.set_trace()
43+
print(__file__)
44+
visualize_heads(
45+
visual_attention_maps[:,:6],
46+
cols=4,
47+
save_path=''
48+
)
49+
visualize_grid_to_grid(
50+
visual_attention_maps[0,4,:,:],
51+
300,
52+
image,
53+
grid_size=24,
54+
save_path='')
55+
56+
visualize_kept_patches(
57+
pruning_paras['images'][0],
58+
pruning_paras['visual_keep_indexs'],
59+
patch_size=14,
60+
save_path='',
61+
)
62+
return layer_outs
63+
64+
self.model.vision_model.register_forward_pre_hook(
65+
functools.partial(get_images_hook, pruning_paras=self.pruning_paras),
66+
)
67+
68+
for idx, blk in enumerate(self.blocks):
69+
if idx == 5:
70+
blk.register_forward_pre_hook(update_attentions_hook, with_kwargs=True)
71+
blk.register_forward_hook(
72+
functools.partial(get_attentions_hook, pruning_paras=self.pruning_paras),
73+
)
74+
if idx == (len(self.blocks) - 1):
75+
blk.register_forward_hook(
76+
functools.partial(visualizer_hook, pruning_paras=self.pruning_paras),
77+
)

llmc/eval/eval_vqa.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ def eval(
8989
datetime_str: str = get_datetime_str(),
9090
cli_args=None,
9191
):
92+
import argparse
93+
cli_args = argparse.Namespace(
94+
process_with_media=True,
95+
)
9296

9397
model = llmc_model.eval_name
9498
model_args = 'pretrained=' + self.model_path + ',device_map=auto'

llmc/models/llava.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,10 @@ def build_tokenizer(self):
3939
pass
4040

4141
def build_model(self):
42-
self.llava_config = LlavaConfig.from_pretrained(
43-
self.model_path, trust_remote_code=True
44-
)
42+
4543
self.vlm_model_config = AutoConfig.from_pretrained(
4644
self.model_path, trust_remote_code=True
4745
)
48-
# llava need: use_cache
49-
self.llava_config.use_cache = True
50-
self.vlm_model_config.use_cache = True
5146
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
5247

5348
self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model(
@@ -70,8 +65,8 @@ def build_model(self):
7065
self.pruning_config = {
7166
'is_video_model': False,
7267
'image_token_length': self.vlm_model_config.image_seq_length,
73-
'select_layer': self.vlm_model_config.vision_feature_layer,
74-
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
68+
'select_layer': self.vision_model.select_layer,
69+
'select_feature': self.vision_model.select_feature,
7570
'image_token_index': IMAGE_TOKEN_INDEX,
7671
'IMAGE_TOKEN_INDEX': IMAGE_TOKEN_INDEX, # for llava
7772
'vision_token_start_index': 35,

llmc/models/qwen2_5vl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ def build_model(self):
6262

6363
self.min_pixels = 256 * 28 * 28
6464
self.max_pixels = 1280 * 28 * 28
65-
logger.warning(f'min_pixels is set to: {self.min_pixels}')
66-
logger.warning(f'max_pixels is set to: {self.max_pixels}')
6765
self.processor = AutoProcessor.from_pretrained(
6866
self.model_path,
6967
min_pixels=self.min_pixels,
@@ -76,6 +74,7 @@ def build_model(self):
7674
'vision_start_token_id': self.vlm_model_config.vision_start_token_id,
7775
'vision_token_start_index': 15
7876
}
77+
self.first_turn_question = True
7978

8079
# todo: check
8180
def get_subsets_in_block(self, block):

0 commit comments

Comments
 (0)