Skip to content

Commit 50da6b6

Browse files
authored
divprune (#409)
1 parent cd5cfd9 commit 50da6b6

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

llmc/compression/token_reduction/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base_blockwise_token_reduction import TokenReduction
22
from .dart import DART
3+
from .divprune import DivPrune
34
from .dycoke import DyCoke
45
from .fastervlm import FasterVLM
56
from .fastv import FastV
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import functools
2+
from functools import wraps
3+
from types import MethodType
4+
5+
import torch
6+
7+
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
8+
9+
from .token_reduction_module import TokenReductionModule
10+
from .utils import prefill_wrapper
11+
12+
13+
def pairwise_cosine_similarity(matrix):
14+
norm_matrix = matrix / matrix.norm(dim=1, keepdim=True)
15+
cosine_similarity = torch.mm(norm_matrix, norm_matrix.t())
16+
return cosine_similarity
17+
18+
19+
def divprune(
20+
visual_feature_vectors,
21+
image_feature_length,
22+
cosine_matrix=None,
23+
threshold_ratio=0.1,
24+
):
25+
threshold_terms = int(round(threshold_ratio * image_feature_length))
26+
if cosine_matrix is None:
27+
cosine_matrix = 1.0 - (pairwise_cosine_similarity(visual_feature_vectors))
28+
29+
s = torch.empty(
30+
threshold_terms, dtype=torch.long, device=visual_feature_vectors.device
31+
)
32+
for i in range(threshold_terms):
33+
if i == 0:
34+
m2 = cosine_matrix
35+
else:
36+
m2 = torch.index_select(
37+
cosine_matrix,
38+
0,
39+
torch.index_select(
40+
s, 0, torch.arange(0, i, device=cosine_matrix.device)
41+
),
42+
)
43+
44+
if i == 0:
45+
scores = torch.topk(m2, 2, dim=0, largest=False).values[
46+
1, :
47+
] # for distance
48+
else:
49+
scores = torch.min(m2, dim=0).values # for distance
50+
51+
phrase_to_add_idx = torch.argmax(scores)
52+
s[i] = phrase_to_add_idx
53+
return s, cosine_matrix
54+
55+
56+
def divprune_post_hook(
57+
input_ids,
58+
position_ids,
59+
attention_mask,
60+
past_key_values,
61+
inputs_embeds,
62+
labels,
63+
pruning_paras=None,
64+
):
65+
rate = pruning_paras['rate']
66+
SYS_TOKEN_LEN = pruning_paras['image_token_start_index']
67+
img_feature_len = pruning_paras['image_token_length']
68+
device = inputs_embeds.device
69+
visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len]
70+
selected_visual_tokens, cosine_matrix = divprune(
71+
visual_tokens, img_feature_len, None, threshold_ratio=rate
72+
)
73+
74+
selected_visual_tokens += SYS_TOKEN_LEN
75+
keep_indexs = torch.cat(
76+
(
77+
torch.arange(SYS_TOKEN_LEN, device=device),
78+
selected_visual_tokens,
79+
torch.arange(
80+
SYS_TOKEN_LEN + img_feature_len, inputs_embeds.shape[1], device=device
81+
),
82+
)
83+
)
84+
keep_indexs = keep_indexs.sort().values
85+
86+
inputs_embeds = inputs_embeds[:, keep_indexs]
87+
if position_ids is not None:
88+
position_ids = position_ids[:, keep_indexs, :]
89+
if attention_mask is not None:
90+
attention_mask = attention_mask[:, keep_indexs]
91+
92+
return (
93+
input_ids,
94+
position_ids,
95+
attention_mask,
96+
past_key_values,
97+
inputs_embeds,
98+
labels,
99+
)
100+
101+
102+
@TOKEN_REDUCTION_REGISTRY.register('DivPrune')
103+
class DivPrune(TokenReductionModule):
104+
def __init__(self, config, model, blocks):
105+
super().__init__(config, model, blocks)
106+
self.add_sparse_config()
107+
self.register_reduction_modules()
108+
109+
def add_sparse_config(self):
110+
self.special_config['image_token_length'] = self.model.pruning_config[
111+
'image_token_length'
112+
]
113+
114+
self.pruning_paras = self.special_config
115+
116+
def register_reduction_modules(self):
117+
118+
def input_hook_llava(fn, pruning_paras):
119+
@wraps(fn)
120+
def wrapper(self, *args, **kwargs):
121+
if len(args) == 0:
122+
return fn(*args, **kwargs)
123+
input_args = args[0]
124+
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
125+
return fn(*args, **kwargs)
126+
127+
input_ids = args[0]
128+
attention_mask = args[2]
129+
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
130+
pruning_paras['image_token_start_index'] = torch.where(token_indices)[
131+
0
132+
].item()
133+
134+
outputs = fn(*args, **kwargs)
135+
136+
return divprune_post_hook(*outputs, pruning_paras=pruning_paras)
137+
138+
return wrapper
139+
140+
if self.model.__class__.__name__ == 'Llava':
141+
from llava.constants import IMAGE_TOKEN_INDEX
142+
143+
hook_fn = input_hook_llava(
144+
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
145+
self.pruning_paras,
146+
)
147+
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
148+
hook_fn, self.model.vlm_model
149+
)

0 commit comments

Comments
 (0)