Skip to content

Commit 21dcfdc

Browse files
committed
dycoke
1 parent ec771c0 commit 21dcfdc

File tree

4 files changed

+167
-0
lines changed

4 files changed

+167
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: Llava OneVision
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [pretrain, transformed]
9+
type: vqa
10+
name: [mme]
11+
download: False
12+
path: MME dataset path
13+
bs: 1
14+
inference_per_block: False
15+
sparse:
16+
method: TokenReduction
17+
special:
18+
method: DyCoke
19+
dycoke_layer_idx: 3
20+
num_tokens_per_frame: 196
21+
merging_ratio: 0.7
22+
dycoke_radio: 0.7
23+
save:
24+
save_trans: False
25+
save_fake: False
26+
save_path: /path/to/save/

llmc/compression/token_reduction/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .base_blockwise_token_reduction import TokenReduction
2+
from .dycoke import DyCoke
23
from .fastervlm import FasterVLM
34
from .fastv import FastV
45
from .pyramiddrop import PyramidDrop
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import functools
2+
from typing import List, Optional, Tuple, Union
3+
4+
import torch
5+
import torch.nn.functional as F
6+
from loguru import logger
7+
8+
try:
9+
from llava.model.llava_arch import LlavaMetaForCausalLM
10+
except ModuleNotFoundError:
11+
logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.')
12+
from transformers.cache_utils import Cache, DynamicCache
13+
14+
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
15+
16+
from .token_reduction_module import TokenReductionModule
17+
from .utils import prefill_wrapper
18+
19+
20+
def dycole_ttm(image_feature, pruning_paras):
21+
bs, num_tokens_per_frame, _ = image_feature.shape
22+
image_feature = image_feature.flatten(0, 1)
23+
# Split frames into tokens
24+
num_frames = image_feature.shape[0] // num_tokens_per_frame
25+
merging_ratio = 1 - pruning_paras['merging_ratio']
26+
# Calculate similarities between adjacent even frames
27+
similarities = []
28+
for i in range(0, num_frames - 1, 2):
29+
# Get tokens for adjacent frames
30+
frame1_tokens = image_feature[
31+
i * num_tokens_per_frame: (i + 1) * num_tokens_per_frame
32+
]
33+
frame2_tokens = image_feature[
34+
(i + 1) * num_tokens_per_frame: (i + 2) * num_tokens_per_frame
35+
]
36+
37+
# Calculate cosine similarity between normalized tokens
38+
frame1_norm = torch.nn.functional.normalize(frame1_tokens, p=2, dim=1)
39+
frame2_norm = torch.nn.functional.normalize(frame2_tokens, p=2, dim=1)
40+
similarity = torch.nn.functional.cosine_similarity(
41+
frame1_norm, frame2_norm, dim=1
42+
)
43+
similarities.append(similarity)
44+
45+
similarities = torch.stack(
46+
[torch.tensor(similarity) for similarity in similarities]
47+
)
48+
49+
# Process even frames
50+
modified_image_feature = []
51+
for i in range(0, num_frames - 1, 2):
52+
frame1_tokens = image_feature[
53+
i * num_tokens_per_frame: (i + 1) * num_tokens_per_frame
54+
]
55+
frame2_tokens = image_feature[
56+
(i + 1) * num_tokens_per_frame: (i + 2) * num_tokens_per_frame
57+
]
58+
59+
avg_similarity = similarities[i // 2]
60+
num_tokens_to_keep = int(merging_ratio * num_tokens_per_frame)
61+
tokens_to_keep = avg_similarity.topk(num_tokens_to_keep, largest=False).indices
62+
63+
modified_image_feature.append(frame1_tokens)
64+
modified_image_feature.append(frame2_tokens[tokens_to_keep])
65+
66+
# Process odd frames
67+
odd_similarities = []
68+
for i in range(0, num_frames - 4, 4):
69+
frame1_tokens = image_feature[
70+
i * num_tokens_per_frame: (i + 1) * num_tokens_per_frame
71+
]
72+
frame2_tokens = image_feature[
73+
(i + 2) * num_tokens_per_frame: (i + 3) * num_tokens_per_frame
74+
]
75+
76+
similarity = torch.nn.functional.cosine_similarity(
77+
frame1_tokens, frame2_tokens, dim=1
78+
)
79+
odd_similarities.append(similarity)
80+
81+
odd_similarities = torch.stack(
82+
[torch.tensor(similarity) for similarity in odd_similarities]
83+
)
84+
85+
for i in range(0, num_frames - 4, 4):
86+
frame1_tokens = image_feature[
87+
i * num_tokens_per_frame: (i + 1) * num_tokens_per_frame
88+
]
89+
frame2_tokens = image_feature[
90+
(i + 2) * num_tokens_per_frame: (i + 3) * num_tokens_per_frame
91+
]
92+
93+
avg_similarity = odd_similarities[i // 4]
94+
num_tokens_to_keep = int(merging_ratio * num_tokens_per_frame)
95+
tokens_to_keep = avg_similarity.topk(num_tokens_to_keep, largest=False).indices
96+
97+
modified_image_feature[i] = frame1_tokens
98+
modified_image_feature[i + 2] = frame2_tokens[tokens_to_keep]
99+
100+
# Combine all tokens
101+
combined_tokens = torch.cat(modified_image_feature, dim=0).unsqueeze(0)
102+
return combined_tokens
103+
104+
105+
def add_dycole_ttm_to_get_2dPool(model, post_hook_fn, pruning_paras):
106+
original_fn = model.get_2dPool
107+
108+
def wrapped_fn(*args, **kwargs):
109+
result = original_fn(*args, **kwargs)
110+
return post_hook_fn(result, pruning_paras)
111+
112+
model.get_2dPool = wrapped_fn
113+
114+
115+
@TOKEN_REDUCTION_REGISTRY.register('DyCoke')
116+
class DyCoke(TokenReductionModule):
117+
def __init__(self, config, model, blocks):
118+
super().__init__(config, model, blocks)
119+
self.add_sparse_config()
120+
self.register_reduction_modules()
121+
122+
def add_sparse_config(self):
123+
self.special_config['different_token_idxs'] = []
124+
self.dycoke_layer_idx = self.special_config['dycoke_layer_idx']
125+
self.model.model.pruning_paras = self.special_config
126+
127+
def register_reduction_modules(self):
128+
129+
if isinstance(self.model.model, LlavaMetaForCausalLM):
130+
add_dycole_ttm_to_get_2dPool(
131+
self.model.model, dycole_ttm, self.model.model.pruning_paras
132+
)

llmc/models/llava_onevision.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ def build_model(self):
105105

106106
self.processor = None
107107

108+
def find_blocks(self):
109+
if self.get_modality() == 'language':
110+
super().find_blocks()
111+
elif self.get_modality() == 'vision':
112+
self.blocks = self.vision_model.vision_tower.vision_model.encoder.layers
113+
else:
114+
raise Exception(f'Llava_OneVision do not support {self.get_modality()} modality.')
115+
108116

109117
@MODEL_REGISTRY
110118
class Llava_OneVision_Eval(LLaVA_OV):

0 commit comments

Comments
 (0)