-
Notifications
You must be signed in to change notification settings - Fork 66
divprune #409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
divprune #409
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,149 @@ | ||||||
| import functools | ||||||
| from functools import wraps | ||||||
| from types import MethodType | ||||||
|
|
||||||
| import torch | ||||||
|
|
||||||
| from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY | ||||||
|
|
||||||
| from .token_reduction_module import TokenReductionModule | ||||||
| from .utils import prefill_wrapper | ||||||
|
|
||||||
|
|
||||||
| def pairwise_cosine_similarity(matrix): | ||||||
| norm_matrix = matrix / matrix.norm(dim=1, keepdim=True) | ||||||
| cosine_similarity = torch.mm(norm_matrix, norm_matrix.t()) | ||||||
| return cosine_similarity | ||||||
|
|
||||||
|
|
||||||
| def divprune( | ||||||
| visual_feature_vectors, | ||||||
| image_feature_length, | ||||||
| cosine_matrix=None, | ||||||
| threshold_ratio=0.1, | ||||||
| ): | ||||||
| threshold_terms = int(round(threshold_ratio * image_feature_length)) | ||||||
| if cosine_matrix is None: | ||||||
| cosine_matrix = 1.0 - (pairwise_cosine_similarity(visual_feature_vectors)) | ||||||
|
|
||||||
| s = torch.empty( | ||||||
| threshold_terms, dtype=torch.long, device=visual_feature_vectors.device | ||||||
| ) | ||||||
| for i in range(threshold_terms): | ||||||
| if i == 0: | ||||||
| m2 = cosine_matrix | ||||||
| else: | ||||||
| m2 = torch.index_select( | ||||||
| cosine_matrix, | ||||||
| 0, | ||||||
| torch.index_select( | ||||||
| s, 0, torch.arange(0, i, device=cosine_matrix.device) | ||||||
| ), | ||||||
| ) | ||||||
|
|
||||||
| if i == 0: | ||||||
| scores = torch.topk(m2, 2, dim=0, largest=False).values[ | ||||||
| 1, : | ||||||
| ] # for distance | ||||||
| else: | ||||||
| scores = torch.min(m2, dim=0).values # for distance | ||||||
|
|
||||||
| phrase_to_add_idx = torch.argmax(scores) | ||||||
| s[i] = phrase_to_add_idx | ||||||
| return s, cosine_matrix | ||||||
|
|
||||||
|
|
||||||
| def divprune_post_hook( | ||||||
| input_ids, | ||||||
| position_ids, | ||||||
| attention_mask, | ||||||
| past_key_values, | ||||||
| inputs_embeds, | ||||||
| labels, | ||||||
| pruning_paras=None, | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| ): | ||||||
| rate = pruning_paras['rate'] | ||||||
| SYS_TOKEN_LEN = pruning_paras['image_token_start_index'] | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable name
Suggested change
|
||||||
| img_feature_len = pruning_paras['image_token_length'] | ||||||
| device = inputs_embeds.device | ||||||
| visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len] | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line hardcodes the batch index
Suggested change
|
||||||
| selected_visual_tokens, cosine_matrix = divprune( | ||||||
| visual_tokens, img_feature_len, None, threshold_ratio=rate | ||||||
| ) | ||||||
|
|
||||||
| selected_visual_tokens += SYS_TOKEN_LEN | ||||||
| keep_indexs = torch.cat( | ||||||
| ( | ||||||
| torch.arange(SYS_TOKEN_LEN, device=device), | ||||||
| selected_visual_tokens, | ||||||
| torch.arange( | ||||||
| SYS_TOKEN_LEN + img_feature_len, inputs_embeds.shape[1], device=device | ||||||
| ), | ||||||
| ) | ||||||
| ) | ||||||
| keep_indexs = keep_indexs.sort().values | ||||||
|
|
||||||
| inputs_embeds = inputs_embeds[:, keep_indexs] | ||||||
| if position_ids is not None: | ||||||
| position_ids = position_ids[:, keep_indexs, :] | ||||||
| if attention_mask is not None: | ||||||
| attention_mask = attention_mask[:, keep_indexs] | ||||||
|
|
||||||
| return ( | ||||||
| input_ids, | ||||||
| position_ids, | ||||||
| attention_mask, | ||||||
| past_key_values, | ||||||
| inputs_embeds, | ||||||
| labels, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| @TOKEN_REDUCTION_REGISTRY.register('DivPrune') | ||||||
| class DivPrune(TokenReductionModule): | ||||||
| def __init__(self, config, model, blocks): | ||||||
| super().__init__(config, model, blocks) | ||||||
| self.add_sparse_config() | ||||||
| self.register_reduction_modules() | ||||||
|
|
||||||
| def add_sparse_config(self): | ||||||
| self.special_config['image_token_length'] = self.model.pruning_config[ | ||||||
| 'image_token_length' | ||||||
| ] | ||||||
|
|
||||||
| self.pruning_paras = self.special_config | ||||||
|
|
||||||
| def register_reduction_modules(self): | ||||||
|
|
||||||
| def input_hook_llava(fn, pruning_paras): | ||||||
| @wraps(fn) | ||||||
| def wrapper(self, *args, **kwargs): | ||||||
| if len(args) == 0: | ||||||
| return fn(*args, **kwargs) | ||||||
| input_args = args[0] | ||||||
| if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: | ||||||
| return fn(*args, **kwargs) | ||||||
|
|
||||||
| input_ids = args[0] | ||||||
| attention_mask = args[2] | ||||||
|
Comment on lines
+127
to
+128
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX | ||||||
| pruning_paras['image_token_start_index'] = torch.where(token_indices)[ | ||||||
| 0 | ||||||
| ].item() | ||||||
|
Comment on lines
+129
to
+132
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic here hardcodes the batch index assert input_ids.shape[0] == 1, 'Batch size must be 1'
assert attention_mask.shape[0] == 1, 'Batch size must be 1' |
||||||
|
|
||||||
| outputs = fn(*args, **kwargs) | ||||||
|
|
||||||
| return divprune_post_hook(*outputs, pruning_paras=pruning_paras) | ||||||
|
|
||||||
| return wrapper | ||||||
|
|
||||||
| if self.model.__class__.__name__ == 'Llava': | ||||||
| from llava.constants import IMAGE_TOKEN_INDEX | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| hook_fn = input_hook_llava( | ||||||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal, | ||||||
| self.pruning_paras, | ||||||
| ) | ||||||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( | ||||||
| hook_fn, self.model.vlm_model | ||||||
| ) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of
norm_matrixcan result in a division by zero if any vector in the inputmatrixhas a norm of zero, leading toNaNvalues. Adding a small epsilon to the denominator ensures numerical stability.