Skip to content

Commit ee525d4

Browse files
authored
random (#415)
1 parent 43e54ac commit ee525d4

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

llmc/compression/token_reduction/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .mustdrop import MustDrop
1010
from .prunevid import PruneVid
1111
from .pyramiddrop import PyramidDrop
12+
from .random import RandomPrune
1213
from .sparsevlm import SparseVLM
1314
from .tome import ToMe
1415
from .visionzip import VisionZip
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import functools
2+
from functools import wraps
3+
from types import MethodType
4+
5+
import torch
6+
7+
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
8+
9+
from .token_reduction_module import TokenReductionModule
10+
from .utils import prefill_wrapper
11+
12+
13+
@TOKEN_REDUCTION_REGISTRY.register('RandomPrune')
14+
class RandomPrune(TokenReductionModule):
15+
def __init__(self, config, model, blocks):
16+
super().__init__(config, model, blocks)
17+
self.add_sparse_config()
18+
self.register_reduction_modules()
19+
20+
def add_sparse_config(self):
21+
22+
self.pruning_loc = self.special_config['pruning_loc']
23+
self.special_config['image_token_length'] = self.model.pruning_config[
24+
'image_token_length'
25+
]
26+
27+
self.pruning_paras = self.special_config
28+
29+
def register_reduction_modules(self):
30+
31+
def input_hook_llava(fn, pruning_paras):
32+
@wraps(fn)
33+
def wrapper(self, *args, **kwargs):
34+
if len(args) == 0:
35+
return fn(*args, **kwargs)
36+
input_args = args[0]
37+
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
38+
return fn(*args, **kwargs)
39+
40+
input_ids = args[0]
41+
attention_mask = args[2]
42+
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
43+
pruning_paras['image_token_start_index'] = torch.where(token_indices)[
44+
0
45+
][0].item()
46+
47+
outputs = fn(*args, **kwargs)
48+
return outputs
49+
50+
return wrapper
51+
52+
@prefill_wrapper
53+
def input_hook(module, input_args, pruning_paras):
54+
input_ids = input_args[0]
55+
image_token_idxs = (
56+
input_ids[0] == pruning_paras['vision_token_index']
57+
).nonzero(as_tuple=True)[0]
58+
pruning_paras['image_token_start_index'] = image_token_idxs[0].item()
59+
60+
return input_args
61+
62+
@prefill_wrapper
63+
def random_pruning_hook(module, args, kwargs, pruning_paras):
64+
65+
rate = pruning_paras['rate']
66+
image_token_start_index = pruning_paras['image_token_start_index']
67+
image_token_length = pruning_paras['image_token_length']
68+
69+
hidden_states = args[0]
70+
causal_mask = kwargs['attention_mask']
71+
72+
device = hidden_states.device
73+
vision_indexes = torch.arange(
74+
image_token_start_index,
75+
image_token_start_index + image_token_length,
76+
device=device,
77+
)
78+
num_keep = round(image_token_length * (1 - rate))
79+
rand_idx = torch.randperm(image_token_length, device=device)[:num_keep]
80+
vision_indexes = vision_indexes[rand_idx]
81+
# keep index
82+
keep_indexs = torch.cat(
83+
(
84+
torch.arange(image_token_start_index, device=device),
85+
vision_indexes,
86+
torch.arange(
87+
image_token_start_index + image_token_length,
88+
hidden_states.shape[1],
89+
device=device,
90+
),
91+
)
92+
)
93+
94+
keep_indexs = keep_indexs.sort().values
95+
# filter hidden states &
96+
hidden_states = hidden_states[:, keep_indexs, :]
97+
# update position ids
98+
position_ids = keep_indexs.unsqueeze(0)
99+
# update attention mask
100+
if causal_mask is not None:
101+
causal_mask = causal_mask[
102+
:, :, : hidden_states.shape[1], : hidden_states.shape[1]
103+
]
104+
kwargs['attention_mask'].resize_as_(causal_mask).copy_(
105+
causal_mask.clone()
106+
)
107+
kwargs['cache_position'].resize_as_(position_ids.squeeze(0)).copy_(
108+
position_ids.squeeze(0).clone()
109+
)
110+
kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone())
111+
112+
position_embeddings = kwargs['position_embeddings']
113+
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
114+
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
115+
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
116+
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)
117+
118+
return (hidden_states,), kwargs
119+
120+
if self.model.__class__.__name__ == 'LlavaHf':
121+
self.model.embed_tokens.register_forward_pre_hook(
122+
functools.partial(input_hook, pruning_paras=self.pruning_paras)
123+
)
124+
elif self.model.__class__.__name__ == 'Llava':
125+
from llava.constants import IMAGE_TOKEN_INDEX
126+
127+
hook_fn = input_hook_llava(
128+
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
129+
self.pruning_paras,
130+
)
131+
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
132+
hook_fn, self.model.vlm_model
133+
)
134+
135+
self.blocks[self.pruning_loc].register_forward_pre_hook(
136+
functools.partial(random_pruning_hook, pruning_paras=self.pruning_paras),
137+
with_kwargs=True,
138+
)

0 commit comments

Comments
 (0)