Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions configs/sparsification/methods/SparseVLM/sparsevlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ sparse:
special:
method: SparseVLM
pruning_loc: [2, 6, 15]
retained_tokens: 192
prune_flag: True
reduction_ratio: 0.6667
merge_flag: True
save:
save_trans: False
Expand Down
26 changes: 14 additions & 12 deletions llmc/compression/token_reduction/dart.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
import math
from types import MethodType

import torch

Expand All @@ -24,26 +24,20 @@ def add_sparse_config(self):
def register_reduction_modules(self):

@prefill_wrapper
def vtoken_length_hook(module, input_args, pruning_paras):

input_ids = input_args[0]
def vtoken_length_hook(module, args, pruning_paras):
input_ids = args[0]
token_indices = torch.where(
input_ids[0] == pruning_paras['vision_token_index']
)[0]
pruning_paras['vision_token_length'] = token_indices.shape[0]

return input_args

@prefill_wrapper
def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_idx):

past_key_value = kwargs['past_key_value']
if past_key_value is None:
raise ValueError('DART needs past_key_value but got None.')
pruning_paras['any_states'] = past_key_value.key_cache[layer_idx]

return layer_outs

@prefill_wrapper
def pruning_hook(module, args, kwargs, pruning_paras, normlayer):

Expand Down Expand Up @@ -95,9 +89,17 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
return (hidden_states,), kwargs

if self.special_config['vision_token_length'] is None:
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
)
if self.model.__class__.__name__ == 'Llava':
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
self.vtoken_length_for_llava_hook(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
), self.model.vlm_model
)
Comment on lines +92 to +98

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling LLaVA models by monkey-patching prepare_inputs_labels_for_multimodal is also present in llmc/compression/token_reduction/fastv.py. This code duplication can lead to maintenance issues if this logic needs to be updated in the future.

To improve maintainability, consider refactoring this logic into a shared method in the TokenReductionModule base class. This would likely involve:

  1. Moving the vtoken_length_hook function (which is also duplicated) to token_reduction_module.py or a shared utility file.
  2. Creating a new method in TokenReductionModule, for example _register_vtoken_length_hook, that contains this if/else block.
  3. Calling this new method from the register_reduction_modules method of both DART and FastV subclasses.

else:
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
)

self.blocks[self.pruning_loc - 1].register_forward_hook(
functools.partial(
Expand Down
15 changes: 12 additions & 3 deletions llmc/compression/token_reduction/fastv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from types import MethodType

import torch

Expand Down Expand Up @@ -104,9 +105,17 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):
return (hidden_states,), kwargs

if self.special_config['vision_token_length'] is None:
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
)
if self.model.__class__.__name__ == 'Llava':
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
self.vtoken_length_for_llava_hook(
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
self.pruning_paras
), self.model.vlm_model
)
else:
self.model.embed_tokens.register_forward_pre_hook(
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
)

self.blocks[self.pruning_loc - 1].register_forward_pre_hook(
functools.partial(update_output_attentions_hook, pruning_paras=self.pruning_paras),
Expand Down
Loading