diff --git a/llmc/compression/token_reduction/holitom.py b/llmc/compression/token_reduction/holitom.py index 0208d6c6e..5b48691fe 100644 --- a/llmc/compression/token_reduction/holitom.py +++ b/llmc/compression/token_reduction/holitom.py @@ -35,7 +35,7 @@ def SigLipEncoder_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, -) -> Union[Tuple]: +): output_attentions = ( output_attentions if output_attentions is not None @@ -934,7 +934,10 @@ def prepare_inputs_labels_for_multimodal( new_input_embeds = [] new_labels = [] - if os.getenv('HOLITOM_k') is not None and os.getenv('HOLITOM_r') is not None: + if ( + self.pruning_paras.get('HOLITOM_k', None) is not None + and self.pruning_paras.get('HOLITOM_r', None) is not None + ): # [modified] image_token_posi = [] prompt_len = [] @@ -942,8 +945,8 @@ def prepare_inputs_labels_for_multimodal( # rank_print("Inserting Images embedding") for batch_idx, cur_input_ids in enumerate(input_ids): if ( - os.getenv('HOLITOM_k') is not None - and os.getenv('HOLITOM_r') is not None + self.pruning_paras.get('HOLITOM_k', None) is not None + and self.pruning_paras.get('HOLITOM_r', None) is not None ): # [modified] # record image position for further dropping @@ -1036,7 +1039,10 @@ def prepare_inputs_labels_for_multimodal( new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) - if os.getenv('HOLITOM_k') is not None and os.getenv('HOLITOM_r') is not None: + if ( + self.pruning_paras.get('HOLITOM_k', None) is not None + and self.pruning_paras.get('HOLITOM_r', None) is not None + ): # [modified] self.model.image_token_posi = image_token_posi self.model.prompt_len = prompt_len @@ -1173,6 +1179,7 @@ def __init__(self, config, model, blocks): def add_sparse_config(self): special_config = self.config.get('special', {}) self.model.model.pruning_paras = special_config + self.model.model.model.pruning_paras = special_config if self.model.__class__.__name__ == 'Llava_OneVision': SigLipEncoder.forward = SigLipEncoder_forward @@ -1211,5 +1218,283 @@ def add_sparse_config(self): LlavaMetaForCausalLM_holitom.add_newline_token ) + if ( + self.special_config.get('HOLITOM_k', None) is not None + and self.special_config.get('HOLITOM_r', None) is not None + ): + from functools import partial + + from transformers.cache_utils import Cache, DynamicCache + from transformers.modeling_flash_attention_utils import \ + FlashAttentionKwargs + from transformers.modeling_outputs import \ + BaseModelOutputWithPast + from transformers.processing_utils import Unpack + + def qwen_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + 'You must specify exactly one of input_ids or inputs_embeds' + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + '`use_cache=True` is incompatible with gradient checkpointing.' + + 'Setting `use_cache=False`.' + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- + # it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError( + 'The `past_key_values` should be either a `Cache` object or `None`.' + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() + if past_key_values is not None + else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + HOLITOM_k = self.pruning_paras.get('HOLITOM_k', 3) + HOLITOM_r = self.pruning_paras.get('HOLITOM_r', 0.5) + HOLITOM_image_token_start_index = self.image_token_posi[0] + HOLITOM_image_token_length = self.image_tokens[0] + seq_length_with_past = past_seen_tokens + inputs_embeds.shape[1] + + for layer_idx, decoder_layer in enumerate( + self.layers[: self.config.num_hidden_layers] + ): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + if layer_idx < HOLITOM_k: + pass + elif layer_idx == HOLITOM_k and position_ids.size(1) > 1: + # compute pruned tokens, generate fastv sign + last_layer_attention = layer_outputs[1] + # compute average attention over different head + last_layer_attention_avg = torch.mean( + last_layer_attention, dim=1 + )[0] + # generate new attention mask based on the average attention, + # sample the top ATTENTION_RANK tokens with highest attention + last_layer_attention_avg_last_tok = ( + last_layer_attention_avg[-1] + ) + # get the attention in image token + last_layer_attention_avg_last_tok_image = \ + last_layer_attention_avg_last_tok[ + HOLITOM_image_token_start_index: + HOLITOM_image_token_start_index + + HOLITOM_image_token_length + ] + # get the indexes of the top ATTENTION_RANK tokens + top_attention_rank_index = ( + last_layer_attention_avg_last_tok_image.topk( + round( + HOLITOM_image_token_length * (1 - HOLITOM_r) + ) + ).indices + + HOLITOM_image_token_start_index + ) + # print("Before merge:", HOLITOM_image_token_length, "After merge:", + # round(HOLITOM_image_token_length*(1-HOLITOM_r))) + + device = hidden_states.device + # [modified] + all_indices = torch.arange( + HOLITOM_image_token_length, device=device + ) + non_topk_mask = ~torch.isin( + all_indices, + top_attention_rank_index + - HOLITOM_image_token_start_index, + ) + non_topk_indices = ( + all_indices[non_topk_mask] + + HOLITOM_image_token_start_index + ) + non_topk_states = hidden_states[ + :, non_topk_indices, : + ] # [batch_size, len(non_topk), hidden_size] + topk_states = hidden_states[ + :, top_attention_rank_index, : + ] # [batch_size, len(topk), hidden_size] + non_topk_norm = torch.norm( + non_topk_states, dim=-1, keepdim=True + ) # [batch_size, len(non_topk), 1] + topk_norm = torch.norm( + topk_states, dim=-1, keepdim=True + ) # [batch_size, len(topk), 1] + dot_product = torch.bmm( + non_topk_states, topk_states.transpose(1, 2) + ) # [batch_size, len(non_topk), len(topk)] + sim_matrix = dot_product / ( + non_topk_norm * topk_norm.transpose(1, 2) + ) + sim_max, sim_max_index = torch.max(sim_matrix, dim=-1) + + for b in range(hidden_states.size(0)): + for i in range(len(non_topk_indices)): + non_topk_idx = non_topk_indices[i] + most_similar_topk_idx = ( + top_attention_rank_index[ + sim_max_index[b, i] + ] + ) + hidden_states[b, most_similar_topk_idx, :] = ( + hidden_states[b, most_similar_topk_idx, :] + + hidden_states[b, non_topk_idx, :] + ) / 2 + # [modified] + + # keep index + keep_indexes = torch.cat( + ( + torch.arange( + HOLITOM_image_token_start_index, + device=device, + ), + top_attention_rank_index, + torch.arange( + HOLITOM_image_token_start_index + + HOLITOM_image_token_length, + seq_length_with_past, + device=device, + ), + ) + ) + # sort index + keep_indexes = keep_indexes.sort().values + # update seq length + new_seq_length = keep_indexes.shape[0] + # filter hidden states + + hidden_states = hidden_states[ + :, keep_indexes, : + ] + # lead the cuda error in the + # second iteration of decoding layeridx 3 + # update position ids + position_ids = keep_indexes.unsqueeze(0) + + position_embeddings = self.rotary_emb( + hidden_states, position_ids + ) + + cache_position = cache_position[:new_seq_length] + + if layer_idx == HOLITOM_k - 1: + output_attentions = True + else: + output_attentions = False + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + # if output_attentions: + # all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + + Qwen2Model.forward = qwen_forward + def register_reduction_modules(self): pass diff --git a/llmc/compression/token_reduction/token_reduction_module.py b/llmc/compression/token_reduction/token_reduction_module.py index 0645ef538..619cf62ad 100644 --- a/llmc/compression/token_reduction/token_reduction_module.py +++ b/llmc/compression/token_reduction/token_reduction_module.py @@ -1,3 +1,8 @@ +import time + +import torch +from loguru import logger + class TokenReductionModule: def __init__(self, config, model, blocks): diff --git a/llmc/eval/eval_vqa.py b/llmc/eval/eval_vqa.py index 8c3560f66..39289e70e 100755 --- a/llmc/eval/eval_vqa.py +++ b/llmc/eval/eval_vqa.py @@ -1,4 +1,5 @@ import random +import time from typing import List, Optional, Union import numpy as np @@ -20,10 +21,40 @@ def __init__(self, config): self.model_path = config.model.path self.eval_dataset_name = self.eval_config['name'] if not isinstance(self.eval_dataset_name, list): - self.eval_dataset_name = [self.eval_dataset_name, ] + self.eval_dataset_name = [ + self.eval_dataset_name, + ] self.eval_dataset_path = self.eval_config['path'] self.eval_bs = self.eval_config['bs'] + self.statistics = self.eval_config.get('statistics', False) + + def set_statistics_modules(self, model): + + def start_time_hook(module, args, kwargs): + torch.cuda.synchronize() + module.start_time = time.time() + return args, kwargs + + def end_time_hook(module, inputs, kwargs, layer_outputs): + torch.cuda.synchronize() + elapsed_prefill = time.time() - module.start_time + if kwargs['inputs_embeds'] is not None: + module.prefill_count += 1 + module.prefill_time += elapsed_prefill + else: + model.decode_count += 1 + model.decode_time += elapsed_prefill + + model.prefill_count = 0 + model.prefill_time = 0 + model.decode_time = 0 + model.decode_count = 0 + + model.register_forward_pre_hook(start_time_hook, with_kwargs=True) + + model.register_forward_hook(end_time_hook, with_kwargs=True) + def eval( self, llmc_model, @@ -82,7 +113,9 @@ def eval( if seed_message: logger.info(' | '.join(seed_message)) - assert tasks != [], 'No tasks specified, or no tasks found. Please verify the task names.' + assert ( + tasks != [] + ), 'No tasks specified, or no tasks found. Please verify the task names.' if gen_kwargs: gen_kwargs = simple_parse_args_string(gen_kwargs) @@ -98,6 +131,10 @@ def eval( task_dict = get_task_dict(tasks, task_manager) + if self.statistics: + self.set_statistics_modules(llmc_model.vlm_model) + torch.cuda.reset_peak_memory_stats() + lm = MODEL_REGISTRY[model].create_from_arg_string( model_args, { @@ -128,12 +165,15 @@ def _adjust_config(task_dict): lm.task_dict[task_name] = task_obj.dataset if 'generate_until' in task_obj.get_config('output_type'): if gen_kwargs is not None: - task_obj.set_config(key='generation_kwargs', - value=gen_kwargs, update=True) + task_obj.set_config( + key='generation_kwargs', value=gen_kwargs, update=True + ) if predict_only: - logger.info(f'Processing {task_name} in output-only mode. \ - Metrics will not be calculated!') + logger.info( + f'Processing {task_name} in output-only mode. \ + Metrics will not be calculated!' + ) # we have to change the class properties post-hoc. This is pretty hacky. task_obj.override_metric(metric_name='bypass') @@ -142,17 +182,25 @@ def _adjust_config(task_dict): # except if tasks have it set to 0 manually in their configs--then # we should never overwrite that if num_fewshot is not None: - if (default_num_fewshot := task_obj.get_config('num_fewshot')) == 0: - logger.info(f'num_fewshot has been set to 0 for {task_name} \ - in its config. Manual configuration will be ignored.') + if ( + default_num_fewshot := task_obj.get_config('num_fewshot') + ) == 0: + logger.info( + f'num_fewshot has been set to 0 for {task_name} \ + in its config. Manual configuration will be ignored.' + ) else: - logger.warning(f'Overwriting default num_fewshot of {task_name} \ - from {default_num_fewshot} to {num_fewshot}') + logger.warning( + f'Overwriting default num_fewshot of {task_name} \ + from {default_num_fewshot} to {num_fewshot}' + ) task_obj.set_config(key='num_fewshot', value=num_fewshot) else: # if num_fewshot not provided, and the task does not define a default one, # default to 0 - if (default_num_fewshot := task_obj.get_config('num_fewshot')) is None: + if ( + default_num_fewshot := task_obj.get_config('num_fewshot') + ) is None: task_obj.set_config(key='num_fewshot', value=0) # fewshot_random_seed set for tasks, even with a default num_fewshot # (e.g. in the YAML file) @@ -193,6 +241,19 @@ def _adjust_config(task_dict): cli_args=cli_args, ) + if self.statistics: + prefill = ( + llmc_model.vlm_model.prefill_time / llmc_model.vlm_model.prefill_count + ) + decode = ( + llmc_model.vlm_model.decode_time / llmc_model.vlm_model.decode_count + ) + gen_max_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 + + logger.info(f'peak memory: {gen_max_mem:.1f} MB.') + logger.info(f'prefill average time: {prefill *1000:.1f} ms.') + logger.info(f'decode average time: {decode *1000:.1f} ms.') + if hasattr(lm, '_model'): del lm._model torch.cuda.empty_cache() @@ -217,8 +278,11 @@ def _adjust_config(task_dict): results['config'].update( { 'batch_size': batch_size, - 'batch_sizes': (list(lm.batch_sizes.values()) - if hasattr(lm, 'batch_sizes') else []), + 'batch_sizes': ( + list(lm.batch_sizes.values()) + if hasattr(lm, 'batch_sizes') + else [] + ), 'device': device, 'use_cache': use_cache, 'limit': limit,