From e45ebed1b02b43a6186ad34bce44ba36fa43b545 Mon Sep 17 00:00:00 2001 From: SmudgedWings <2045955563@qq.com> Date: Sun, 27 Jul 2025 23:04:31 +0800 Subject: [PATCH] dart,fatev,vispruner for llava1.6,update sparsevlm --- .../methods/SparseVLM/sparsevlm.yml | 3 +- llmc/compression/token_reduction/dart.py | 26 +-- llmc/compression/token_reduction/fastv.py | 15 +- llmc/compression/token_reduction/sparsevlm.py | 172 ++++++++---------- .../token_reduction/token_reduction_module.py | 21 ++- llmc/compression/token_reduction/utils.py | 104 +++++++++++ llmc/compression/token_reduction/vispruner.py | 163 ++++++++++++++++- llmc/eval/eval_vqa.py | 8 +- llmc/models/llava.py | 6 +- 9 files changed, 395 insertions(+), 123 deletions(-) diff --git a/configs/sparsification/methods/SparseVLM/sparsevlm.yml b/configs/sparsification/methods/SparseVLM/sparsevlm.yml index 8dc69d92..e2c117ee 100644 --- a/configs/sparsification/methods/SparseVLM/sparsevlm.yml +++ b/configs/sparsification/methods/SparseVLM/sparsevlm.yml @@ -17,8 +17,7 @@ sparse: special: method: SparseVLM pruning_loc: [2, 6, 15] - retained_tokens: 192 - prune_flag: True + reduction_ratio: 0.6667 merge_flag: True save: save_trans: False diff --git a/llmc/compression/token_reduction/dart.py b/llmc/compression/token_reduction/dart.py index f237282d..79742bc3 100644 --- a/llmc/compression/token_reduction/dart.py +++ b/llmc/compression/token_reduction/dart.py @@ -1,5 +1,5 @@ import functools -import math +from types import MethodType import torch @@ -24,26 +24,20 @@ def add_sparse_config(self): def register_reduction_modules(self): @prefill_wrapper - def vtoken_length_hook(module, input_args, pruning_paras): - - input_ids = input_args[0] + def vtoken_length_hook(module, args, pruning_paras): + input_ids = args[0] token_indices = torch.where( input_ids[0] == pruning_paras['vision_token_index'] )[0] pruning_paras['vision_token_length'] = token_indices.shape[0] - return input_args - @prefill_wrapper def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_idx): - past_key_value = kwargs['past_key_value'] if past_key_value is None: raise ValueError('DART needs past_key_value but got None.') pruning_paras['any_states'] = past_key_value.key_cache[layer_idx] - return layer_outs - @prefill_wrapper def pruning_hook(module, args, kwargs, pruning_paras, normlayer): @@ -95,9 +89,17 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer): return (hidden_states,), kwargs if self.special_config['vision_token_length'] is None: - self.model.embed_tokens.register_forward_pre_hook( - functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras) - ) + if self.model.__class__.__name__ == 'Llava': + self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( + self.vtoken_length_for_llava_hook( + self.model.vlm_model.prepare_inputs_labels_for_multimodal, + self.pruning_paras + ), self.model.vlm_model + ) + else: + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras) + ) self.blocks[self.pruning_loc - 1].register_forward_hook( functools.partial( diff --git a/llmc/compression/token_reduction/fastv.py b/llmc/compression/token_reduction/fastv.py index 48080da5..2a2e8e47 100644 --- a/llmc/compression/token_reduction/fastv.py +++ b/llmc/compression/token_reduction/fastv.py @@ -1,4 +1,5 @@ import functools +from types import MethodType import torch @@ -104,9 +105,17 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras): return (hidden_states,), kwargs if self.special_config['vision_token_length'] is None: - self.model.embed_tokens.register_forward_pre_hook( - functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras) - ) + if self.model.__class__.__name__ == 'Llava': + self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( + self.vtoken_length_for_llava_hook( + self.model.vlm_model.prepare_inputs_labels_for_multimodal, + self.pruning_paras + ), self.model.vlm_model + ) + else: + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras) + ) self.blocks[self.pruning_loc - 1].register_forward_pre_hook( functools.partial(update_output_attentions_hook, pruning_paras=self.pruning_paras), diff --git a/llmc/compression/token_reduction/sparsevlm.py b/llmc/compression/token_reduction/sparsevlm.py index f91c9564..02cb2c74 100755 --- a/llmc/compression/token_reduction/sparsevlm.py +++ b/llmc/compression/token_reduction/sparsevlm.py @@ -1,4 +1,3 @@ -import copy import functools import math from functools import wraps @@ -29,47 +28,37 @@ def __init__(self, config, model, blocks): self.register_reduction_modules() def add_sparse_config(self): - special_config = self.config.get('special', {}) - self.pruning_loc = special_config.get('pruning_loc', [2, 6, 15]) + self.pruning_loc = self.special_config.get('pruning_loc', [2, 6, 15]) global layer_dict, prune_flag, merge_flag layer_dict = {layer: idx for idx, layer in enumerate(self.pruning_loc)} - prune_flag = special_config.get('prune_flag', True) - merge_flag = special_config.get('merge_flag', True) + prune_flag = self.special_config.get('prune_flag', True) + merge_flag = self.special_config.get('merge_flag', True) update_list() - special_config['retained_tokens'] = special_config.get('retained_tokens', 192) - special_config['pre_prompt_length_list'] = [] - special_config['image_shape'] = self.model.pruning_config['image_token_length'] - special_config['image_token_index'] = self.model.pruning_config['image_token_index'] - self.pruning_paras = special_config + self.pruning_paras = self.special_config + self.pruning_paras['pre_prompt_length_list'] = [] def register_reduction_modules(self): @prefill_wrapper - def input_hook(module, input_args, pruning_pars): - input_ids = input_args[0] + def input_hook(module, args, pruning_paras): + input_ids = args[0] pre_prompt_length_list = [] - IMAGE_TOKEN_INDEX = pruning_pars['image_token_index'] # find the position of the first image token for seq in input_ids: image_token_index = ( - seq == IMAGE_TOKEN_INDEX + seq == pruning_paras['vision_token_index'] ).nonzero(as_tuple=True)[0] if len(image_token_index) > 0: pre_prompt_length_list.append(image_token_index[0].item()) else: pre_prompt_length_list.append(0) - pruning_pars['pre_prompt_length_list'] = pre_prompt_length_list - - return input_args + pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list def input_hook_llava(fn, pruning_paras): @wraps(fn) def wrapper(self, *args, **kwargs): - if len(args) == 0: - return fn(*args, **kwargs) - input_args = args[0] - if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: + if args[0].shape[1] == 1: return fn(*args, **kwargs) input_ids = args[0] @@ -85,7 +74,7 @@ def wrapper(self, *args, **kwargs): seq = cur_input_ids[cur_attention_mask] image_token_index = ( [-1] - + torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist() + + torch.where(seq == pruning_paras['vision_token_index'])[0].tolist() + [seq.shape[0]] ) pre_prompt_length_list.append(image_token_index[1]) @@ -96,57 +85,60 @@ def wrapper(self, *args, **kwargs): return wrapper @prefill_wrapper_model - def register_module_pars(module, args, kwargs, pruning_pars): - pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] + def register_module_pars(module, args, kwargs, pruning_paras): + pre_prompt_length_list = pruning_paras['pre_prompt_length_list'] hidden_states = kwargs['inputs_embeds'] if hidden_states is None: hidden_states = module.embed_tokens(kwargs['input_ids']) B, L, _ = hidden_states.shape - pruning_pars['B'] = B + pruning_paras['B'] = B v_token_start = pre_prompt_length_list[0] if len( pre_prompt_length_list) != 0 else 0 - text_token_start = v_token_start + pruning_pars['image_shape'] - pruning_pars['v_token_start'] = v_token_start # 35 - pruning_pars['text_token_start'] = text_token_start # 611 - pruning_pars['v_token_num'] = pruning_pars['image_shape'] # 576 + text_token_start = v_token_start + pruning_paras['vision_token_length'] + pruning_paras['v_token_start'] = v_token_start # 35 + pruning_paras['text_token_start'] = text_token_start # 611 + pruning_paras['v_token_num'] = pruning_paras['vision_token_length'] # 576 + pruning_paras['retained_tokens'] = round( + pruning_paras['vision_token_length'] * (1 - pruning_paras['reduction_ratio']) + ) if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1): v_t = hidden_states[:, v_token_start: text_token_start, :] t_t = hidden_states[:, text_token_start:, :] m_v_t = v_t @ t_t.transpose(1, 2) m_v_t = m_v_t.softmax(2).mean(1) - pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean()) + pruning_paras['t_token_idx'] = torch.where(m_v_t > m_v_t.mean()) return args, kwargs - def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx): + def update_output_attentions_hook(module, args, kwargs, pruning_paras, layer_idx): kwargs['output_attentions'] = True if layer_idx != self.pruning_loc[0]: - kwargs['position_ids'] = pruning_pars['position_ids'] - kwargs['attention_mask'] = pruning_pars['attention_mask'] - kwargs['cache_position'] = pruning_pars['cache_position'] - kwargs['position_embeddings'] = pruning_pars['position_embeddings'] + kwargs['position_ids'] = pruning_paras['position_ids'] + kwargs['attention_mask'] = pruning_paras['attention_mask'] + kwargs['cache_position'] = pruning_paras['cache_position'] + kwargs['position_embeddings'] = pruning_paras['position_embeddings'] return args, kwargs - def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx): + def update_kwargs_hook(module, args, kwargs, pruning_paras, layer_idx): if len(kwargs['position_ids'][0]) == 1: return args, kwargs if layer_idx != self.pruning_loc[0]: - kwargs['position_ids'] = pruning_pars['position_ids'] - kwargs['attention_mask'] = pruning_pars['attention_mask'] - kwargs['cache_position'] = pruning_pars['cache_position'] - kwargs['position_embeddings'] = pruning_pars['position_embeddings'] + kwargs['position_ids'] = pruning_paras['position_ids'] + kwargs['attention_mask'] = pruning_paras['attention_mask'] + kwargs['cache_position'] = pruning_paras['cache_position'] + kwargs['position_embeddings'] = pruning_paras['position_embeddings'] else: - pruning_pars['position_ids'] = kwargs['position_ids'] - pruning_pars['attention_mask'] = kwargs['attention_mask'] - pruning_pars['cache_position'] = kwargs['cache_position'] - pruning_pars['position_embeddings'] = kwargs['position_embeddings'] + pruning_paras['position_ids'] = kwargs['position_ids'] + pruning_paras['attention_mask'] = kwargs['attention_mask'] + pruning_paras['cache_position'] = kwargs['cache_position'] + pruning_paras['position_embeddings'] = kwargs['position_embeddings'] return args, kwargs - def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_idx): + def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_paras, layer_idx): if len(kwargs['position_ids'][0]) == 1: return layer_outs @@ -160,9 +152,9 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i past_key_value = layer_outs[2] attention_mask = kwargs['attention_mask'] - t_token_idx = pruning_pars['t_token_idx'] - v_token_start = pruning_pars['v_token_start'] - v_token_num = pruning_pars['v_token_num'] + t_token_idx = pruning_paras['t_token_idx'] + v_token_start = pruning_paras['v_token_start'] + v_token_num = pruning_paras['v_token_num'] bsz, q_len, _ = hidden_states.size() query_states = module.q_proj(hidden_states) @@ -201,27 +193,28 @@ def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_i attn_logits += attn_bias.to(query_states.device) attn_logits = torch.softmax(attn_logits, dim=-1) - pruning_pars['attn_logits'] = attn_logits + pruning_paras['attn_logits'] = attn_logits return layer_outs @prefill_wrapper - def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx): + def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_paras, layer_idx): - if 'attn_logits' not in pruning_pars: - attn_logits = layer_outputs[1] # for LlavaHf + if 'attn_logits' not in pruning_paras: + attn_logits = layer_outputs[1] # for LlavaHf, but error else: - attn_logits = pruning_pars['attn_logits'] - prune_flag = pruning_pars.get('prune_flag', True) - merge_flag = pruning_pars['merge_flag'] - v_token_start = pruning_pars['v_token_start'] - v_token_num = pruning_pars['v_token_num'] - text_token_start = pruning_pars['text_token_start'] - t_token_idx = pruning_pars['t_token_idx'] - retained_tokens = pruning_pars['retained_tokens'] - B = pruning_pars['B'] - pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] - image_shape = pruning_pars['image_shape'] + attn_logits = pruning_paras['attn_logits'] + prune_flag = pruning_paras.get('prune_flag', True) + merge_flag = pruning_paras['merge_flag'] + v_token_start = pruning_paras['v_token_start'] + v_token_num = pruning_paras['v_token_num'] + text_token_start = pruning_paras['text_token_start'] + t_token_idx = pruning_paras['t_token_idx'] + retained_tokens = pruning_paras['retained_tokens'] + + B = pruning_paras['B'] + pre_prompt_length_list = pruning_paras['pre_prompt_length_list'] + vision_token_length = pruning_paras['vision_token_length'] attention_mask = kwargs['attention_mask'] position_embeddings = kwargs['position_embeddings'] @@ -248,7 +241,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer prompt_length = pre_prompt_length_list[batch] policy[batch, :prompt_length] = 1 # keep question - text_token_start = prompt_length + image_shape + text_token_start = prompt_length + vision_token_length policy[batch, text_token_start:] = 1 if self.model.first_turn_question: @@ -333,40 +326,35 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer new_pe1 = position_embeddings[1][:, keep_indexs, :].clone() position_embeddings = (new_pe0, new_pe1) - pruning_pars['v_token_num'] = v_token_num - pruning_pars['text_token_start'] = text_token_start + pruning_paras['v_token_num'] = v_token_num + pruning_paras['text_token_start'] = text_token_start - pruning_pars['position_ids'] = position_ids - pruning_pars['cache_position'] = cache_position - pruning_pars['position_embeddings'] = position_embeddings - pruning_pars['attention_mask'] = attention_mask + pruning_paras['position_ids'] = position_ids + pruning_paras['cache_position'] = cache_position + pruning_paras['position_embeddings'] = position_embeddings + pruning_paras['attention_mask'] = attention_mask return new_output @prefill_wrapper - def read_parameter_hook(module, args, kwargs, pruning_pars): - kwargs['position_ids'] = pruning_pars['position_ids'] - kwargs['attention_mask'] = pruning_pars['attention_mask'] - kwargs['cache_position'] = pruning_pars['cache_position'] - kwargs['position_embeddings'] = pruning_pars['position_embeddings'] + def read_parameter_hook(module, args, kwargs, pruning_paras): + kwargs['position_ids'] = pruning_paras['position_ids'] + kwargs['attention_mask'] = pruning_paras['attention_mask'] + kwargs['cache_position'] = pruning_paras['cache_position'] + kwargs['position_embeddings'] = pruning_paras['position_embeddings'] return args, kwargs if self.model.__class__.__name__ == 'LlavaHf': self.model.embed_tokens.register_forward_pre_hook( - functools.partial( - input_hook, - pruning_pars=self.pruning_paras - ) + functools.partial(input_hook, pruning_paras=self.pruning_paras) ) elif self.model.__class__.__name__ == 'Llava': - from llava.constants import IMAGE_TOKEN_INDEX - hook_fn = input_hook_llava( - self.model.vlm_model.prepare_inputs_labels_for_multimodal, - self.pruning_paras - ) self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( - hook_fn, self.model.vlm_model + input_hook_llava( + self.model.vlm_model.prepare_inputs_labels_for_multimodal, + self.pruning_paras + ), self.model.vlm_model ) if self.model.__class__.__name__ == 'LlavaHf': @@ -374,9 +362,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): elif self.model.__class__.__name__ == 'Llava': llama_model = self.model.model.model llama_model.register_forward_pre_hook( - functools.partial( - register_module_pars, - pruning_pars=self.pruning_paras), + functools.partial(register_module_pars, pruning_paras=self.pruning_paras), with_kwargs=True ) @@ -389,7 +375,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): self.blocks[block_idx].register_forward_pre_hook( functools.partial( update_output_attentions_hook, - pruning_pars=self.pruning_paras, + pruning_paras=self.pruning_paras, layer_idx=block_idx, ), with_kwargs=True @@ -398,7 +384,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): self.blocks[block_idx].register_forward_pre_hook( functools.partial( update_kwargs_hook, - pruning_pars=self.pruning_paras, + pruning_paras=self.pruning_paras, layer_idx=block_idx, ), with_kwargs=True @@ -406,7 +392,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): self.blocks[block_idx].self_attn.register_forward_hook( functools.partial( get_attn_logits_hook, - pruning_pars=self.pruning_paras, + pruning_paras=self.pruning_paras, layer_idx=block_idx, ), with_kwargs=True @@ -414,7 +400,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): self.blocks[block_idx].register_forward_hook( functools.partial( decoder_attn_hook, - pruning_pars=self.pruning_paras, + pruning_paras=self.pruning_paras, layer_idx=block_idx ), with_kwargs=True @@ -423,7 +409,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): self.blocks[block_idx].register_forward_pre_hook( functools.partial( read_parameter_hook, - pruning_pars=self.pruning_paras + pruning_paras=self.pruning_paras ), with_kwargs=True ) diff --git a/llmc/compression/token_reduction/token_reduction_module.py b/llmc/compression/token_reduction/token_reduction_module.py index 37e4483a..5bd72676 100644 --- a/llmc/compression/token_reduction/token_reduction_module.py +++ b/llmc/compression/token_reduction/token_reduction_module.py @@ -1,7 +1,5 @@ -import time -import torch -from loguru import logger +from functools import wraps class TokenReductionModule: @@ -35,3 +33,20 @@ def set_sparse_config(self): def register_reduction_modules(self): pass + + def vtoken_length_for_llava_hook(self, fn, pruning_paras): + @wraps(fn) + def wrapper(self, *args, **kwargs): + if args[0].shape[1] == 1: + return fn(*args, **kwargs) + + message = ( + 'To obtain the vision_token_length for LLaVA-1.6, you should append ' + '`image_features.shape[1]` to the return value of the function ' + '`prepare_inputs_labels_for_multimodal`, and modify the related code accordingly.' + ) + outs = fn(*args, **kwargs) + assert len(outs) == 7, message + pruning_paras['vision_token_length'] = outs[-1] + return outs + return wrapper diff --git a/llmc/compression/token_reduction/utils.py b/llmc/compression/token_reduction/utils.py index 3a43c8e1..cf657e22 100755 --- a/llmc/compression/token_reduction/utils.py +++ b/llmc/compression/token_reduction/utils.py @@ -1,3 +1,5 @@ +import ast +import re from functools import wraps from typing import Any, List, Optional, Tuple, Union @@ -98,3 +100,105 @@ def wrapped_fn(*args, **kwargs): return post_hook_fn(result, pruning_paras) model.get_2dPool = wrapped_fn + + +def select_best_resolution(original_size, possible_resolutions): + + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float('inf') + + for width, height in possible_resolutions: + # Calculate the downscaled size to keep the aspect ratio + scale = min(width / original_width, height / original_height) + downscaled_width = int(original_width * scale) + downscaled_height = int(original_height * scale) + + # Calculate effective and wasted resolutions + effective_resolution = min( + downscaled_width * downscaled_height, + original_width * original_height + ) + wasted_resolution = (width * height) - effective_resolution + + if (effective_resolution > max_effective_resolution) or ( + effective_resolution == max_effective_resolution and + wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """Calculate the shape of the image patch grid after the preprocessing for + images of any resolution. + + Args: + image_size (tuple): The size of the input image in the format (width, height). + grid_pinpoints (str): A string representation of a list of possible resolutions. + patch_size (int): The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if isinstance(grid_pinpoints, str) and 'x' in grid_pinpoints: + assert patch_size in [224, 336, 384, 448, 512], ( + 'patch_size should be in [224, 336, 384, 448, 512]' + ) + # Use regex to extract the range from the input string + matches = re.findall(r'\((\d+)x(\d+)\)', grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples + # from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + width, height = select_best_resolution(image_size, possible_resolutions) + return width // patch_size, height // patch_size + + +def unpad_image(tensor, original_size): + """Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of the image (height, width). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + # Compute aspect ratios + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + # Determine padding size and direction + if original_aspect_ratio > current_aspect_ratio: + # Padding was added to the height + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding: current_height - padding, :] + else: + # Padding was added to the width + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding: current_width - padding] + + return unpadded_tensor diff --git a/llmc/compression/token_reduction/vispruner.py b/llmc/compression/token_reduction/vispruner.py index f97bf1b9..ddcabfa6 100644 --- a/llmc/compression/token_reduction/vispruner.py +++ b/llmc/compression/token_reduction/vispruner.py @@ -1,10 +1,13 @@ import functools +from functools import wraps +from types import MethodType import torch from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule +from .utils import get_anyres_image_grid_shape, unpad_image @TOKEN_REDUCTION_REGISTRY.register('VisPruner') @@ -26,8 +29,39 @@ def add_sparse_config(self): def register_reduction_modules(self): + def change_images_hook(fn, pruning_paras): + @wraps(fn) + def wrapper(self, *args, **kwargs): + images = args[5] + input_ids = args[0] + vision_tower = self.get_vision_tower() + + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return fn(*args, **kwargs) + + if images.ndim == 5: + args = list(args) + concat_images = torch.cat([image for image in images], dim=0) + args[5] = concat_images.unsqueeze(dim=0).unsqueeze(dim=0) + pruning_paras['image_sizes'] = kwargs['image_sizes'] + pruning_paras['num_patches_per_side'] = vision_tower.num_patches_per_side + if hasattr(vision_tower, 'image_size'): + pruning_paras['vision_tower_image_size'] = vision_tower.image_size + else: + pruning_paras['vision_tower_image_size'] = None + pruning_paras['image_newline'] = self.model.image_newline + + return fn(*tuple(args), **kwargs) + else: + return fn(*args, **kwargs) + return wrapper + def update_output_attentions_hook(module, args, kwargs): + args = list(args) + if args[0].ndim == 6: + args[0] = args[0].squeeze(dim=0).squeeze(dim=0) kwargs['output_attentions'] = True + return tuple(args), kwargs def store_attention_hook(module, inps, outs, pruning_paras): image_attentions = outs.attentions[pruning_paras['select_layer']] @@ -35,7 +69,7 @@ def store_attention_hook(module, inps, outs, pruning_paras): image_attentions = image_attentions[:, :, 0, 1:] elif pruning_paras['select_feature'] == 'cls_patch': image_attentions = image_attentions - raise ValueError(f'Unexpected select feature: {self.select_feature}') + raise ValueError(f"Unexpected select feature: {pruning_paras['select_feature']}") pruning_paras['image_attentions'] = image_attentions.to(inps[0].dtype) @@ -96,10 +130,127 @@ def get_index_masks_hook(module, args, pruning_paras): pruning_paras['index_masks'] = index_masks - def prune_hook(module, inputs, outputs, pruning_paras): + def prune_hook(module, inputs, outputs, pruning_paras, model_config): image_features = outputs index_masks = pruning_paras['index_masks'] - return image_features[index_masks].unsqueeze(0) + + if image_features.shape[0] == 1: + return image_features[index_masks].unsqueeze(0) + + image_sizes = pruning_paras['image_sizes'] + split_sizes = [image_features.shape[0]] + image_features = torch.split(image_features, split_sizes, dim=0) + index_masks = torch.split(index_masks, split_sizes, dim=0) + # 'spatial_unpad', 'anyres' + mm_patch_merge_type = getattr(model_config, 'mm_patch_merge_type', 'flat') + mm_patch_merge_type = mm_patch_merge_type.replace('_unpad', '') + image_aspect_ratio = getattr(model_config, 'image_aspect_ratio', 'square') + + if mm_patch_merge_type == 'flat': + image_features = [x.flatten(0, 1) for x in image_features] + index_masks = [x.flatten(0, 1) for x in index_masks] + image_features = [x[m] for x, m in zip(image_features, index_masks)] + elif mm_patch_merge_type.startswith('spatial'): + new_image_features = [] + for image_idx, (image_feature, index_mask) in enumerate( + zip(image_features, index_masks) + ): + if image_feature.shape[0] > 1: + base_image_feature, base_index_mask = image_feature[0], index_mask[0] + image_feature, index_mask = image_feature[1:], index_mask[1:] + height = width = pruning_paras['num_patches_per_side'] + assert height * width == base_image_feature.shape[0] + + if image_aspect_ratio == 'anyres': + if pruning_paras['vision_tower_image_size'] is not None: + vision_tower_image_size = pruning_paras['vision_tower_image_size'] + else: + raise ValueError( + 'vision_tower_image_size is not found in the vision tower.' + ) + try: + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + model_config.image_grid_pinpoints, + vision_tower_image_size + ) + except Exception: + num_patch_width, num_patch_height = 2, 2 + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + index_mask = index_mask.view( + num_patch_height, num_patch_width, height, width + ) + else: + raise NotImplementedError + + if 'maxpool2x2' in mm_patch_merge_type: + raise NotImplementedError + elif 'unpad' in mm_patch_merge_type and 'anyres_max' in image_aspect_ratio: + raise NotImplementedError + elif 'unpad' in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat(( + image_feature, + pruning_paras['image_newline'][:, None, None].expand( + *image_feature.shape[:-1], 1 + ).to(image_feature.device) + ), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + index_mask = index_mask.permute(0, 2, 1, 3).contiguous().unsqueeze(0) + index_mask = index_mask.flatten(1, 2).flatten(2, 3) + index_mask = unpad_image(index_mask, image_sizes[image_idx]) + index_mask = torch.cat(( + index_mask, + torch.ones( + *index_mask.shape[:-1], 1, dtype=torch.bool + ).to(index_mask.device) + ), dim=-1) + index_mask = index_mask.flatten(1, 2).squeeze(0) + image_feature = image_feature[index_mask] + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + index_mask = index_mask.permute(0, 2, 1, 3).contiguous() + index_mask = index_mask.flatten(0, 3) + image_feature = image_feature[index_mask] + if 'nobase' in mm_patch_merge_type: + raise NotImplementedError + else: + base_image_feature = base_image_feature[base_index_mask] + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: # single image operations + image_feature = image_feature[0] + index_mask = index_mask[0] + if 'unpad' in mm_patch_merge_type: + image_feature = torch.cat(( + image_feature, + pruning_paras['image_newline'][None] + ), dim=0) + index_mask = torch.cat(( + index_mask, + torch.ones(1, dtype=torch.bool).to(index_mask.device) + ), dim=0) + image_feature = image_feature[index_mask] + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError( + f'Unexpected mm_patch_merge_type: {model_config.mm_patch_merge_type}' + ) + return image_features + + self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( + change_images_hook( + self.model.vlm_model.prepare_inputs_labels_for_multimodal, + self.pruning_paras + ), + self.model.vlm_model + ) self.model.vision_model.vision_tower.register_forward_pre_hook( update_output_attentions_hook, @@ -115,5 +266,9 @@ def prune_hook(module, inputs, outputs, pruning_paras): ) self.model.vision_projector.register_forward_hook( - functools.partial(prune_hook, pruning_paras=self.pruning_paras), + functools.partial( + prune_hook, + pruning_paras=self.pruning_paras, + model_config=self.model.vlm_model_config + ), ) diff --git a/llmc/eval/eval_vqa.py b/llmc/eval/eval_vqa.py index ac829cc4..b8fd6503 100755 --- a/llmc/eval/eval_vqa.py +++ b/llmc/eval/eval_vqa.py @@ -89,10 +89,10 @@ def eval( datetime_str: str = get_datetime_str(), cli_args=None, ): - import argparse - cli_args = argparse.Namespace( - process_with_media=True, - ) + # import argparse + # cli_args = argparse.Namespace( + # process_with_media=True, + # ) model = llmc_model.eval_name model_args = 'pretrained=' + self.model_path + ',device_map=auto' diff --git a/llmc/models/llava.py b/llmc/models/llava.py index 7e8efcea..4a71d24f 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -44,11 +44,11 @@ def build_model(self): self.model_path, trust_remote_code=True ) logger.info(f'self.vlm_model_config : {self.vlm_model_config}') - + model_name = get_model_name_from_path(self.model_path) self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model( self.model_path, None, - get_model_name_from_path(self.model_path), + model_name, device_map='cpu', attn_implementation='sdpa' ) @@ -71,6 +71,8 @@ def build_model(self): 'IMAGE_TOKEN_INDEX': IMAGE_TOKEN_INDEX, # for llava 'vision_token_start_index': 35, } + if 'v1.6' in model_name.lower(): + self.pruning_config['image_token_length'] = None self.processor = None self.first_turn_question = True