diff --git a/llmc/compression/token_reduction/__init__.py b/llmc/compression/token_reduction/__init__.py index ec27af88..6501cf2a 100755 --- a/llmc/compression/token_reduction/__init__.py +++ b/llmc/compression/token_reduction/__init__.py @@ -6,6 +6,7 @@ from .fastv import FastV from .fastvid import FastVID from .holitom import HoliTom +from .mustdrop import MustDrop from .prunevid import PruneVid from .pyramiddrop import PyramidDrop from .sparsevlm import SparseVLM diff --git a/llmc/compression/token_reduction/mustdrop.py b/llmc/compression/token_reduction/mustdrop.py new file mode 100644 index 00000000..7f72f800 --- /dev/null +++ b/llmc/compression/token_reduction/mustdrop.py @@ -0,0 +1,200 @@ +import functools + +import torch + +from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY + +from .token_reduction_module import TokenReductionModule + + +@TOKEN_REDUCTION_REGISTRY.register('MustDrop') +class MustDrop(TokenReductionModule): + def __init__(self, config, model, blocks): + super().__init__(config, model, blocks) + self.add_sparse_config() + self.register_reduction_modules() + + def add_sparse_config(self): + self.pruning_paras = self.special_config + + def register_reduction_modules(self): + + import math + from typing import Callable, Tuple + + import numpy as np + import torch.nn.functional as F + from einops import rearrange + + def conditional_pooling( + feat: torch.Tensor, + threshold: float, + window_size: Tuple[int, int], + ) -> Tuple[Callable, Callable]: + + with torch.no_grad(): + + ws_h, ws_w = int(window_size[0]), int(window_size[1]) # 窗口尺寸,2*2 + stride_h, stride_w = ws_h, ws_w + num_token_window = stride_h * stride_w # 窗口内token数量,4 + + _, feat = ( + feat[:, :1, :], + feat[:, 1:, :], + ) # 取出cls token之外的所有tokens,一共576个vision token + B, N, D = feat.size() + base_grid_H = int(math.sqrt(N)) + base_grid_W = base_grid_H + assert ( + base_grid_H * base_grid_W == N + and base_grid_H % ws_h == 0 + and base_grid_W % ws_w == 0 + ) + + feat = rearrange(feat, 'b (h w) c -> b c h w', h=base_grid_H) + + feat = rearrange( + feat, + 'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w', + gh=base_grid_H // ws_h, + gw=base_grid_W // ws_w, + ) + b, gh, gw, c, ps_h, ps_w = feat.shape + + # Flatten mxm window for pairwise operations + tensor_flattened = feat.reshape(b, gh, gw, c, -1) + + # Expand dims for pairwise operations + tensor_1 = tensor_flattened.unsqueeze(-1) + tensor_2 = tensor_flattened.unsqueeze(-2) + + # Compute cosine similarities + sims = F.cosine_similarity(tensor_1, tensor_2, dim=3) + + # Exclude the self-similarity (i.e., similarity with oneself will be 1) + sims_mask = 1 - torch.eye(ps_h * ps_w).to(sims.device) + sims = sims * sims_mask + + # Average similarities (excluding the self-similarity) + similarity_map = sims.sum(-1).sum(-1) / ( + (ps_h * ps_w) * (ps_h * ps_w - 1) + ) + + similarity_map = rearrange( + similarity_map.unsqueeze(1), 'b c h w-> b (c h w)' + ) + + # --- adaptive section ---# + + n_B, n_H = similarity_map.shape + node_mean = torch.tensor(threshold).cuda(sims.device) + node_mean = node_mean.repeat(1, n_H) + r = torch.ge(similarity_map, node_mean).sum(dim=1).min() + # -------------# + + # get top k similar super patches + _, sim_super_patch_idxs = similarity_map.topk(r, dim=-1) + + # --- creating the mergabel and unmergable super patches + tensor = ( + torch.arange(base_grid_H * base_grid_W) + .reshape(base_grid_H, base_grid_W) + .to(feat.device) + ) + + # Repeat the tensor to create a batch of size 2 + tensor = tensor.unsqueeze(0).repeat(B, 1, 1) + + # Apply unfold operation on last two dimensions to create the sliding window + windowed_tensor = tensor.unfold(1, ws_h, stride_h).unfold( + 2, ws_w, stride_w + ) + + # Reshape the tensor to the desired shape + windowed_tensor = windowed_tensor.reshape(B, -1, num_token_window) + + # Use torch.gather to collect the desired elements + gathered_tensor = torch.gather( + windowed_tensor, + 1, + sim_super_patch_idxs.unsqueeze(-1).expand(-1, -1, num_token_window), + ) + + # Create a mask for all indices, for each batch + mask = torch.ones((B, windowed_tensor.shape[1]), dtype=bool).to( + feat.device + ) + + # Create a tensor that matches the shape of indices and fill it with False + mask_values = torch.zeros_like( + sim_super_patch_idxs, dtype=torch.bool + ).to(feat.device) + + # Use scatter_ to update the mask. + # This will set mask[b, indices[b]] = False for all b + mask.scatter_(1, sim_super_patch_idxs, mask_values) + + # Get the remaining tensor + remaining_tensor = windowed_tensor[ + mask.unsqueeze(-1).expand(-1, -1, num_token_window) + ].reshape(B, -1, num_token_window) + unm_idx = ( + remaining_tensor.reshape(B, -1).sort(dim=-1).values.unsqueeze(-1) + ) + dim_index = (num_token_window) - 1 + src_idx = gathered_tensor[:, :, :dim_index].reshape(B, -1).unsqueeze(-1) + dst_idx = gathered_tensor[:, :, dim_index].reshape(B, -1).unsqueeze(-1) + merge_idx = ( + torch.arange(src_idx.shape[1] // dim_index) + .repeat_interleave(dim_index) + .repeat(B, 1) + .unsqueeze(-1) + .to(feat.device) + ) + + def merge(x: torch.Tensor, mode='mean') -> torch.Tensor: + # TODO: num_token_window can be undefined + + x_cls, x_feat = x[:, :1, :], x[:, 1:, :] + n, t1, c = x_feat.shape + src = x_feat.gather(dim=-2, index=src_idx.expand(n, r * dim_index, c)) + dst = x_feat.gather(dim=-2, index=dst_idx.expand(n, r, c)) + unm = x_feat.gather( + dim=-2, index=unm_idx.expand(n, t1 - (r * num_token_window), c) + ) + dst = dst.scatter_reduce( + -2, merge_idx.expand(n, r * dim_index, c), src, reduce=mode + ) + x = torch.cat([dst, unm], dim=1) + x = torch.cat((x_cls, x), dim=1) + return x + + return merge + + def merge_wavg( + merge: Callable, x: torch.Tensor, size: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if size is None: + size = torch.ones_like(x[..., 0, None]) + + x = merge(x * size, mode='sum') + size = merge(size, mode='sum') + x = x / size + + return x, size + + def spatial_merge_hook(module, args, kwargs, pruning_paras): + spatial_threshold = pruning_paras['spatial_threshold'] + window_size = pruning_paras['window_size'] + hidden_states = args[0] + merge = conditional_pooling(hidden_states, spatial_threshold, window_size) + hidden_states, size = merge_wavg(merge, hidden_states, None) + return (hidden_states,) + args[1:], kwargs + + self.model.set_modality('vision') + self.model.find_blocks() + self.model.blocks[1].register_forward_pre_hook( + functools.partial(spatial_merge_hook, pruning_paras=self.pruning_paras), + with_kwargs=True, + )