diff --git a/configs/sparsification/methods/FastV/fastv.yml b/configs/sparsification/methods/FastV/fastv.yml index 7c89968de..b8968ea7b 100644 --- a/configs/sparsification/methods/FastV/fastv.yml +++ b/configs/sparsification/methods/FastV/fastv.yml @@ -17,7 +17,7 @@ sparse: special: method: FastV pruning_loc: 3 - rate: 0.5 + rate: 0.778 save: save_trans: False save_fake: False diff --git a/configs/sparsification/methods/FasterVLM/fastervlm.yml b/configs/sparsification/methods/FasterVLM/fastervlm.yml index f96d123e7..0b2aaf43d 100644 --- a/configs/sparsification/methods/FasterVLM/fastervlm.yml +++ b/configs/sparsification/methods/FasterVLM/fastervlm.yml @@ -16,7 +16,7 @@ sparse: method: TokenReduction special: method: FasterVLM - rate: 0.75 + rate: 0.778 save: save_trans: False save_fake: False diff --git a/configs/sparsification/methods/VisionZip/visionzip.yml b/configs/sparsification/methods/VisionZip/visionzip.yml index d02d59f30..ff639f4a6 100644 --- a/configs/sparsification/methods/VisionZip/visionzip.yml +++ b/configs/sparsification/methods/VisionZip/visionzip.yml @@ -13,11 +13,12 @@ eval: bs: 1 inference_per_block: False sparse: - method: TokenReduction - special: - method: VisionZip - dominant: 191 - contextual: 30 + vision: + method: TokenReduction + special: + method: VisionZip + dominant: 191 # visual_tokens = dominan_tokens + 1(cls_token) + contextual: 30 save: save_trans: False save_fake: False diff --git a/llmc/compression/token_reduction/fastervlm.py b/llmc/compression/token_reduction/fastervlm.py index a65ba8b5b..635595b5a 100644 --- a/llmc/compression/token_reduction/fastervlm.py +++ b/llmc/compression/token_reduction/fastervlm.py @@ -24,7 +24,9 @@ def add_sparse_config(self): special_config['select_feature'] = self.model.pruning_config['select_feature'] special_config['image_token_index'] = self.model.pruning_config['image_token_index'] - self.model.model.parameters = special_config + special_config['image_attentions_list'] = [] + + self.pruning_paras = special_config def register_reduction_modules(self): @@ -32,20 +34,29 @@ def update_output_attentions_hook(module, args, kwargs): kwargs['output_attentions'] = True return args, kwargs - def store_attention_hook(m, x, image_forward_outs, pruning_pars): - image_attentions = image_forward_outs.attentions[pruning_pars['select_layer']] - if pruning_pars['select_feature'] == 'default': # patch - image_attentions = image_attentions[:, :, 0, 1:] - elif pruning_pars['select_feature'] == 'full': - image_attentions = image_attentions + def clear_attentions_hook(m, x, pruning_paras): + pruning_paras['image_attentions_list'].clear() + + def store_attention_hook(m, x, image_forward_outs, pruning_paras): + image_attentions = image_forward_outs.attentions[pruning_paras['select_layer']] + if pruning_paras['select_feature'] in ('default', 'patch'): + image_attention = image_attentions[:, :, 0, 1:] + elif pruning_paras['select_feature'] in ('full', 'cls_patch'): + image_attention = image_attentions else: raise ValueError(f'Unexpected select feature: {self.select_feature}') - pruning_pars['image_attentions'] = image_attentions + pruning_paras['image_attentions_list'].append(image_attention.to(x[0].dtype)) + + def update_attentions_hook(m, x, outs, pruning_paras): + if len(pruning_paras['image_attentions_list']) == 1: + pruning_paras['image_attentions'] = pruning_paras['image_attentions_list'][0] + else: + pruning_paras['image_attentions'] = pruning_paras['image_attentions_list'] - def pruning_hook(module, args, kwargs, pruning_pars): + def pruning_hook(module, args, kwargs, pruning_paras): image_features = args[0] - image_attentions = pruning_pars['image_attentions'] + image_attentions = pruning_paras['image_attentions'] # image_attentions = image_attentions.max(dim=1)[0] # (B, N) = (1, 576) image_attentions = image_attentions.mean(dim=1) # (B, N) = (1, 576) @@ -66,22 +77,22 @@ def pruning_hook(module, args, kwargs, pruning_pars): index_mask = torch.zeros(B, N, dtype=torch.bool, device=image_features.device) # (B, N) index_mask.scatter_(1, token_indices, True) # (B, N) - pruning_pars['index_mask'] = index_mask - pruning_pars['image_attentions'] = image_attentions + pruning_paras['index_mask'] = index_mask + pruning_paras['image_attentions'] = image_attentions return (image_features,), kwargs - def get_image_mask_hook(module, args, kwargs, pruning_pars): - pruning_pars['image_mask'] = ( - kwargs['input_ids'] == pruning_pars['image_token_index'] + def get_image_mask_hook(module, args, kwargs, pruning_paras): + pruning_paras['image_mask'] = ( + kwargs['input_ids'] == pruning_paras['image_token_index'] ) # (B, len) - def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_pars): + def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras): # Only batch size 1 is currently supported. inputs_embeds = kwargs['inputs_embeds'] - image_mask = pruning_pars['image_mask'][0] - index_mask = pruning_pars['index_mask'][0] + image_mask = pruning_paras['image_mask'][0] + index_mask = pruning_paras['index_mask'][0] B, L = inputs_embeds.shape[:2] device = inputs_embeds.device @@ -109,28 +120,67 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_pars): return args, kwargs - self.model.vision_model.register_forward_pre_hook( - update_output_attentions_hook, - with_kwargs=True - ) + def prepare_inputs_hook(module, inputs, outputs, pruning_paras): - self.model.vision_model.register_forward_hook( - functools.partial(store_attention_hook, pruning_pars=self.model.model.parameters), - ) + image_features = outputs + index_masks = pruning_paras['index_mask'] + # image_attentions = pruning_paras['image_attentions'] + new_image_features = [] + for image_feature, index_mask in zip(image_features, index_masks): + image_feature = image_feature[index_mask] + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + + outputs = image_features + pruning_paras['image_features_shape'] = image_features[0].shape[0] + + return outputs + + if self.model.__class__.__name__ == 'LlavaHf': + self.model.vision_model.register_forward_pre_hook( + update_output_attentions_hook, + with_kwargs=True + ) + + self.model.vision_model.register_forward_hook( + functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), + ) + elif self.model.__class__.__name__ == 'Llava': + self.model.vision_model.register_forward_pre_hook( + functools.partial(clear_attentions_hook, pruning_paras=self.pruning_paras), + ) + + self.model.vision_model.register_forward_hook( + functools.partial(update_attentions_hook, pruning_paras=self.pruning_paras), + ) + + self.model.vision_model.vision_tower.register_forward_pre_hook( + update_output_attentions_hook, + with_kwargs=True + ) + + self.model.vision_model.vision_tower.register_forward_hook( + functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), + ) self.model.vision_projector.register_forward_pre_hook( - functools.partial(pruning_hook, pruning_pars=self.model.model.parameters), + functools.partial(pruning_hook, pruning_paras=self.pruning_paras), with_kwargs=True ) - self.model.vlm_model.register_forward_pre_hook( - functools.partial(get_image_mask_hook, pruning_pars=self.model.model.parameters), - with_kwargs=True - ) + if self.model.__class__.__name__ == 'LlavaHf': + self.model.vlm_model.register_forward_pre_hook( + functools.partial(get_image_mask_hook, pruning_paras=self.pruning_paras), + with_kwargs=True + ) - self.model.model.register_forward_pre_hook( - functools.partial( - prepare_inputs_for_llm_hook, pruning_pars=self.model.model.parameters - ), - with_kwargs=True - ) + self.model.model.register_forward_pre_hook( + functools.partial( + prepare_inputs_for_llm_hook, pruning_paras=self.pruning_paras + ), + with_kwargs=True + ) + elif self.model.__class__.__name__ == 'Llava': + self.model.vision_projector.register_forward_hook( + functools.partial(prepare_inputs_hook, pruning_paras=self.pruning_paras), + ) diff --git a/llmc/compression/token_reduction/fastv.py b/llmc/compression/token_reduction/fastv.py index e288ad6c3..2fff71c8e 100644 --- a/llmc/compression/token_reduction/fastv.py +++ b/llmc/compression/token_reduction/fastv.py @@ -1,4 +1,5 @@ import functools +from types import MethodType import torch @@ -8,6 +9,8 @@ from .token_reduction_module import TokenReductionModule from .utils import prefill_wrapper +IMAGE_TOKEN_INDEX = -200 + @TOKEN_REDUCTION_REGISTRY.register('FastV') class FastV(TokenReductionModule): @@ -23,33 +26,53 @@ def add_sparse_config(self): self.model.pruning_config['image_token_length'] self.special_config['attn_scores'] = None - self.model.model.parameters = self.special_config + self.pruning_paras = self.special_config def register_reduction_modules(self): @prefill_wrapper - def input_hook(module, input_args, pruning_pars): + def input_hook(module, input_args, pruning_paras): input_ids = input_args[0] image_token_idxs = (input_ids[0] == - pruning_pars['vision_token_index']).nonzero(as_tuple=True)[0] - pruning_pars['image_token_start_index'] = image_token_idxs[0].item() + pruning_paras['vision_token_index']).nonzero(as_tuple=True)[0] + pruning_paras['image_token_start_index'] = image_token_idxs[0].item() return input_args + def make_hook_prepare_inputs_labels_for_multimodal(pruning_paras): + def hook_prepare_inputs_labels_for_multimodal( + self, + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ): + if 'image_token_start_index' not in pruning_paras: + token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX + pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item() + return self._original_prepare_inputs_labels_for_multimodal( + input_ids, position_ids, attention_mask, + past_key_values, labels, images, image_sizes + ) + return hook_prepare_inputs_labels_for_multimodal + def update_output_attentions_hook(module, args, kwargs): kwargs['output_attentions'] = True return args, kwargs - def store_attention_hook(m, x, layer_outputs, pruning_pars): + def store_attention_hook(m, x, layer_outputs, pruning_paras): layer_attention = layer_outputs[1] - pruning_pars['attn_scores'] = layer_attention + pruning_paras['attn_scores'] = layer_attention @prefill_wrapper - def fastv_pruning_hook(module, args, kwargs, pruning_pars): + def fastv_pruning_hook(module, args, kwargs, pruning_paras): - rate = pruning_pars['rate'] - image_token_start_index = pruning_pars['image_token_start_index'] - image_token_length = pruning_pars['image_token_length'] + rate = pruning_paras['rate'] + image_token_start_index = pruning_paras['image_token_start_index'] + image_token_length = pruning_paras['image_token_length'] hidden_states = args[0] causal_mask = kwargs['attention_mask'] @@ -57,7 +80,7 @@ def fastv_pruning_hook(module, args, kwargs, pruning_pars): device = hidden_states.device # last_layer_attention = layer_outputs[1] - last_layer_attention = pruning_pars['attn_scores'] + last_layer_attention = pruning_paras['attn_scores'] # compute average attention over different head last_layer_attention_avg = torch.mean(last_layer_attention, dim=1)[0] # generate new attention mask based on the average attention, @@ -98,25 +121,34 @@ def fastv_pruning_hook(module, args, kwargs, pruning_pars): kwargs['cache_position'] = cache_position[:new_seq_length] kwargs['position_ids'] = position_ids kwargs['position_embeddings'] = None - pruning_pars['attention_mask'] = causal_mask - pruning_pars['cache_position'] = cache_position[:new_seq_length] - pruning_pars['position_ids'] = position_ids - pruning_pars['position_embeddings'] = None + pruning_paras['attention_mask'] = causal_mask + pruning_paras['cache_position'] = cache_position[:new_seq_length] + pruning_paras['position_ids'] = position_ids + pruning_paras['position_embeddings'] = None return (hidden_states,), kwargs @prefill_wrapper - def read_parameter_hook(module, args, kwargs, pruning_pars): - kwargs['attention_mask'] = pruning_pars['attention_mask'] - kwargs['cache_position'] = pruning_pars['cache_position'] - kwargs['position_ids'] = pruning_pars['position_ids'] - kwargs['position_embeddings'] = pruning_pars['position_embeddings'] + def read_parameter_hook(module, args, kwargs, pruning_paras): + kwargs['attention_mask'] = pruning_paras['attention_mask'] + kwargs['cache_position'] = pruning_paras['cache_position'] + kwargs['position_ids'] = pruning_paras['position_ids'] + kwargs['position_embeddings'] = pruning_paras['position_embeddings'] return args, kwargs - self.model.embed_tokens.register_forward_pre_hook( - functools.partial(input_hook, pruning_pars=self.model.model.parameters) - ) + if self.model.__class__.__name__ == 'LlavaHf': + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(input_hook, pruning_paras=self.pruning_paras) + ) + elif self.model.__class__.__name__ == 'Llava': + hook_fn = make_hook_prepare_inputs_labels_for_multimodal(self.pruning_paras) + self.model.vlm_model._original_prepare_inputs_labels_for_multimodal = ( + self.model.vlm_model.prepare_inputs_labels_for_multimodal + ) + self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( + hook_fn, self.model.vlm_model + ) self.blocks[self.pruning_loc - 1].register_forward_pre_hook( update_output_attentions_hook, @@ -124,16 +156,16 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): ) self.blocks[self.pruning_loc - 1].register_forward_hook( - functools.partial(store_attention_hook, pruning_pars=self.model.model.parameters), + functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), ) self.blocks[self.pruning_loc].register_forward_pre_hook( - functools.partial(fastv_pruning_hook, pruning_pars=self.model.model.parameters), + functools.partial(fastv_pruning_hook, pruning_paras=self.pruning_paras), with_kwargs=True ) for idx in range(self.pruning_loc + 1, len(self.blocks)): self.blocks[idx].register_forward_pre_hook( - functools.partial(read_parameter_hook, pruning_pars=self.model.model.parameters), + functools.partial(read_parameter_hook, pruning_paras=self.pruning_paras), with_kwargs=True ) diff --git a/llmc/compression/token_reduction/utils.py b/llmc/compression/token_reduction/utils.py index 496ff6345..727f91107 100755 --- a/llmc/compression/token_reduction/utils.py +++ b/llmc/compression/token_reduction/utils.py @@ -93,6 +93,7 @@ def apply_info(model, dominant_num, contextual_num): for module in model.modules(): if isinstance(module, CLIPEncoderLayer): module.self_attn.k_proj._info = model._info + module.self_attn.k_proj.metric = None def add_post_hook_to_get_2dPool(model, post_hook_fn, pruning_paras): diff --git a/llmc/compression/token_reduction/visionzip.py b/llmc/compression/token_reduction/visionzip.py index 962aa2a3f..f72580a5b 100755 --- a/llmc/compression/token_reduction/visionzip.py +++ b/llmc/compression/token_reduction/visionzip.py @@ -231,128 +231,55 @@ def visionzip_forward( ) -class CLIPVisionTower_VisionZip(nn.Module): - - @torch.no_grad() - def forward(self, images): - - if type(images) is list: - image_features = [] - for image in images: - image_forward_out = self.vision_tower( - image.to(device=self.device, dtype=self.dtype).unsqueeze(0), - output_hidden_states=True, - output_attentions=True, - ) - image_feature = self.feature_select(image_forward_out).to(image.dtype) - image_features.append(image_feature) - else: - - image_forward_outs = self.vision_tower( - images.to(device=self.device, dtype=self.dtype), - output_hidden_states=True, - output_attentions=True, - ) - attn_weights = image_forward_outs.attentions[-2] - hidden_states = image_forward_outs.hidden_states[-2] - metric = self.vision_tower.vision_model.encoder.layers[-2].metric - dominant_num = self.vision_tower._info['dominant'] - contextual_num = self.vision_tower._info['contextual'] - - # Dominant Visual Tokens - cls_idx = 0 - cls_attention = attn_weights[:, :, cls_idx, cls_idx + 1:] - cls_attention_sum = cls_attention.sum(dim=1) - topk_indices = cls_attention_sum.topk(dominant_num, dim=1).indices + 1 - all_indices = torch.cat( - [ - torch.zeros( - (hidden_states.shape[0], 1), - dtype=topk_indices.dtype, - device=topk_indices.device, - ), - topk_indices, - ], - dim=1, - ) - - mask = torch.ones_like( - hidden_states[:, :, 0], dtype=torch.bool, device=metric.device - ).scatter_(1, all_indices, False) - dominant_tokens = hidden_states.masked_select(~mask.unsqueeze(-1)).view( - hidden_states.shape[0], dominant_num + 1, hidden_states.shape[2] - ) - - # Filter - metric_filtered = metric[mask].view( - hidden_states.shape[0], - hidden_states.shape[1] - (dominant_num + 1), - metric.shape[2], - ) - - hidden_states_filtered = hidden_states.masked_select( - mask.unsqueeze(-1) - ).view( - hidden_states.shape[0], - hidden_states.shape[1] - (dominant_num + 1), - hidden_states.shape[2], - ) - - metric_normalized = metric_filtered / metric_filtered.norm( - dim=-1, keepdim=True - ) +def CLIP_EncoderLayer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, +) -> Tuple[torch.FloatTensor]: + # docformatter: off + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer + `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + See `attentions` under + returned tensors for more detail. + """ + # docformatter: on + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + metric = self.self_attn.k_proj.metric - # Contextual Visual Tokens - step = max(1, metric_normalized.shape[1] // contextual_num) - target_indices = torch.arange( - 0, metric_normalized.shape[1], step, device=metric_normalized.device - )[:contextual_num] - target_tokens = metric_normalized[:, target_indices, :] + hidden_states = residual + hidden_states - tokens_to_merge = metric_normalized[ - :, - ~torch.isin( - torch.arange( - metric_normalized.shape[1], device=metric_normalized.device - ), - target_indices, - ), - :, - ] - similarity = torch.bmm(tokens_to_merge, target_tokens.transpose(1, 2)) - assign_one_hot = torch.zeros( - tokens_to_merge.shape[0], - tokens_to_merge.shape[1], - contextual_num, - dtype=hidden_states_filtered.dtype, - device=metric_normalized.device, - ) - assign_one_hot.scatter_(2, similarity.argmax(dim=2).unsqueeze(-1), 1) - counts = assign_one_hot.sum(dim=1).clamp(min=1).unsqueeze(-1) - hidden_to_merge = hidden_states_filtered[ - :, - ~torch.isin( - torch.arange( - hidden_states_filtered.shape[1], - device=hidden_states_filtered.device, - ), - target_indices, - ), - :, - ] - aggregated_hidden = ( - torch.bmm(assign_one_hot.transpose(1, 2), hidden_to_merge) / counts - ) - target_hidden = hidden_states_filtered[:, target_indices, :] + r = self.self_attn.k_proj._info['r'].pop(0) + if r > 0: + self.metric = metric + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states - contextual_tokens = target_hidden + aggregated_hidden + outputs = (hidden_states,) - # Merge with target hidden states and concatenate - hidden_states_save = torch.cat( - [dominant_tokens, contextual_tokens], dim=1 - ).to(images.dtype) + if output_attentions: + outputs += (attn_weights,) - return hidden_states_save, all_indices + return outputs @TOKEN_REDUCTION_REGISTRY.register('VisionZip') @@ -487,23 +414,32 @@ def update_output_attentions_hook(module, args, kwargs): kwargs['output_attentions'] = True return args, kwargs + if self.model.__class__.__name__ == 'LlavaHf': + vision_tower = self.model.vlm_model.vision_tower + elif self.model.__class__.__name__ == 'Llava': + vision_tower = self.model.vlm_model.model.vision_tower.vision_tower + apply_info( - self.model.vlm_model.vision_tower, + vision_tower, dominant_num=self.dominant, contextual_num=self.contextual, ) - self.model.vlm_model.__class__.forward = visionzip_forward - self.model.vlm_model.vision_tower.register_forward_pre_hook( + if self.model.__class__.__name__ == 'LlavaHf': + self.model.vlm_model.__class__.forward = visionzip_forward + elif self.model.__class__.__name__ == 'Llava': + from transformers.models.clip.modeling_clip import CLIPEncoderLayer + CLIPEncoderLayer.forward = CLIP_EncoderLayer_forward + + vision_tower.register_forward_pre_hook( update_output_attentions_hook, with_kwargs=True ) - self.blocks = self.model.vlm_model.vision_tower.vision_model.encoder.layers - r = self.model.vlm_model.vision_tower.r + r = vision_tower.r for idx, block in enumerate(self.blocks): if r[idx]: block.self_attn.k_proj.num_heads = block.self_attn.num_heads block.self_attn.k_proj.head_dim = block.self_attn.head_dim block.self_attn.k_proj.register_forward_hook(store_key_hook) - self.model.vlm_model.vision_tower.register_forward_hook(visionzip_hook) + vision_tower.register_forward_hook(visionzip_hook) diff --git a/llmc/models/llava.py b/llmc/models/llava.py index 3256b65b6..818fd313e 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -40,9 +40,9 @@ def build_model(self): self.vlm_model_config = AutoConfig.from_pretrained( self.model_path, trust_remote_code=True ) - if not self.use_cache: - self.llava_config.use_cache = False - self.vlm_model_config.use_cache = False + # llava need: use_cache + self.llava_config.use_cache = True + self.vlm_model_config.use_cache = True logger.info(f'self.vlm_model_config : {self.vlm_model_config}') self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model( self.model_path, @@ -59,7 +59,6 @@ def build_model(self): ori_forward = self.vlm_model.forward def safe_forward(*args, **kwargs): - kwargs['use_cache'] = False kwargs.pop('cache_position', None) return ori_forward(*args, **kwargs) self.vlm_model.forward = safe_forward