Skip to content

Commit bbe1658

Browse files
authored
prunevid (#389)
1 parent 62d47ef commit bbe1658

File tree

5 files changed

+452
-12
lines changed

5 files changed

+452
-12
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: Llava OneVision
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [pretrain, transformed]
9+
type: vqa
10+
name: [mme]
11+
download: False
12+
path: MME dataset path
13+
bs: 1
14+
inference_per_block: False
15+
sparse:
16+
method: TokenReduction
17+
special:
18+
method: PruneVid
19+
lora_alpha: 14
20+
selected_layers: 10
21+
alphas: 0.4
22+
taus: 0.8
23+
temporal_segment_ratios: 0.25
24+
cluster_ratios: 0.5
25+
save:
26+
save_trans: False
27+
save_fake: False
28+
save_path: /path/to/save/

llmc/compression/token_reduction/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .dycoke import DyCoke
33
from .fastervlm import FasterVLM
44
from .fastv import FastV
5+
from .prunevid import PruneVid
56
from .pyramiddrop import PyramidDrop
67
from .sparsevlm import SparseVLM
78
from .tome import ToMe

llmc/compression/token_reduction/dycoke.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
1515

1616
from .token_reduction_module import TokenReductionModule
17-
from .utils import prefill_wrapper
17+
from .utils import add_post_hook_to_get_2dPool
1818

1919

2020
def dycole_ttm(image_feature, pruning_paras):
@@ -102,16 +102,6 @@ def dycole_ttm(image_feature, pruning_paras):
102102
return combined_tokens
103103

104104

105-
def add_dycole_ttm_to_get_2dPool(model, post_hook_fn, pruning_paras):
106-
original_fn = model.get_2dPool
107-
108-
def wrapped_fn(*args, **kwargs):
109-
result = original_fn(*args, **kwargs)
110-
return post_hook_fn(result, pruning_paras)
111-
112-
model.get_2dPool = wrapped_fn
113-
114-
115105
@TOKEN_REDUCTION_REGISTRY.register('DyCoke')
116106
class DyCoke(TokenReductionModule):
117107
def __init__(self, config, model, blocks):
@@ -127,6 +117,6 @@ def add_sparse_config(self):
127117
def register_reduction_modules(self):
128118

129119
if isinstance(self.model.model, LlavaMetaForCausalLM):
130-
add_dycole_ttm_to_get_2dPool(
120+
add_post_hook_to_get_2dPool(
131121
self.model.model, dycole_ttm, self.model.model.pruning_paras
132122
)

0 commit comments

Comments
 (0)