-
Notifications
You must be signed in to change notification settings - Fork 66
dycoke #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dycoke #388
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| base: | ||
| seed: &seed 42 | ||
| model: | ||
| type: Llava OneVision | ||
| path: model path | ||
| torch_dtype: auto | ||
| eval: | ||
| eval_pos: [pretrain, transformed] | ||
| type: vqa | ||
| name: [mme] | ||
| download: False | ||
| path: MME dataset path | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| bs: 1 | ||
| inference_per_block: False | ||
| sparse: | ||
| method: TokenReduction | ||
| special: | ||
| method: DyCoke | ||
| dycoke_layer_idx: 3 | ||
| num_tokens_per_frame: 196 | ||
| merging_ratio: 0.7 | ||
| dycoke_radio: 0.7 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a parameter Additionally, the Python code ( Could you clarify:
|
||
| save: | ||
| save_trans: False | ||
| save_fake: False | ||
| save_path: /path/to/save/ | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,132 @@ | ||||||
| import functools | ||||||
| from typing import List, Optional, Tuple, Union | ||||||
|
|
||||||
| import torch | ||||||
| import torch.nn.functional as F | ||||||
| from loguru import logger | ||||||
|
|
||||||
| try: | ||||||
| from llava.model.llava_arch import LlavaMetaForCausalLM | ||||||
| except ModuleNotFoundError: | ||||||
| logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.') | ||||||
| from transformers.cache_utils import Cache, DynamicCache | ||||||
|
|
||||||
| from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY | ||||||
|
|
||||||
| from .token_reduction_module import TokenReductionModule | ||||||
| from .utils import prefill_wrapper | ||||||
|
|
||||||
|
|
||||||
| def dycole_ttm(image_feature, pruning_paras): | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There seems to be a typo in the function name. Should
Suggested change
|
||||||
| bs, num_tokens_per_frame, _ = image_feature.shape | ||||||
| image_feature = image_feature.flatten(0, 1) | ||||||
| # Split frames into tokens | ||||||
| num_frames = image_feature.shape[0] // num_tokens_per_frame | ||||||
| merging_ratio = 1 - pruning_paras['merging_ratio'] | ||||||
| # Calculate similarities between adjacent even frames | ||||||
| similarities = [] | ||||||
| for i in range(0, num_frames - 1, 2): | ||||||
| # Get tokens for adjacent frames | ||||||
| frame1_tokens = image_feature[ | ||||||
| i * num_tokens_per_frame: (i + 1) * num_tokens_per_frame | ||||||
| ] | ||||||
| frame2_tokens = image_feature[ | ||||||
| (i + 1) * num_tokens_per_frame: (i + 2) * num_tokens_per_frame | ||||||
| ] | ||||||
|
|
||||||
| # Calculate cosine similarity between normalized tokens | ||||||
| frame1_norm = torch.nn.functional.normalize(frame1_tokens, p=2, dim=1) | ||||||
| frame2_norm = torch.nn.functional.normalize(frame2_tokens, p=2, dim=1) | ||||||
| similarity = torch.nn.functional.cosine_similarity( | ||||||
| frame1_norm, frame2_norm, dim=1 | ||||||
| ) | ||||||
| similarities.append(similarity) | ||||||
|
|
||||||
| similarities = torch.stack( | ||||||
| [torch.tensor(similarity) for similarity in similarities] | ||||||
| ) | ||||||
|
|
||||||
| # Process even frames | ||||||
| modified_image_feature = [] | ||||||
| for i in range(0, num_frames - 1, 2): | ||||||
| frame1_tokens = image_feature[ | ||||||
| i * num_tokens_per_frame: (i + 1) * num_tokens_per_frame | ||||||
| ] | ||||||
| frame2_tokens = image_feature[ | ||||||
| (i + 1) * num_tokens_per_frame: (i + 2) * num_tokens_per_frame | ||||||
| ] | ||||||
|
|
||||||
| avg_similarity = similarities[i // 2] | ||||||
| num_tokens_to_keep = int(merging_ratio * num_tokens_per_frame) | ||||||
| tokens_to_keep = avg_similarity.topk(num_tokens_to_keep, largest=False).indices | ||||||
|
|
||||||
| modified_image_feature.append(frame1_tokens) | ||||||
| modified_image_feature.append(frame2_tokens[tokens_to_keep]) | ||||||
|
Comment on lines
+28
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The first loop Could you review the logic for handling an odd number of frames? One common approach is to append the last frame unprocessed if A docstring explaining the overall algorithm, inputs, and outputs of this function would also be very beneficial for future understanding and maintenance. |
||||||
|
|
||||||
| # Process odd frames | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment Perhaps a more descriptive comment like |
||||||
| odd_similarities = [] | ||||||
| for i in range(0, num_frames - 4, 4): | ||||||
| frame1_tokens = image_feature[ | ||||||
| i * num_tokens_per_frame: (i + 1) * num_tokens_per_frame | ||||||
| ] | ||||||
| frame2_tokens = image_feature[ | ||||||
| (i + 2) * num_tokens_per_frame: (i + 3) * num_tokens_per_frame | ||||||
| ] | ||||||
|
|
||||||
| similarity = torch.nn.functional.cosine_similarity( | ||||||
| frame1_tokens, frame2_tokens, dim=1 | ||||||
| ) | ||||||
| odd_similarities.append(similarity) | ||||||
|
|
||||||
| odd_similarities = torch.stack( | ||||||
| [torch.tensor(similarity) for similarity in odd_similarities] | ||||||
| ) | ||||||
|
|
||||||
| for i in range(0, num_frames - 4, 4): | ||||||
| frame1_tokens = image_feature[ | ||||||
| i * num_tokens_per_frame: (i + 1) * num_tokens_per_frame | ||||||
| ] | ||||||
| frame2_tokens = image_feature[ | ||||||
| (i + 2) * num_tokens_per_frame: (i + 3) * num_tokens_per_frame | ||||||
| ] | ||||||
|
|
||||||
| avg_similarity = odd_similarities[i // 4] | ||||||
| num_tokens_to_keep = int(merging_ratio * num_tokens_per_frame) | ||||||
| tokens_to_keep = avg_similarity.topk(num_tokens_to_keep, largest=False).indices | ||||||
|
|
||||||
| modified_image_feature[i] = frame1_tokens | ||||||
| modified_image_feature[i + 2] = frame2_tokens[tokens_to_keep] | ||||||
|
|
||||||
| # Combine all tokens | ||||||
| combined_tokens = torch.cat(modified_image_feature, dim=0).unsqueeze(0) | ||||||
| return combined_tokens | ||||||
|
|
||||||
|
|
||||||
| def add_dycole_ttm_to_get_2dPool(model, post_hook_fn, pruning_paras): | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo in function name:
Suggested change
|
||||||
| original_fn = model.get_2dPool | ||||||
|
|
||||||
| def wrapped_fn(*args, **kwargs): | ||||||
| result = original_fn(*args, **kwargs) | ||||||
| return post_hook_fn(result, pruning_paras) | ||||||
|
|
||||||
| model.get_2dPool = wrapped_fn | ||||||
|
|
||||||
|
|
||||||
| @TOKEN_REDUCTION_REGISTRY.register('DyCoke') | ||||||
| class DyCoke(TokenReductionModule): | ||||||
| def __init__(self, config, model, blocks): | ||||||
| super().__init__(config, model, blocks) | ||||||
| self.add_sparse_config() | ||||||
| self.register_reduction_modules() | ||||||
|
|
||||||
| def add_sparse_config(self): | ||||||
| self.special_config['different_token_idxs'] = [] | ||||||
| self.dycoke_layer_idx = self.special_config['dycoke_layer_idx'] | ||||||
| self.model.model.pruning_paras = self.special_config | ||||||
|
|
||||||
| def register_reduction_modules(self): | ||||||
|
|
||||||
| if isinstance(self.model.model, LlavaMetaForCausalLM): | ||||||
| add_dycole_ttm_to_get_2dPool( | ||||||
| self.model.model, dycole_ttm, self.model.model.pruning_paras | ||||||
|
Comment on lines
+130
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| ) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model path is set to a placeholder
model path. Is this intended to be a generic template? If so, it might be helpful to add a comment indicating that this needs to be replaced with an actual path. If a default or example path could be provided, that would be even better for usability.