Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions configs/sparsification/methods/PruneVid/prunevid.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
base:
seed: &seed 42
model:
type: Llava OneVision
path: model path
torch_dtype: auto
eval:
eval_pos: [pretrain, transformed]
type: vqa
name: [mme]
download: False
path: MME dataset path
bs: 1
inference_per_block: False
sparse:
method: TokenReduction
special:
method: PruneVid
lora_alpha: 14
selected_layers: 10
alphas: 0.4
taus: 0.8
temporal_segment_ratios: 0.25
cluster_ratios: 0.5
save:
save_trans: False
save_fake: False
save_path: /path/to/save/
1 change: 1 addition & 0 deletions llmc/compression/token_reduction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .dycoke import DyCoke
from .fastervlm import FasterVLM
from .fastv import FastV
from .prunevid import PruneVid
from .pyramiddrop import PyramidDrop
from .sparsevlm import SparseVLM
from .tome import ToMe
Expand Down
14 changes: 2 additions & 12 deletions llmc/compression/token_reduction/dycoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

from .token_reduction_module import TokenReductionModule
from .utils import prefill_wrapper
from .utils import add_post_hook_to_get_2dPool


def dycole_ttm(image_feature, pruning_paras):
Expand Down Expand Up @@ -102,16 +102,6 @@ def dycole_ttm(image_feature, pruning_paras):
return combined_tokens


def add_dycole_ttm_to_get_2dPool(model, post_hook_fn, pruning_paras):
original_fn = model.get_2dPool

def wrapped_fn(*args, **kwargs):
result = original_fn(*args, **kwargs)
return post_hook_fn(result, pruning_paras)

model.get_2dPool = wrapped_fn


@TOKEN_REDUCTION_REGISTRY.register('DyCoke')
class DyCoke(TokenReductionModule):
def __init__(self, config, model, blocks):
Expand All @@ -127,6 +117,6 @@ def add_sparse_config(self):
def register_reduction_modules(self):

if isinstance(self.model.model, LlavaMetaForCausalLM):
add_dycole_ttm_to_get_2dPool(
add_post_hook_to_get_2dPool(
self.model.model, dycole_ttm, self.model.model.pruning_paras
)
Loading