-
Notifications
You must be signed in to change notification settings - Fork 66
PyramidDrop and SparseVLM for llava #396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| import functools | ||
| import math | ||
| from functools import wraps | ||
| from types import MethodType | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
@@ -26,13 +28,17 @@ def add_sparse_config(self): | |
| image_token_ratio_list = self.special_config['image_token_ratio_list'] | ||
| image_token_ratio_list.insert(0, 1.0) | ||
| self.special_config['image_token_ratio_list'] = image_token_ratio_list | ||
| if self.model.__class__.__name__ == 'LlavaHf': | ||
| llama_model = self.model.vlm_model.language_model.model | ||
| elif self.model.__class__.__name__ == 'Llava': | ||
| llama_model = self.model.vlm_model.model | ||
|
Comment on lines
+31
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| self.special_config['tokenizer_padding_side'] = getattr( | ||
| self.model.vlm_model.language_model.model.config, | ||
| llama_model.config, | ||
| 'tokenizer_padding_side', | ||
| 'right', | ||
| ) | ||
|
|
||
| self.model.model.parameters = self.special_config | ||
| self.pruning_paras = self.special_config | ||
|
|
||
| def register_reduction_modules(self): | ||
| @prefill_wrapper | ||
|
|
@@ -214,8 +220,12 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx): | |
| attention_mask_list.append(new_attention_mask) | ||
|
|
||
| # Truncate sequences to max length as image embeddings can make the sequence longer | ||
| if self.model.__class__.__name__ == 'LlavaHf': | ||
| llama_model = self.model.vlm_model.language_model.model | ||
| elif self.model.__class__.__name__ == 'Llava': | ||
| llama_model = self.model.vlm_model.model | ||
| tokenizer_model_max_length = getattr( | ||
| self.model.vlm_model.language_model.model.config, | ||
| llama_model.config, | ||
| 'tokenizer_model_max_length', | ||
| 2048, | ||
| ) | ||
|
|
@@ -321,6 +331,39 @@ def input_hook(module, input_args, pruning_pars): | |
|
|
||
| return input_args | ||
|
|
||
| def input_hook_llava(fn, pruning_paras): | ||
| @wraps(fn) | ||
| def wrapper(self, *args, **kwargs): | ||
| if len(args) == 0: | ||
| return fn(*args, **kwargs) | ||
| input_args = args[0] | ||
| if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: | ||
| return fn(*args, **kwargs) | ||
|
|
||
| input_ids = args[0] | ||
| attention_mask = args[2] | ||
|
|
||
| image_token_posi = [] | ||
| prompt_len = [] | ||
| vision_tokens = [] | ||
| for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask): | ||
| seq = cur_input_ids[cur_attention_mask] | ||
| image_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist() | ||
| if image_index == []: | ||
| image_token_posi.append(-1) | ||
| prompt_len.append(cur_input_ids.shape[0]) | ||
| else: | ||
| image_token_posi.append(image_index[0]) | ||
| prompt_len.append(cur_input_ids.shape[0] - 1) | ||
| vision_tokens.append(pruning_paras['vision_token_length']) | ||
|
|
||
| pruning_paras['image_token_posi'] = image_token_posi | ||
| pruning_paras['prompt_len'] = prompt_len | ||
| pruning_paras['image_tokens'] = vision_tokens | ||
|
|
||
| return fn(*args, **kwargs) | ||
| return wrapper | ||
|
|
||
| @prefill_wrapper | ||
| def read_parameter_hook(module, args, kwargs, pruning_pars): | ||
| kwargs['attention_mask'] = pruning_pars['attention_mask'] | ||
|
|
@@ -330,17 +373,27 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): | |
|
|
||
| return args, kwargs | ||
|
|
||
| self.model.embed_tokens.register_forward_pre_hook( | ||
| functools.partial(input_hook, pruning_pars=self.model.model.parameters) | ||
| ) | ||
| if self.model.__class__.__name__ == 'LlavaHf': | ||
| self.model.embed_tokens.register_forward_pre_hook( | ||
| functools.partial(input_hook, pruning_pars=self.pruning_paras) | ||
| ) | ||
| elif self.model.__class__.__name__ == 'Llava': | ||
| from llava.constants import IMAGE_TOKEN_INDEX | ||
| hook_fn = input_hook_llava( | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal, | ||
| self.pruning_paras | ||
| ) | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( | ||
| hook_fn, self.model.vlm_model | ||
| ) | ||
|
Comment on lines
+376
to
+388
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This |
||
|
|
||
| for layer_idx in range(self.pruning_loc[0], len(self.blocks)): | ||
| if layer_idx in self.pruning_loc: | ||
| stage = self.pruning_loc.index(layer_idx) | ||
| self.blocks[layer_idx].register_forward_pre_hook( | ||
| functools.partial( | ||
| pruning_hook, | ||
| pruning_pars=self.model.model.parameters, | ||
| pruning_pars=self.pruning_paras, | ||
| cur_num=stage, | ||
| layer_idx=layer_idx, | ||
| ), | ||
|
|
@@ -349,7 +402,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): | |
| else: | ||
| self.blocks[layer_idx].register_forward_pre_hook( | ||
| functools.partial( | ||
| read_parameter_hook, pruning_pars=self.model.model.parameters | ||
| read_parameter_hook, pruning_pars=self.pruning_paras | ||
| ), | ||
| with_kwargs=True, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| import functools | ||
| from functools import wraps | ||
| from types import MethodType | ||
|
|
||
| import einops as ein | ||
| import torch | ||
|
|
@@ -27,7 +29,7 @@ def add_sparse_config(self): | |
| special_config['token_length_list'] = [] | ||
| special_config['image_shape'] = self.model.pruning_config['image_token_length'] | ||
| special_config['image_token_index'] = self.model.pruning_config['image_token_index'] | ||
| self.model.model.parameters = special_config | ||
| self.pruning_paras = special_config | ||
|
|
||
| def register_reduction_modules(self): | ||
| @prefill_wrapper | ||
|
|
@@ -52,16 +54,48 @@ def input_hook(module, input_args, pruning_pars): | |
|
|
||
| return input_args | ||
|
|
||
| def input_hook_llava(fn, pruning_paras): | ||
| @wraps(fn) | ||
| def wrapper(self, *args, **kwargs): | ||
| if len(args) == 0: | ||
| return fn(*args, **kwargs) | ||
| input_args = args[0] | ||
| if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: | ||
| return fn(*args, **kwargs) | ||
|
|
||
| input_ids = args[0] | ||
| attention_mask = args[2] | ||
|
|
||
| pre_prompt_length_list = [] | ||
| for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask): | ||
| seq = cur_input_ids[cur_attention_mask] | ||
| image_token_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist() | ||
| if len(image_token_index) > 0: | ||
| pre_prompt_length_list.append(image_token_index[0]) | ||
| else: | ||
| pre_prompt_length_list.append(0) | ||
| pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list | ||
|
|
||
| outputs = fn(*args, **kwargs) | ||
|
|
||
| token_length_list = [] | ||
| for cur_attention_mask in outputs[2]: | ||
| token_length_list.append(cur_attention_mask.sum().item()) | ||
| pruning_paras['token_length_list'] = token_length_list | ||
|
|
||
| return outputs | ||
| return wrapper | ||
|
|
||
| @prefill_wrapper_model | ||
| def register_module_pars(module, args, kwargs, pruning_pars): | ||
| pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] | ||
| inputs_embeds = kwargs['inputs_embeds'] | ||
| if inputs_embeds is None: | ||
| inputs_embeds = self.embed_tokens(kwargs['input_ids']) | ||
| inputs_embeds = module.embed_tokens(kwargs['input_ids']) | ||
| hidden_states = inputs_embeds # shape: (B, L, C) | ||
|
|
||
| pruning_pars['B'], L, _ = hidden_states.shape | ||
| B = pruning_pars['B'] | ||
| B, L, _ = hidden_states.shape | ||
| pruning_pars['B'] = B | ||
| init_n = pruning_pars['init_token_total_shape'] + \ | ||
| pruning_pars['generate_process_count'] # 668 | ||
| pruning_pars['prev_decision'] = torch.ones( | ||
|
|
@@ -80,7 +114,7 @@ def register_module_pars(module, args, kwargs, pruning_pars): | |
| if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1): | ||
| v_t = hidden_states[:, v_token_start: text_token_start, :] | ||
| t_t = hidden_states[:, text_token_start:, :] | ||
| m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] | ||
| m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] # 52? | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| m_v_t = m_v_t.softmax(2).mean(1) # [1, 53] | ||
| pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean()) | ||
|
|
||
|
|
@@ -206,17 +240,31 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): | |
|
|
||
| return args, kwargs | ||
|
|
||
| self.model.embed_tokens.register_forward_pre_hook( | ||
| functools.partial( | ||
| input_hook, | ||
| pruning_pars=self.model.model.parameters | ||
| if self.model.__class__.__name__ == 'LlavaHf': | ||
| self.model.embed_tokens.register_forward_pre_hook( | ||
| functools.partial( | ||
| input_hook, | ||
| pruning_pars=self.pruning_paras | ||
| ) | ||
| ) | ||
| elif self.model.__class__.__name__ == 'Llava': | ||
| from llava.constants import IMAGE_TOKEN_INDEX | ||
| hook_fn = input_hook_llava( | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal, | ||
| self.pruning_paras | ||
| ) | ||
| self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( | ||
| hook_fn, self.model.vlm_model | ||
| ) | ||
| ) | ||
|
|
||
| self.model.model.register_forward_pre_hook( | ||
| if self.model.__class__.__name__ == 'LlavaHf': | ||
| llama_model = self.model.model | ||
| elif self.model.__class__.__name__ == 'Llava': | ||
| llama_model = self.model.model.model | ||
| llama_model.register_forward_pre_hook( | ||
| functools.partial( | ||
| register_module_pars, | ||
| pruning_pars=self.model.model.parameters), | ||
| pruning_pars=self.pruning_paras), | ||
| with_kwargs=True | ||
| ) | ||
|
Comment on lines
+243
to
269
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two consecutive |
||
|
|
||
|
|
@@ -228,15 +276,15 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): | |
| self.blocks[block_idx].register_forward_pre_hook( | ||
| functools.partial( | ||
| update_output_attentions_hook, | ||
| pruning_pars=self.model.model.parameters, | ||
| pruning_pars=self.pruning_paras, | ||
| layer_idx=block_idx, | ||
| ), | ||
| with_kwargs=True | ||
| ) | ||
| self.blocks[block_idx].register_forward_hook( | ||
| functools.partial( | ||
| decoder_attn_hook, | ||
| pruning_pars=self.model.model.parameters, | ||
| pruning_pars=self.pruning_paras, | ||
| layer_idx=block_idx, | ||
| ), | ||
| with_kwargs=True | ||
|
|
@@ -245,7 +293,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): | |
| self.blocks[block_idx].register_forward_pre_hook( | ||
| functools.partial( | ||
| read_parameter_hook, | ||
| pruning_pars=self.model.model.parameters | ||
| pruning_pars=self.pruning_paras | ||
| ), | ||
| with_kwargs=True | ||
| ) | ||
|
|
@@ -278,6 +326,7 @@ def attn_postprocess_topk( | |
| self_attn_weights = self_attn_weights.mean(1) # B, L[Q], L[K] | ||
|
|
||
| t_token_idx = t_token_idx[1] + text_token_start | ||
|
|
||
| relation_vis_text = self_attn_weights[:, t_token_idx, | ||
| v_token_start: v_token_start + v_token_num] # B, L2, L1 | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint was changed to
Union[Tuple], but the function can still return aBaseModelOutputinstance (lines 86-90) ifreturn_dictis true. This makes the type hint incorrect.