-
Notifications
You must be signed in to change notification settings - Fork 66
divprune #409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
divprune #409
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @chengtao-lv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request integrates a new token reduction method, "DivPrune," into the llmc library. This method is designed to efficiently prune visual tokens in multimodal models, particularly LLaVA, by selecting a diverse subset of image features. The aim is to reduce computational load and improve inference efficiency without sacrificing model performance.
Highlights
- New Token Reduction Method: Introduces
DivPrune, a novel token reduction technique designed to prune visual tokens based on diversity, aiming to improve efficiency in multimodal models. - DivPrune Algorithm Implementation: Implements the core
divprunefunction, which selects a subset of visual feature vectors by maximizing diversity using pairwise cosine similarity. This includes adivprune_post_hookfor applying the reduction to model inputs. - LLaVA Model Integration: Integrates
DivPrunespecifically with the LLaVA model by hooking into itsprepare_inputs_labels_for_multimodalmethod to apply the token reduction to image features before model processing.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces DivPrune, a new token reduction method based on selecting diverse visual tokens using cosine similarity. There are several areas for improvement regarding robustness, clarity, and adherence to best practices. Key issues include potential runtime errors from division by zero and incorrect function arguments, brittleness due to hardcoded batch and argument indices, and a local import. Addressing these points will significantly improve the code's quality and maintainability.
| past_key_values, | ||
| inputs_embeds, | ||
| labels, | ||
| pruning_paras=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pruning_paras argument defaults to None but is accessed on line 65 without a check, which can raise a TypeError if the function is called without this argument. Removing the default value makes it a required argument, clarifying the function's dependencies.
| pruning_paras=None, | |
| pruning_paras, |
| SYS_TOKEN_LEN = pruning_paras['image_token_start_index'] | ||
| img_feature_len = pruning_paras['image_token_length'] | ||
| device = inputs_embeds.device | ||
| visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line hardcodes the batch index [0], assuming a batch size of 1. This assumption is fragile and will cause errors if the function is used with a batch size greater than 1. Add an assertion to enforce a batch size of 1.
| visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len] | |
| assert inputs_embeds.shape[0] == 1, 'Batch size must be 1' |
| token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX | ||
| pruning_paras['image_token_start_index'] = torch.where(token_indices)[ | ||
| 0 | ||
| ].item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic here hardcodes the batch index [0] for input_ids and attention_mask, which assumes a batch size of 1. This will fail for multi-batch inputs. Add an assertion at the beginning of the wrapper to enforce a batch size of 1.
assert input_ids.shape[0] == 1, 'Batch size must be 1'
assert attention_mask.shape[0] == 1, 'Batch size must be 1'|
|
||
|
|
||
| def pairwise_cosine_similarity(matrix): | ||
| norm_matrix = matrix / matrix.norm(dim=1, keepdim=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of norm_matrix can result in a division by zero if any vector in the input matrix has a norm of zero, leading to NaN values. Adding a small epsilon to the denominator ensures numerical stability.
| norm_matrix = matrix / matrix.norm(dim=1, keepdim=True) | |
| norm_matrix = matrix / (matrix.norm(dim=1, keepdim=True) + 1e-8) |
| pruning_paras=None, | ||
| ): | ||
| rate = pruning_paras['rate'] | ||
| SYS_TOKEN_LEN = pruning_paras['image_token_start_index'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name SYS_TOKEN_LEN is misleading as it stores a start index, not a length. Renaming it to image_token_start_index would improve code clarity. Update its usages in this function as well.
| SYS_TOKEN_LEN = pruning_paras['image_token_start_index'] | |
| image_token_start_index = pruning_paras['image_token_start_index'] |
| input_ids = args[0] | ||
| attention_mask = args[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return wrapper | ||
|
|
||
| if self.model.__class__.__name__ == 'Llava': | ||
| from llava.constants import IMAGE_TOKEN_INDEX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.