|
| 1 | +import functools |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY |
| 6 | + |
| 7 | +from .token_reduction_module import TokenReductionModule |
| 8 | + |
| 9 | + |
| 10 | +@TOKEN_REDUCTION_REGISTRY.register('VisPruner') |
| 11 | +class VisPruner(TokenReductionModule): |
| 12 | + def __init__(self, config, model, blocks): |
| 13 | + super().__init__(config, model, blocks) |
| 14 | + self.add_sparse_config() |
| 15 | + self.register_reduction_modules() |
| 16 | + |
| 17 | + def add_sparse_config(self): |
| 18 | + self.special_config['select_layer'] = self.model.pruning_config.get( |
| 19 | + 'select_layer', -1 |
| 20 | + ) |
| 21 | + self.special_config['select_feature'] = self.model.pruning_config.get( |
| 22 | + 'select_feature', None |
| 23 | + ) |
| 24 | + |
| 25 | + self.pruning_paras = self.special_config |
| 26 | + |
| 27 | + def register_reduction_modules(self): |
| 28 | + |
| 29 | + def update_output_attentions_hook(module, args, kwargs): |
| 30 | + kwargs['output_attentions'] = True |
| 31 | + |
| 32 | + def store_attention_hook(module, inps, outs, pruning_paras): |
| 33 | + image_attentions = outs.attentions[pruning_paras['select_layer']] |
| 34 | + if pruning_paras['select_feature'] == 'patch': |
| 35 | + image_attentions = image_attentions[:, :, 0, 1:] |
| 36 | + elif pruning_paras['select_feature'] == 'cls_patch': |
| 37 | + image_attentions = image_attentions |
| 38 | + raise ValueError(f'Unexpected select feature: {self.select_feature}') |
| 39 | + |
| 40 | + pruning_paras['image_attentions'] = image_attentions.to(inps[0].dtype) |
| 41 | + |
| 42 | + def get_index_masks_hook(module, args, pruning_paras): |
| 43 | + image_features = args[0] |
| 44 | + image_attentions = pruning_paras['image_attentions'] |
| 45 | + |
| 46 | + B, N, C = image_features.shape |
| 47 | + device = image_features.device |
| 48 | + index_masks = torch.ones(B, N, dtype=torch.bool, device=device) |
| 49 | + |
| 50 | + visual_token_num = round( |
| 51 | + self.special_config['vision_token_length'] * ( |
| 52 | + 1 - self.special_config['prune_ratio'] |
| 53 | + ) |
| 54 | + ) # T |
| 55 | + important_ratio = self.pruning_paras['important_ratio'] # r |
| 56 | + important_token_num = int(visual_token_num * important_ratio) # T_imp = T * r |
| 57 | + diverse_token_num = visual_token_num - important_token_num # T_div = T * (1 - r) |
| 58 | + |
| 59 | + # [VisPruner] Select important tokens using attention scores |
| 60 | + image_attentions = image_attentions.mean(dim=1) # (B, N) |
| 61 | + token_indices = image_attentions.argsort(dim=-1, descending=True) # (B, N) |
| 62 | + important_indices = token_indices[:, :important_token_num] # (B, T_imp) |
| 63 | + residual_indices = token_indices[:, important_token_num:] # (B, N - T_imp) |
| 64 | + |
| 65 | + # [VisPruner] Remove duplicate tokens by iterative matching and pruning |
| 66 | + image_normalized = image_features / image_features.norm(dim=-1, keepdim=True) |
| 67 | + while diverse_token_num > 0: |
| 68 | + R = residual_indices.shape[1] |
| 69 | + r = min(8, R - diverse_token_num) |
| 70 | + if r <= 0: |
| 71 | + break |
| 72 | + |
| 73 | + residual_tokens = image_normalized[ |
| 74 | + torch.arange(B).unsqueeze(-1).expand(-1, R), |
| 75 | + residual_indices |
| 76 | + ] # (B, R, C) |
| 77 | + a, b = residual_tokens[..., ::2, :], residual_tokens[..., 1::2, :] # (B, R // 2, C) |
| 78 | + scores = a @ b.transpose(-1, -2) # (B, R // 2, R // 2) |
| 79 | + scores = scores.max(dim=-1).values # (B, R // 2) |
| 80 | + |
| 81 | + distinct_indices = scores.argsort(dim=-1, descending=True)[:, r:] # (B, R // 2 - r) |
| 82 | + residual_indices = torch.cat([ |
| 83 | + residual_indices[..., ::2][ |
| 84 | + torch.arange(B).unsqueeze(-1).expand(-1, R // 2 - r), |
| 85 | + distinct_indices |
| 86 | + ], |
| 87 | + residual_indices[..., 1::2] |
| 88 | + ], dim=-1) # (B, R - r) |
| 89 | + |
| 90 | + if diverse_token_num > 0: |
| 91 | + selected_indices = torch.cat([important_indices, residual_indices], dim=-1) |
| 92 | + else: |
| 93 | + selected_indices = important_indices # (B, T) |
| 94 | + index_masks = torch.zeros(B, N, dtype=torch.bool, device=device) |
| 95 | + index_masks.scatter_(1, selected_indices, True) |
| 96 | + |
| 97 | + pruning_paras['index_masks'] = index_masks |
| 98 | + |
| 99 | + def prune_hook(module, inputs, outputs, pruning_paras): |
| 100 | + image_features = outputs |
| 101 | + index_masks = pruning_paras['index_masks'] |
| 102 | + return image_features[index_masks].unsqueeze(0) |
| 103 | + |
| 104 | + self.model.vision_model.vision_tower.register_forward_pre_hook( |
| 105 | + update_output_attentions_hook, |
| 106 | + with_kwargs=True |
| 107 | + ) |
| 108 | + |
| 109 | + self.model.vision_model.vision_tower.register_forward_hook( |
| 110 | + functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), |
| 111 | + ) |
| 112 | + |
| 113 | + self.model.vision_projector.register_forward_pre_hook( |
| 114 | + functools.partial(get_index_masks_hook, pruning_paras=self.pruning_paras), |
| 115 | + ) |
| 116 | + |
| 117 | + self.model.vision_projector.register_forward_hook( |
| 118 | + functools.partial(prune_hook, pruning_paras=self.pruning_paras), |
| 119 | + ) |
0 commit comments