diff --git a/llmc/compression/token_reduction/mustdrop.py b/llmc/compression/token_reduction/mustdrop.py index 7f72f800..ab00fe53 100644 --- a/llmc/compression/token_reduction/mustdrop.py +++ b/llmc/compression/token_reduction/mustdrop.py @@ -15,6 +15,7 @@ def __init__(self, config, model, blocks): self.register_reduction_modules() def add_sparse_config(self): + self.pruning_loc = self.special_config['pruning_loc'] self.pruning_paras = self.special_config def register_reduction_modules(self): @@ -30,6 +31,7 @@ def conditional_pooling( feat: torch.Tensor, threshold: float, window_size: Tuple[int, int], + fix_r: int = 0, ) -> Tuple[Callable, Callable]: with torch.no_grad(): @@ -91,7 +93,8 @@ def conditional_pooling( node_mean = node_mean.repeat(1, n_H) r = torch.ge(similarity_map, node_mean).sum(dim=1).min() # -------------# - + if fix_r != 0: + r = fix_r # get top k similar super patches _, sim_super_patch_idxs = similarity_map.topk(r, dim=-1) @@ -184,17 +187,20 @@ def merge_wavg( return x, size - def spatial_merge_hook(module, args, kwargs, pruning_paras): + def spatial_merge_hook(module, args, kwargs, layer_outs, pruning_paras): spatial_threshold = pruning_paras['spatial_threshold'] window_size = pruning_paras['window_size'] - hidden_states = args[0] - merge = conditional_pooling(hidden_states, spatial_threshold, window_size) + hidden_states = layer_outs[0] + fix_r = 0 + if pruning_paras.get('retained_tokens', None) is not None: + retained_tokens = pruning_paras['retained_tokens'] + fix_r = (pruning_paras['vision_token_length'] - retained_tokens) \ + // (window_size[0] * window_size[1] - 1) + merge = conditional_pooling(hidden_states, spatial_threshold, window_size, fix_r) hidden_states, size = merge_wavg(merge, hidden_states, None) - return (hidden_states,) + args[1:], kwargs + return (hidden_states,) - self.model.set_modality('vision') - self.model.find_blocks() - self.model.blocks[1].register_forward_pre_hook( + self.blocks[self.pruning_loc - 1].register_forward_hook( functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras), with_kwargs=True, )