diff --git a/llmc/compression/token_reduction/holitom.py b/llmc/compression/token_reduction/holitom.py index ddd7f039..0208d6c6 100644 --- a/llmc/compression/token_reduction/holitom.py +++ b/llmc/compression/token_reduction/holitom.py @@ -35,7 +35,7 @@ def SigLipEncoder_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, -) -> Union[Tuple, BaseModelOutput]: +) -> Union[Tuple]: output_attentions = ( output_attentions if output_attentions is not None diff --git a/llmc/compression/token_reduction/pyramiddrop.py b/llmc/compression/token_reduction/pyramiddrop.py index 04ece8e3..cf2e41cf 100644 --- a/llmc/compression/token_reduction/pyramiddrop.py +++ b/llmc/compression/token_reduction/pyramiddrop.py @@ -1,5 +1,7 @@ import functools import math +from functools import wraps +from types import MethodType import torch from torch import nn @@ -26,13 +28,17 @@ def add_sparse_config(self): image_token_ratio_list = self.special_config['image_token_ratio_list'] image_token_ratio_list.insert(0, 1.0) self.special_config['image_token_ratio_list'] = image_token_ratio_list + if self.model.__class__.__name__ == 'LlavaHf': + llama_model = self.model.vlm_model.language_model.model + elif self.model.__class__.__name__ == 'Llava': + llama_model = self.model.vlm_model.model self.special_config['tokenizer_padding_side'] = getattr( - self.model.vlm_model.language_model.model.config, + llama_model.config, 'tokenizer_padding_side', 'right', ) - self.model.model.parameters = self.special_config + self.pruning_paras = self.special_config def register_reduction_modules(self): @prefill_wrapper @@ -214,8 +220,12 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx): attention_mask_list.append(new_attention_mask) # Truncate sequences to max length as image embeddings can make the sequence longer + if self.model.__class__.__name__ == 'LlavaHf': + llama_model = self.model.vlm_model.language_model.model + elif self.model.__class__.__name__ == 'Llava': + llama_model = self.model.vlm_model.model tokenizer_model_max_length = getattr( - self.model.vlm_model.language_model.model.config, + llama_model.config, 'tokenizer_model_max_length', 2048, ) @@ -321,6 +331,39 @@ def input_hook(module, input_args, pruning_pars): return input_args + def input_hook_llava(fn, pruning_paras): + @wraps(fn) + def wrapper(self, *args, **kwargs): + if len(args) == 0: + return fn(*args, **kwargs) + input_args = args[0] + if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: + return fn(*args, **kwargs) + + input_ids = args[0] + attention_mask = args[2] + + image_token_posi = [] + prompt_len = [] + vision_tokens = [] + for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask): + seq = cur_input_ids[cur_attention_mask] + image_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist() + if image_index == []: + image_token_posi.append(-1) + prompt_len.append(cur_input_ids.shape[0]) + else: + image_token_posi.append(image_index[0]) + prompt_len.append(cur_input_ids.shape[0] - 1) + vision_tokens.append(pruning_paras['vision_token_length']) + + pruning_paras['image_token_posi'] = image_token_posi + pruning_paras['prompt_len'] = prompt_len + pruning_paras['image_tokens'] = vision_tokens + + return fn(*args, **kwargs) + return wrapper + @prefill_wrapper def read_parameter_hook(module, args, kwargs, pruning_pars): kwargs['attention_mask'] = pruning_pars['attention_mask'] @@ -330,9 +373,19 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): return args, kwargs - self.model.embed_tokens.register_forward_pre_hook( - functools.partial(input_hook, pruning_pars=self.model.model.parameters) - ) + if self.model.__class__.__name__ == 'LlavaHf': + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(input_hook, pruning_pars=self.pruning_paras) + ) + elif self.model.__class__.__name__ == 'Llava': + from llava.constants import IMAGE_TOKEN_INDEX + hook_fn = input_hook_llava( + self.model.vlm_model.prepare_inputs_labels_for_multimodal, + self.pruning_paras + ) + self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( + hook_fn, self.model.vlm_model + ) for layer_idx in range(self.pruning_loc[0], len(self.blocks)): if layer_idx in self.pruning_loc: @@ -340,7 +393,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): self.blocks[layer_idx].register_forward_pre_hook( functools.partial( pruning_hook, - pruning_pars=self.model.model.parameters, + pruning_pars=self.pruning_paras, cur_num=stage, layer_idx=layer_idx, ), @@ -349,7 +402,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): else: self.blocks[layer_idx].register_forward_pre_hook( functools.partial( - read_parameter_hook, pruning_pars=self.model.model.parameters + read_parameter_hook, pruning_pars=self.pruning_paras ), with_kwargs=True, ) diff --git a/llmc/compression/token_reduction/sparsevlm.py b/llmc/compression/token_reduction/sparsevlm.py index 92b9659c..7b903569 100755 --- a/llmc/compression/token_reduction/sparsevlm.py +++ b/llmc/compression/token_reduction/sparsevlm.py @@ -1,4 +1,6 @@ import functools +from functools import wraps +from types import MethodType import einops as ein import torch @@ -27,7 +29,7 @@ def add_sparse_config(self): special_config['token_length_list'] = [] special_config['image_shape'] = self.model.pruning_config['image_token_length'] special_config['image_token_index'] = self.model.pruning_config['image_token_index'] - self.model.model.parameters = special_config + self.pruning_paras = special_config def register_reduction_modules(self): @prefill_wrapper @@ -52,16 +54,48 @@ def input_hook(module, input_args, pruning_pars): return input_args + def input_hook_llava(fn, pruning_paras): + @wraps(fn) + def wrapper(self, *args, **kwargs): + if len(args) == 0: + return fn(*args, **kwargs) + input_args = args[0] + if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: + return fn(*args, **kwargs) + + input_ids = args[0] + attention_mask = args[2] + + pre_prompt_length_list = [] + for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask): + seq = cur_input_ids[cur_attention_mask] + image_token_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist() + if len(image_token_index) > 0: + pre_prompt_length_list.append(image_token_index[0]) + else: + pre_prompt_length_list.append(0) + pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list + + outputs = fn(*args, **kwargs) + + token_length_list = [] + for cur_attention_mask in outputs[2]: + token_length_list.append(cur_attention_mask.sum().item()) + pruning_paras['token_length_list'] = token_length_list + + return outputs + return wrapper + @prefill_wrapper_model def register_module_pars(module, args, kwargs, pruning_pars): pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] inputs_embeds = kwargs['inputs_embeds'] if inputs_embeds is None: - inputs_embeds = self.embed_tokens(kwargs['input_ids']) + inputs_embeds = module.embed_tokens(kwargs['input_ids']) hidden_states = inputs_embeds # shape: (B, L, C) - pruning_pars['B'], L, _ = hidden_states.shape - B = pruning_pars['B'] + B, L, _ = hidden_states.shape + pruning_pars['B'] = B init_n = pruning_pars['init_token_total_shape'] + \ pruning_pars['generate_process_count'] # 668 pruning_pars['prev_decision'] = torch.ones( @@ -80,7 +114,7 @@ def register_module_pars(module, args, kwargs, pruning_pars): if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1): v_t = hidden_states[:, v_token_start: text_token_start, :] t_t = hidden_states[:, text_token_start:, :] - m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] + m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] # 52? m_v_t = m_v_t.softmax(2).mean(1) # [1, 53] pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean()) @@ -206,17 +240,31 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): return args, kwargs - self.model.embed_tokens.register_forward_pre_hook( - functools.partial( - input_hook, - pruning_pars=self.model.model.parameters + if self.model.__class__.__name__ == 'LlavaHf': + self.model.embed_tokens.register_forward_pre_hook( + functools.partial( + input_hook, + pruning_pars=self.pruning_paras + ) + ) + elif self.model.__class__.__name__ == 'Llava': + from llava.constants import IMAGE_TOKEN_INDEX + hook_fn = input_hook_llava( + self.model.vlm_model.prepare_inputs_labels_for_multimodal, + self.pruning_paras + ) + self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( + hook_fn, self.model.vlm_model ) - ) - self.model.model.register_forward_pre_hook( + if self.model.__class__.__name__ == 'LlavaHf': + llama_model = self.model.model + elif self.model.__class__.__name__ == 'Llava': + llama_model = self.model.model.model + llama_model.register_forward_pre_hook( functools.partial( register_module_pars, - pruning_pars=self.model.model.parameters), + pruning_pars=self.pruning_paras), with_kwargs=True ) @@ -228,7 +276,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): self.blocks[block_idx].register_forward_pre_hook( functools.partial( update_output_attentions_hook, - pruning_pars=self.model.model.parameters, + pruning_pars=self.pruning_paras, layer_idx=block_idx, ), with_kwargs=True @@ -236,7 +284,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): self.blocks[block_idx].register_forward_hook( functools.partial( decoder_attn_hook, - pruning_pars=self.model.model.parameters, + pruning_pars=self.pruning_paras, layer_idx=block_idx, ), with_kwargs=True @@ -245,7 +293,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): self.blocks[block_idx].register_forward_pre_hook( functools.partial( read_parameter_hook, - pruning_pars=self.model.model.parameters + pruning_pars=self.pruning_paras ), with_kwargs=True ) @@ -278,6 +326,7 @@ def attn_postprocess_topk( self_attn_weights = self_attn_weights.mean(1) # B, L[Q], L[K] t_token_idx = t_token_idx[1] + text_token_start + relation_vis_text = self_attn_weights[:, t_token_idx, v_token_start: v_token_start + v_token_num] # B, L2, L1 diff --git a/llmc/compression/token_reduction/tome.py b/llmc/compression/token_reduction/tome.py index 5142778e..759c795d 100644 --- a/llmc/compression/token_reduction/tome.py +++ b/llmc/compression/token_reduction/tome.py @@ -43,7 +43,7 @@ def add_sparse_config(self): else: raise ValueError('Invalid r format. Expected int or (start, step) tuple.') - self.model.model.parameters = special_config + self.pruning_paras = special_config def patch_layer(self): for idx, block in enumerate(self.blocks):