Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions llmc/compression/token_reduction/mustdrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, config, model, blocks):
self.register_reduction_modules()

def add_sparse_config(self):
self.pruning_loc = self.special_config['pruning_loc']

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Directly accessing self.special_config['pruning_loc'] can raise a KeyError if the key is missing. Consider adding a check to ensure the key exists before accessing it to prevent potential crashes.

self.pruning_paras = self.special_config

def register_reduction_modules(self):
Expand All @@ -30,6 +31,7 @@ def conditional_pooling(
feat: torch.Tensor,
threshold: float,
window_size: Tuple[int, int],
fix_r: int = 0,
) -> Tuple[Callable, Callable]:

with torch.no_grad():
Expand Down Expand Up @@ -91,7 +93,8 @@ def conditional_pooling(
node_mean = node_mean.repeat(1, n_H)
r = torch.ge(similarity_map, node_mean).sum(dim=1).min()
# -------------#

if fix_r != 0:
r = fix_r
# get top k similar super patches
_, sim_super_patch_idxs = similarity_map.topk(r, dim=-1)

Expand Down Expand Up @@ -184,17 +187,20 @@ def merge_wavg(

return x, size

def spatial_merge_hook(module, args, kwargs, pruning_paras):
def spatial_merge_hook(module, args, kwargs, layer_outs, pruning_paras):
spatial_threshold = pruning_paras['spatial_threshold']
window_size = pruning_paras['window_size']
hidden_states = args[0]
merge = conditional_pooling(hidden_states, spatial_threshold, window_size)
hidden_states = layer_outs[0]
fix_r = 0
if pruning_paras.get('retained_tokens', None) is not None:
retained_tokens = pruning_paras['retained_tokens']
fix_r = (pruning_paras['vision_token_length'] - retained_tokens) \
// (window_size[0] * window_size[1] - 1)
Comment on lines +197 to +198

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The denominator window_size[0] * window_size[1] - 1 in the fix_r calculation can be zero if window_size is (1, 1), leading to a ZeroDivisionError. Add a check to prevent this.

if (window_size[0] * window_size[1] - 1) > 0:
    fix_r = (pruning_paras['vision_token_length'] - retained_tokens) // (window_size[0] * window_size[1] - 1)
else:
    fix_r = 0

merge = conditional_pooling(hidden_states, spatial_threshold, window_size, fix_r)
hidden_states, size = merge_wavg(merge, hidden_states, None)
return (hidden_states,) + args[1:], kwargs
return (hidden_states,)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The spatial_merge_hook is registered as a forward_hook. Returning only (hidden_states,) replaces the block's entire output, potentially discarding other outputs like attention weights, which can cause errors in subsequent layers. Ensure the hook preserves other elements of the output tuple.

return (hidden_states,) + layer_outs[1:]


self.model.set_modality('vision')
self.model.find_blocks()
self.model.blocks[1].register_forward_pre_hook(
self.blocks[self.pruning_loc - 1].register_forward_hook(
functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras),
with_kwargs=True,
)