-
Notifications
You must be signed in to change notification settings - Fork 66
fastv,fastervlm,visionzip for llava1.5 #393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -24,28 +24,39 @@ def add_sparse_config(self): | |||||||||||
| special_config['select_feature'] = self.model.pruning_config['select_feature'] | ||||||||||||
| special_config['image_token_index'] = self.model.pruning_config['image_token_index'] | ||||||||||||
|
|
||||||||||||
| self.model.model.parameters = special_config | ||||||||||||
| special_config['image_attentions_list'] = [] | ||||||||||||
|
|
||||||||||||
| self.pruning_paras = special_config | ||||||||||||
|
|
||||||||||||
| def register_reduction_modules(self): | ||||||||||||
|
|
||||||||||||
| def update_output_attentions_hook(module, args, kwargs): | ||||||||||||
| kwargs['output_attentions'] = True | ||||||||||||
| return args, kwargs | ||||||||||||
|
|
||||||||||||
| def store_attention_hook(m, x, image_forward_outs, pruning_pars): | ||||||||||||
| image_attentions = image_forward_outs.attentions[pruning_pars['select_layer']] | ||||||||||||
| if pruning_pars['select_feature'] == 'default': # patch | ||||||||||||
| image_attentions = image_attentions[:, :, 0, 1:] | ||||||||||||
| elif pruning_pars['select_feature'] == 'full': | ||||||||||||
| image_attentions = image_attentions | ||||||||||||
| def clear_attentions_hook(m, x, pruning_paras): | ||||||||||||
| pruning_paras['image_attentions_list'].clear() | ||||||||||||
|
|
||||||||||||
| def store_attention_hook(m, x, image_forward_outs, pruning_paras): | ||||||||||||
| image_attentions = image_forward_outs.attentions[pruning_paras['select_layer']] | ||||||||||||
| if pruning_paras['select_feature'] in ('default', 'patch'): | ||||||||||||
| image_attention = image_attentions[:, :, 0, 1:] | ||||||||||||
| elif pruning_paras['select_feature'] in ('full', 'cls_patch'): | ||||||||||||
| image_attention = image_attentions | ||||||||||||
| else: | ||||||||||||
| raise ValueError(f'Unexpected select feature: {self.select_feature}') | ||||||||||||
| pruning_pars['image_attentions'] = image_attentions | ||||||||||||
| pruning_paras['image_attentions_list'].append(image_attention.to(x[0].dtype)) | ||||||||||||
|
|
||||||||||||
| def update_attentions_hook(m, x, outs, pruning_paras): | ||||||||||||
| if len(pruning_paras['image_attentions_list']) == 1: | ||||||||||||
| pruning_paras['image_attentions'] = pruning_paras['image_attentions_list'][0] | ||||||||||||
| else: | ||||||||||||
| pruning_paras['image_attentions'] = pruning_paras['image_attentions_list'] | ||||||||||||
|
Comment on lines
+53
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
Suggested change
|
||||||||||||
|
|
||||||||||||
| def pruning_hook(module, args, kwargs, pruning_pars): | ||||||||||||
| def pruning_hook(module, args, kwargs, pruning_paras): | ||||||||||||
|
|
||||||||||||
| image_features = args[0] | ||||||||||||
| image_attentions = pruning_pars['image_attentions'] | ||||||||||||
| image_attentions = pruning_paras['image_attentions'] | ||||||||||||
|
|
||||||||||||
| # image_attentions = image_attentions.max(dim=1)[0] # (B, N) = (1, 576) | ||||||||||||
| image_attentions = image_attentions.mean(dim=1) # (B, N) = (1, 576) | ||||||||||||
|
|
@@ -66,22 +77,22 @@ def pruning_hook(module, args, kwargs, pruning_pars): | |||||||||||
| index_mask = torch.zeros(B, N, dtype=torch.bool, device=image_features.device) # (B, N) | ||||||||||||
| index_mask.scatter_(1, token_indices, True) # (B, N) | ||||||||||||
|
|
||||||||||||
| pruning_pars['index_mask'] = index_mask | ||||||||||||
| pruning_pars['image_attentions'] = image_attentions | ||||||||||||
| pruning_paras['index_mask'] = index_mask | ||||||||||||
| pruning_paras['image_attentions'] = image_attentions | ||||||||||||
|
|
||||||||||||
| return (image_features,), kwargs | ||||||||||||
|
|
||||||||||||
| def get_image_mask_hook(module, args, kwargs, pruning_pars): | ||||||||||||
| pruning_pars['image_mask'] = ( | ||||||||||||
| kwargs['input_ids'] == pruning_pars['image_token_index'] | ||||||||||||
| def get_image_mask_hook(module, args, kwargs, pruning_paras): | ||||||||||||
| pruning_paras['image_mask'] = ( | ||||||||||||
| kwargs['input_ids'] == pruning_paras['image_token_index'] | ||||||||||||
| ) # (B, len) | ||||||||||||
|
|
||||||||||||
| def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_pars): | ||||||||||||
| def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras): | ||||||||||||
|
|
||||||||||||
| # Only batch size 1 is currently supported. | ||||||||||||
| inputs_embeds = kwargs['inputs_embeds'] | ||||||||||||
| image_mask = pruning_pars['image_mask'][0] | ||||||||||||
| index_mask = pruning_pars['index_mask'][0] | ||||||||||||
| image_mask = pruning_paras['image_mask'][0] | ||||||||||||
| index_mask = pruning_paras['index_mask'][0] | ||||||||||||
|
|
||||||||||||
| B, L = inputs_embeds.shape[:2] | ||||||||||||
| device = inputs_embeds.device | ||||||||||||
|
|
@@ -109,28 +120,67 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_pars): | |||||||||||
|
|
||||||||||||
| return args, kwargs | ||||||||||||
|
|
||||||||||||
| self.model.vision_model.register_forward_pre_hook( | ||||||||||||
| update_output_attentions_hook, | ||||||||||||
| with_kwargs=True | ||||||||||||
| ) | ||||||||||||
| def prepare_inputs_hook(module, inputs, outputs, pruning_paras): | ||||||||||||
|
|
||||||||||||
| self.model.vision_model.register_forward_hook( | ||||||||||||
| functools.partial(store_attention_hook, pruning_pars=self.model.model.parameters), | ||||||||||||
| ) | ||||||||||||
| image_features = outputs | ||||||||||||
| index_masks = pruning_paras['index_mask'] | ||||||||||||
| # image_attentions = pruning_paras['image_attentions'] | ||||||||||||
| new_image_features = [] | ||||||||||||
| for image_feature, index_mask in zip(image_features, index_masks): | ||||||||||||
| image_feature = image_feature[index_mask] | ||||||||||||
| new_image_features.append(image_feature) | ||||||||||||
| image_features = torch.stack(new_image_features, dim=0) | ||||||||||||
|
|
||||||||||||
| outputs = image_features | ||||||||||||
| pruning_paras['image_features_shape'] = image_features[0].shape[0] | ||||||||||||
|
|
||||||||||||
| return outputs | ||||||||||||
|
|
||||||||||||
| if self.model.__class__.__name__ == 'LlavaHf': | ||||||||||||
| self.model.vision_model.register_forward_pre_hook( | ||||||||||||
| update_output_attentions_hook, | ||||||||||||
| with_kwargs=True | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| self.model.vision_model.register_forward_hook( | ||||||||||||
| functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), | ||||||||||||
| ) | ||||||||||||
| elif self.model.__class__.__name__ == 'Llava': | ||||||||||||
| self.model.vision_model.register_forward_pre_hook( | ||||||||||||
| functools.partial(clear_attentions_hook, pruning_paras=self.pruning_paras), | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| self.model.vision_model.register_forward_hook( | ||||||||||||
| functools.partial(update_attentions_hook, pruning_paras=self.pruning_paras), | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| self.model.vision_model.vision_tower.register_forward_pre_hook( | ||||||||||||
| update_output_attentions_hook, | ||||||||||||
| with_kwargs=True | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| self.model.vision_model.vision_tower.register_forward_hook( | ||||||||||||
| functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| self.model.vision_projector.register_forward_pre_hook( | ||||||||||||
| functools.partial(pruning_hook, pruning_pars=self.model.model.parameters), | ||||||||||||
| functools.partial(pruning_hook, pruning_paras=self.pruning_paras), | ||||||||||||
| with_kwargs=True | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| self.model.vlm_model.register_forward_pre_hook( | ||||||||||||
| functools.partial(get_image_mask_hook, pruning_pars=self.model.model.parameters), | ||||||||||||
| with_kwargs=True | ||||||||||||
| ) | ||||||||||||
| if self.model.__class__.__name__ == 'LlavaHf': | ||||||||||||
| self.model.vlm_model.register_forward_pre_hook( | ||||||||||||
| functools.partial(get_image_mask_hook, pruning_paras=self.pruning_paras), | ||||||||||||
| with_kwargs=True | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| self.model.model.register_forward_pre_hook( | ||||||||||||
| functools.partial( | ||||||||||||
| prepare_inputs_for_llm_hook, pruning_pars=self.model.model.parameters | ||||||||||||
| ), | ||||||||||||
| with_kwargs=True | ||||||||||||
| ) | ||||||||||||
| self.model.model.register_forward_pre_hook( | ||||||||||||
| functools.partial( | ||||||||||||
| prepare_inputs_for_llm_hook, pruning_paras=self.pruning_paras | ||||||||||||
| ), | ||||||||||||
| with_kwargs=True | ||||||||||||
| ) | ||||||||||||
| elif self.model.__class__.__name__ == 'Llava': | ||||||||||||
| self.model.vision_projector.register_forward_hook( | ||||||||||||
| functools.partial(prepare_inputs_hook, pruning_paras=self.pruning_paras), | ||||||||||||
| ) | ||||||||||||
|
Comment on lines
+139
to
+186
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import functools | ||
| from types import MethodType | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -8,6 +9,8 @@ | |
| from .token_reduction_module import TokenReductionModule | ||
| from .utils import prefill_wrapper | ||
|
|
||
| IMAGE_TOKEN_INDEX = -200 | ||
|
|
||
|
|
||
| @TOKEN_REDUCTION_REGISTRY.register('FastV') | ||
| class FastV(TokenReductionModule): | ||
|
|
@@ -23,41 +26,61 @@ def add_sparse_config(self): | |
| self.model.pruning_config['image_token_length'] | ||
| self.special_config['attn_scores'] = None | ||
|
|
||
| self.model.model.parameters = self.special_config | ||
| self.pruning_paras = self.special_config | ||
|
|
||
| def register_reduction_modules(self): | ||
|
|
||
| @prefill_wrapper | ||
| def input_hook(module, input_args, pruning_pars): | ||
| def input_hook(module, input_args, pruning_paras): | ||
| input_ids = input_args[0] | ||
| image_token_idxs = (input_ids[0] == | ||
| pruning_pars['vision_token_index']).nonzero(as_tuple=True)[0] | ||
| pruning_pars['image_token_start_index'] = image_token_idxs[0].item() | ||
| pruning_paras['vision_token_index']).nonzero(as_tuple=True)[0] | ||
| pruning_paras['image_token_start_index'] = image_token_idxs[0].item() | ||
|
|
||
| return input_args | ||
|
|
||
| def make_hook_prepare_inputs_labels_for_multimodal(pruning_paras): | ||
| def hook_prepare_inputs_labels_for_multimodal( | ||
| self, | ||
| input_ids, | ||
| position_ids, | ||
| attention_mask, | ||
| past_key_values, | ||
| labels, | ||
| images, | ||
| image_sizes | ||
| ): | ||
| if 'image_token_start_index' not in pruning_paras: | ||
| token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX | ||
| pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item() | ||
| return self._original_prepare_inputs_labels_for_multimodal( | ||
| input_ids, position_ids, attention_mask, | ||
| past_key_values, labels, images, image_sizes | ||
| ) | ||
| return hook_prepare_inputs_labels_for_multimodal | ||
|
|
||
| def update_output_attentions_hook(module, args, kwargs): | ||
| kwargs['output_attentions'] = True | ||
| return args, kwargs | ||
|
|
||
| def store_attention_hook(m, x, layer_outputs, pruning_pars): | ||
| def store_attention_hook(m, x, layer_outputs, pruning_paras): | ||
| layer_attention = layer_outputs[1] | ||
| pruning_pars['attn_scores'] = layer_attention | ||
| pruning_paras['attn_scores'] = layer_attention | ||
|
|
||
| @prefill_wrapper | ||
| def fastv_pruning_hook(module, args, kwargs, pruning_pars): | ||
| def fastv_pruning_hook(module, args, kwargs, pruning_paras): | ||
|
|
||
| rate = pruning_pars['rate'] | ||
| image_token_start_index = pruning_pars['image_token_start_index'] | ||
| image_token_length = pruning_pars['image_token_length'] | ||
| rate = pruning_paras['rate'] | ||
| image_token_start_index = pruning_paras['image_token_start_index'] | ||
| image_token_length = pruning_paras['image_token_length'] | ||
|
|
||
| hidden_states = args[0] | ||
| causal_mask = kwargs['attention_mask'] | ||
| cache_position = kwargs['cache_position'] | ||
|
|
||
| device = hidden_states.device | ||
| # last_layer_attention = layer_outputs[1] | ||
| last_layer_attention = pruning_pars['attn_scores'] | ||
| last_layer_attention = pruning_paras['attn_scores'] | ||
| # compute average attention over different head | ||
| last_layer_attention_avg = torch.mean(last_layer_attention, dim=1)[0] | ||
| # generate new attention mask based on the average attention, | ||
|
|
@@ -98,42 +121,51 @@ def fastv_pruning_hook(module, args, kwargs, pruning_pars): | |
| kwargs['cache_position'] = cache_position[:new_seq_length] | ||
| kwargs['position_ids'] = position_ids | ||
| kwargs['position_embeddings'] = None | ||
| pruning_pars['attention_mask'] = causal_mask | ||
| pruning_pars['cache_position'] = cache_position[:new_seq_length] | ||
| pruning_pars['position_ids'] = position_ids | ||
| pruning_pars['position_embeddings'] = None | ||
| pruning_paras['attention_mask'] = causal_mask | ||
| pruning_paras['cache_position'] = cache_position[:new_seq_length] | ||
| pruning_paras['position_ids'] = position_ids | ||
| pruning_paras['position_embeddings'] = None | ||
|
|
||
| return (hidden_states,), kwargs | ||
|
|
||
| @prefill_wrapper | ||
| def read_parameter_hook(module, args, kwargs, pruning_pars): | ||
| kwargs['attention_mask'] = pruning_pars['attention_mask'] | ||
| kwargs['cache_position'] = pruning_pars['cache_position'] | ||
| kwargs['position_ids'] = pruning_pars['position_ids'] | ||
| kwargs['position_embeddings'] = pruning_pars['position_embeddings'] | ||
| def read_parameter_hook(module, args, kwargs, pruning_paras): | ||
| kwargs['attention_mask'] = pruning_paras['attention_mask'] | ||
| kwargs['cache_position'] = pruning_paras['cache_position'] | ||
| kwargs['position_ids'] = pruning_paras['position_ids'] | ||
| kwargs['position_embeddings'] = pruning_paras['position_embeddings'] | ||
|
|
||
| return args, kwargs | ||
|
|
||
| self.model.embed_tokens.register_forward_pre_hook( | ||
| functools.partial(input_hook, pruning_pars=self.model.model.parameters) | ||
| ) | ||
| if self.model.__class__.__name__ == 'LlavaHf': | ||
| self.model.embed_tokens.register_forward_pre_hook( | ||
| functools.partial(input_hook, pruning_paras=self.pruning_paras) | ||
| ) | ||
| elif self.model.__class__.__name__ == 'Llava': | ||
| hook_fn = make_hook_prepare_inputs_labels_for_multimodal(self.pruning_paras) | ||
| self.model.vlm_model._original_prepare_inputs_labels_for_multimodal = ( | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal | ||
| ) | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( | ||
| hook_fn, self.model.vlm_model | ||
| ) | ||
|
Comment on lines
+145
to
+151
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| self.blocks[self.pruning_loc - 1].register_forward_pre_hook( | ||
| update_output_attentions_hook, | ||
| with_kwargs=True | ||
| ) | ||
|
|
||
| self.blocks[self.pruning_loc - 1].register_forward_hook( | ||
| functools.partial(store_attention_hook, pruning_pars=self.model.model.parameters), | ||
| functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), | ||
| ) | ||
|
|
||
| self.blocks[self.pruning_loc].register_forward_pre_hook( | ||
| functools.partial(fastv_pruning_hook, pruning_pars=self.model.model.parameters), | ||
| functools.partial(fastv_pruning_hook, pruning_paras=self.pruning_paras), | ||
| with_kwargs=True | ||
| ) | ||
|
|
||
| for idx in range(self.pruning_loc + 1, len(self.blocks)): | ||
| self.blocks[idx].register_forward_pre_hook( | ||
| functools.partial(read_parameter_hook, pruning_pars=self.model.model.parameters), | ||
| functools.partial(read_parameter_hook, pruning_paras=self.pruning_paras), | ||
| with_kwargs=True | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable
self.select_featureis not defined in this class context. Usepruning_paras['select_feature']instead.