diff --git a/configs/sparsification/methods/PyramidDrop/pyramidrop.yml b/configs/sparsification/methods/PyramidDrop/pyramidrop.yml new file mode 100644 index 000000000..39694ee41 --- /dev/null +++ b/configs/sparsification/methods/PyramidDrop/pyramidrop.yml @@ -0,0 +1,24 @@ +base: + seed: &seed 42 +model: + type: Llava + path: model path + torch_dtype: auto +eval: + eval_pos: [transformed] + type: vqa + name: [mme] + download: False + path: MME dataset path + bs: 1 + inference_per_block: False +sparse: + method: TokenReduction + special: + method: PyramidDrop + image_token_ratio_list: [0.5, 0.25, 0.125] + layer_list: [8, 16, 24] +save: + save_trans: False + save_fake: False + save_path: /path/to/save/ diff --git a/llmc/compression/token_reduction/__init__.py b/llmc/compression/token_reduction/__init__.py index d18c7eedf..aaa5bfc65 100755 --- a/llmc/compression/token_reduction/__init__.py +++ b/llmc/compression/token_reduction/__init__.py @@ -1,6 +1,7 @@ from .base_blockwise_token_reduction import TokenReduction from .fastervlm import FasterVLM from .fastv import FastV +from .pyramiddrop import PyramidDrop from .sparsevlm import SparseVLM from .tome import ToMe from .visionzip import VisionZip diff --git a/llmc/compression/token_reduction/pyramiddrop.py b/llmc/compression/token_reduction/pyramiddrop.py new file mode 100644 index 000000000..9be07ad99 --- /dev/null +++ b/llmc/compression/token_reduction/pyramiddrop.py @@ -0,0 +1,349 @@ +import functools +import math + +import torch +from torch import nn +from transformers.modeling_attn_mask_utils import \ + _prepare_4d_causal_attention_mask +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + +from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY + +from .token_reduction_module import TokenReductionModule + + +@TOKEN_REDUCTION_REGISTRY.register('PyramidDrop') +class PyramidDrop(TokenReductionModule): + def __init__(self, config, model, blocks): + super().__init__(config, model, blocks) + self.add_sparse_config() + self.register_reduction_modules() + + def add_sparse_config(self): + special_config = self.config.get('special', {}) + self.pruning_loc = special_config['layer_list'] + image_token_ratio_list = special_config['image_token_ratio_list'] + image_token_ratio_list.insert(0, 1.0) + special_config['image_token_ratio_list'] = image_token_ratio_list + special_config['tokenizer_padding_side'] = getattr( + self.model.vlm_model.language_model.model.config, + 'tokenizer_padding_side', + 'right', + ) + special_config['image_token_index'] = self.model.pruning_config[ + 'image_token_index' + ] + self.model.model.parameters = special_config + + def register_reduction_modules(self): + + def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx): + + if layer_idx == self.pruning_loc[0]: + position_ids = kwargs['position_ids'] + attention_mask = kwargs['attention_mask'] + position_embeddings = kwargs['position_embeddings'] + else: + attention_mask = pruning_pars['attention_mask'] + position_ids = pruning_pars['position_ids'] + position_embeddings = pruning_pars['position_embeddings'] + + features = args[0] + _position_ids = position_ids + _attention_mask = attention_mask + prompt_len = pruning_pars['prompt_len'] + image_tokens_list = pruning_pars['image_tokens'] + image_token_posi = pruning_pars['image_token_posi'] + image_token_ratio_list = pruning_pars['image_token_ratio_list'] + + if position_ids is None: + position_ids = torch.arange( + 0, features.shape[1], dtype=torch.long, device=features.device + ).unsqueeze(0) + + if pruning_pars['tokenizer_padding_side'] == 'right': + + batch_size = features.shape[0] + image_tokens = [ + int(cur_image_token * image_token_ratio_list[cur_num]) + for cur_image_token in image_tokens_list + ] + keep_length = [ + int(cur_image_token * image_token_ratio_list[cur_num + 1]) + for cur_image_token in image_tokens_list + ] + + features_list = [] + attention_mask_list = [] + + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, features.shape[1]), + dtype=torch.bool, + device=features.device, + ) + else: + attention_mask = attention_mask.bool() + + # obtain query_states and key_states to calculate attention map + hidden_states = features.clone().detach() + self_attn = module.self_attn + hidden_states = module.input_layernorm(hidden_states) + + num_heads = self_attn.num_heads + num_key_value_heads = self_attn.num_key_value_heads + head_dim = self_attn.head_dim + + bsz, q_len, _ = hidden_states.size() + + query_states = self_attn.q_proj(hidden_states) + key_states = self_attn.k_proj(hidden_states) + value_states = self_attn.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, num_heads, head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, num_key_value_heads, head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, num_key_value_heads, head_dim + ).transpose(1, 2) + + if position_embeddings is None: + cos, sin = self_attn.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + # attention_mask + eager_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, q_len), + hidden_states, + past_key_values_length=0, + ).to(device=query_states.device) + + # take valid features + features = [ + cur_features[cur_attention_mask] + for cur_features, cur_attention_mask in zip( + features, attention_mask + ) + ] + attention_mask = [ + cur_attention_mask[cur_attention_mask] + for cur_attention_mask, cur_attention_mask in zip( + attention_mask, attention_mask + ) + ] + + # rank & drop + for i in range(batch_size): + image_index = image_token_posi[i] + if image_index == -1: + cur_input_embeds = features[i] + features_list.append(cur_input_embeds) + attention_mask_list.append(attention_mask[i]) + continue + + # obtain current states + cur_key_states = key_states[i] + cur_query_states = query_states[i] + cur_eager_attention_mask = eager_attention_mask[i] + + prompt_total_len = prompt_len[i] + image_tokens[i] + text_query_states = cur_query_states[ + :, prompt_total_len - 1, : + ].unsqueeze(1) + text_eager_attention_mask = cur_eager_attention_mask[ + :, prompt_total_len - 1, : + ].unsqueeze(1) + + # calculate attention map + attn_weights = torch.matmul( + text_query_states, cur_key_states.transpose(1, 2) + ) / math.sqrt( + head_dim + ) # (num_head, text_token,seq_len) + attn_weights = attn_weights + text_eager_attention_mask + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to( + query_states.dtype + ) # (num_head, text_token,seq_len) + + attention_avg_head = torch.mean( + attn_weights, dim=0 + ) # ave across heads + attention_avg_head = attention_avg_head[ + :, image_index: image_index + image_tokens[i] + ] # select image token as keys + attention_avg_text = torch.mean(attention_avg_head, dim=0) # (576) + + # rank and drop by attention score + top_rank_index = attention_avg_text.topk(keep_length[i]).indices + top_rank_index = top_rank_index + image_index + top_rank_index = top_rank_index.sort().values + + start_index = image_index + image_tokens[i] + new_input_embeds = torch.cat( + [ + features[i][:image_index, :], + features[i][top_rank_index, :], + features[i][start_index:, :], + ], + dim=0, + ) + new_attention_mask = torch.cat( + [ + attention_mask[i][:image_index], + attention_mask[i][top_rank_index], + attention_mask[i][start_index:], + ], + dim=0, + ) + + features_list.append(new_input_embeds) + attention_mask_list.append(new_attention_mask) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr( + self.model.vlm_model.language_model.model.config, + 'tokenizer_model_max_length', + 2048, + ) + if tokenizer_model_max_length is not None: + new_input_embeds = [ + x[:tokenizer_model_max_length] for x in features_list + ] + new_attention_mask = [ + x[:tokenizer_model_max_length] for x in attention_mask_list + ] + + max_len = max(x.shape[0] for x in new_input_embeds) + + # padding the sequences to form batch + embeds_padded = [] + attention_mask_padded = [] + position_ids = torch.zeros( + (batch_size, max_len), + dtype=position_ids.dtype, + device=position_ids.device, + ) + for i, cur_new_embed in enumerate(new_input_embeds): + cur_len_emb = cur_new_embed.shape[0] + dif = max_len - cur_len_emb # padding to longest seq + + cur_new_embed = torch.cat( + [ + cur_new_embed, + torch.zeros( + (dif, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + ], + dim=0, + ) + cur_attention_mask = new_attention_mask[i] + cur_attention_mask = torch.cat( + [ + cur_attention_mask, + torch.full( + (dif,), + False, + dtype=cur_attention_mask.dtype, + device=cur_attention_mask.device, + ), + ], + dim=0, + ) + + embeds_padded.append(cur_new_embed) + + attention_mask_padded.append(cur_attention_mask) + + cur_len = new_attention_mask[i].sum().item() + position_ids[i, :cur_len] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + + new_input_embeds = torch.stack(embeds_padded, dim=0) + new_input_embeds = new_input_embeds.to(features[0].dtype) + + new_attention_mask = torch.stack(attention_mask_padded, dim=0) + + if _position_ids is None: + position_ids = None + + if _attention_mask is None: + new_attention_mask = None + else: + new_attention_mask = new_attention_mask.to( + dtype=_attention_mask.dtype + ) + + kwargs['attention_mask'] = new_attention_mask + kwargs['position_ids'] = position_ids + kwargs['position_embeddings'] = None + pruning_pars['attention_mask'] = new_attention_mask + pruning_pars['position_ids'] = position_ids + pruning_pars['position_embeddings'] = None + + return (new_input_embeds,), kwargs + + def input_hook(module, input_args, pruning_pars): + input_ids = input_args[0] + pre_prompt_length_list = [] + image_token_posi = [] + image_tokens = [] + IMAGE_TOKEN_INDEX = pruning_pars['image_token_index'] + + # find the position of the first image token + for seq in input_ids: + image_token_idxs = (seq == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0] + image_tokens.append(image_token_idxs.shape[0]) + image_token_posi.append(image_token_idxs[0].item()) + pre_prompt_length_list.append(seq.shape[0] - image_token_idxs.shape[0]) + + pruning_pars['prompt_len'] = pre_prompt_length_list + pruning_pars['image_token_posi'] = image_token_posi + pruning_pars['image_tokens'] = image_tokens + + return input_args + + def read_parameter_hook(module, args, kwargs, pruning_pars): + kwargs['attention_mask'] = pruning_pars['attention_mask'] + # kwargs['cache_position'] = pruning_pars['cache_position'] + kwargs['position_ids'] = pruning_pars['position_ids'] + kwargs['position_embeddings'] = pruning_pars['position_embeddings'] + + return args, kwargs + + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(input_hook, pruning_pars=self.model.model.parameters) + ) + + for layer_idx in range(self.pruning_loc[0], len(self.blocks)): + if layer_idx in self.pruning_loc: + stage = self.pruning_loc.index(layer_idx) + self.blocks[layer_idx].register_forward_pre_hook( + functools.partial( + pruning_hook, + pruning_pars=self.model.model.parameters, + cur_num=stage, + layer_idx=layer_idx, + ), + with_kwargs=True, + ) + else: + self.blocks[layer_idx].register_forward_pre_hook( + functools.partial( + read_parameter_hook, pruning_pars=self.model.model.parameters + ), + with_kwargs=True, + ) diff --git a/llmc/eval/eval_vqa.py b/llmc/eval/eval_vqa.py index f4cb9e0f5..8c3560f66 100755 --- a/llmc/eval/eval_vqa.py +++ b/llmc/eval/eval_vqa.py @@ -19,8 +19,8 @@ def __init__(self, config): self.eval_config = config.eval self.model_path = config.model.path self.eval_dataset_name = self.eval_config['name'] - if not isinstance(self.dataset, list): - self.eval_dataset_name = [self.dataset, ] + if not isinstance(self.eval_dataset_name, list): + self.eval_dataset_name = [self.eval_dataset_name, ] self.eval_dataset_path = self.eval_config['path'] self.eval_bs = self.eval_config['bs'] @@ -62,7 +62,7 @@ def eval( model = llmc_model.eval_name model_args = 'pretrained=' + self.model_path + ',device_map=auto' batch_size = self.eval_bs - tasks = self.dataset + tasks = self.eval_dataset_name num_fewshot = 0 seed_message = []