diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 80c27008b..3fa87dc1b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -30,7 +30,7 @@ jobs: - name: Download dataset run: | - # pwd # /home/runner/work/llmc/llmc + # pwd # /home/runner/work/LightCompress/LightCompress cd tools python download_calib_dataset.py --save_path ../check/datasets/calib --dataset_name pileval python download_eval_dataset.py --save_path ../check/datasets/eval --dataset_name wikitext2 @@ -46,17 +46,17 @@ jobs: - name: Preparation for check. run: | - cd ci_check # /home/runner/work/llmc/llmc/ci_check + cd ci_check # /home/runner/work/LightCompress/LightCompress/ci_check python change_files.py - name: Run awq check run: | - cd ci_check # /home/runner/work/llmc/llmc/ci_check + cd ci_check # /home/runner/work/LightCompress/LightCompress/ci_check bash run_awq.sh - name: Run gptq check run: | - cd ci_check # /home/runner/work/llmc/llmc/ci_check + cd ci_check # /home/runner/work/LightCompress/LightCompress/ci_check bash run_gptq.sh - name: Check success diff --git a/ci_check/awq_w4a16_fakequant_eval.yml b/ci_check/awq_w4a16_fakequant_eval.yml index d3ef88609..e2baef7c5 100644 --- a/ci_check/awq_w4a16_fakequant_eval.yml +++ b/ci_check/awq_w4a16_fakequant_eval.yml @@ -2,12 +2,12 @@ base: seed: &seed 42 model: type: Opt - path: /home/runner/work/llmc/llmc/ci_check/opt-125m + path: /home/runner/work/LightCompress/LightCompress/ci_check/opt-125m torch_dtype: auto calib: name: pileval download: False - path: /home/runner/work/llmc/llmc/check/datasets/calib/pileval + path: /home/runner/work/LightCompress/LightCompress/check/datasets/calib/pileval n_samples: 4 # 128 bs: -1 seq_len: 16 # 512 @@ -17,7 +17,7 @@ eval: eval_pos: [pretrain, transformed, fake_quant] name: wikitext2 download: False - path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2 + path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2 bs: 1 seq_len: 16 # 2048 eval_token_consist: True @@ -35,4 +35,4 @@ quant: clip_sym: False save: save_trans: False - save_path: /home/runner/work/llmc/llmc/save/opt-125m_awq_w4a16 + save_path: /home/runner/work/LightCompress/LightCompress/save/opt-125m_awq_w4a16 diff --git a/ci_check/gptq_w_only.yml b/ci_check/gptq_w_only.yml index 51c0ac43d..03f64a893 100644 --- a/ci_check/gptq_w_only.yml +++ b/ci_check/gptq_w_only.yml @@ -2,13 +2,13 @@ base: seed: &seed 0 model: type: Opt - path: /home/runner/work/llmc/llmc/ci_check/opt-125m + path: /home/runner/work/LightCompress/LightCompress/ci_check/opt-125m torch_dtype: auto calib: name: wikitext2 download: False n_samples: 4 - path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2 + path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2 bs: 1 seq_len: 16 preproc: wikitext2_gptq @@ -17,7 +17,7 @@ eval: eval_pos: [fake_quant] name: wikitext2 download: False - path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2 + path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2 bs: 1 seq_len: 16 inference_per_block: False @@ -40,4 +40,4 @@ quant: quant_out: True save: save_fake: False - save_path: /home/runner/work/llmc/llmc/save/opt-125m_gptq_w4a16 + save_path: /home/runner/work/LightCompress/LightCompress/save/opt-125m_gptq_w4a16 diff --git a/configs/sparsification/methods/SparseVLM/sparsevlm.yml b/configs/sparsification/methods/SparseVLM/sparsevlm.yml index eb4b62bb7..84d91adaf 100644 --- a/configs/sparsification/methods/SparseVLM/sparsevlm.yml +++ b/configs/sparsification/methods/SparseVLM/sparsevlm.yml @@ -19,6 +19,7 @@ sparse: pruning_loc: [2] # [2, 6, 15] retained_tokens: 192 init_token_total_shape: 668 + merge_flag: False save: save_trans: False save_fake: False diff --git a/llmc/compression/token_reduction/dart.py b/llmc/compression/token_reduction/dart.py index 8c0be44bd..68c7a4161 100644 --- a/llmc/compression/token_reduction/dart.py +++ b/llmc/compression/token_reduction/dart.py @@ -44,7 +44,7 @@ def wrapper(self, *args, **kwargs): token_indices = ( input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX'] ) - pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item() + pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item() outputs = fn(*args, **kwargs) return outputs @@ -67,7 +67,7 @@ def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_i hidden_states = kwargs['hidden_states'] position_embeddings = kwargs['position_embeddings'] position_ids = kwargs['position_ids'] - past_key_value = kwargs['past_key_value'] + past_key_value = layer_outs[2] bsz, q_len, _ = hidden_states.size() query_states = module.q_proj(hidden_states) @@ -193,10 +193,8 @@ def get_retained_image_token(pruning_paras, last_layer_state, any_states): ) // (pivot_image_token + pivot_text_token)) device = last_layer_state.device - any_states = ( - any_states.permute(0, 2, 1, 3) - .reshape(any_states.shape[0], any_states.shape[1], -1) - ) + any_states = any_states.permute(0, 2, 1, 3) + any_states = any_states.reshape(any_states.shape[0], any_states.shape[1], -1) k_states_image_token = any_states[0][ image_token_start_index:image_token_start_index + image_token_length, : diff --git a/llmc/compression/token_reduction/fastv.py b/llmc/compression/token_reduction/fastv.py index 0c2bcbaa4..ecafde5a4 100644 --- a/llmc/compression/token_reduction/fastv.py +++ b/llmc/compression/token_reduction/fastv.py @@ -52,7 +52,7 @@ def wrapper(self, *args, **kwargs): attention_mask = args[2] token_indices = \ input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX'] - pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item() + pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item() outputs = fn(*args, **kwargs) return outputs diff --git a/llmc/compression/token_reduction/sparsevlm.py b/llmc/compression/token_reduction/sparsevlm.py index e8c62106e..997a88b71 100755 --- a/llmc/compression/token_reduction/sparsevlm.py +++ b/llmc/compression/token_reduction/sparsevlm.py @@ -12,6 +12,8 @@ from .token_reduction_module import TokenReductionModule from .utils import prefill_wrapper, prefill_wrapper_model +layer_dict = {} + @TOKEN_REDUCTION_REGISTRY.register('SparseVLM') class SparseVLM(TokenReductionModule): @@ -24,6 +26,8 @@ def add_sparse_config(self): special_config = self.config.get('special', {}) self.pruning_loc = special_config.get('pruning_loc', [2, 6, 15]) + global layer_dict + layer_dict = {layer: idx for idx, layer in enumerate(self.pruning_loc)} special_config['retained_tokens'] = special_config.get('retained_tokens', 192) special_config['init_token_total_shape'] = special_config.get('init_token_total_shape', 668) special_config['generate_process_count'] = 0 @@ -44,7 +48,8 @@ def input_hook(module, input_args, pruning_pars): # find the position of the first image token for seq in input_ids: image_token_index = ( - seq == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0] + seq == IMAGE_TOKEN_INDEX + ).nonzero(as_tuple=True)[0] if len(image_token_index) > 0: pre_prompt_length_list.append(image_token_index[0].item()) else: @@ -95,33 +100,31 @@ def wrapper(self, *args, **kwargs): @prefill_wrapper_model def register_module_pars(module, args, kwargs, pruning_pars): pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] - inputs_embeds = kwargs['inputs_embeds'] - if inputs_embeds is None: - inputs_embeds = module.embed_tokens(kwargs['input_ids']) - hidden_states = inputs_embeds # shape: (B, L, C) + hidden_states = kwargs['inputs_embeds'] + if hidden_states is None: + hidden_states = module.embed_tokens(kwargs['input_ids']) B, L, _ = hidden_states.shape pruning_pars['B'] = B init_n = pruning_pars['init_token_total_shape'] + \ - pruning_pars['generate_process_count'] # 668 + pruning_pars['generate_process_count'] # 668 pruning_pars['prev_decision'] = torch.ones( B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device) pruning_pars['policy'] = torch.ones( B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device) - pruning_pars['v_token_start'] = pre_prompt_length_list[0] if len( - pre_prompt_length_list) != 0 else 0 # 35 - v_token_start = pruning_pars['v_token_start'] - pruning_pars['text_token_start'] = pruning_pars['v_token_start'] + \ - pruning_pars['image_shape'] # 35 + 576 = 611 - text_token_start = pruning_pars['text_token_start'] + v_token_start = pre_prompt_length_list[0] if len( + pre_prompt_length_list) != 0 else 0 + text_token_start = v_token_start + pruning_pars['image_shape'] + pruning_pars['v_token_start'] = v_token_start # 35 + pruning_pars['text_token_start'] = text_token_start # 611 pruning_pars['v_token_num'] = pruning_pars['image_shape'] # 576 if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1): v_t = hidden_states[:, v_token_start: text_token_start, :] t_t = hidden_states[:, text_token_start:, :] - m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] # 52? - m_v_t = m_v_t.softmax(2).mean(1) # [1, 53] + m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 52] + m_v_t = m_v_t.softmax(2).mean(1) # [1, 52] pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean()) return args, kwargs @@ -134,10 +137,20 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx) kwargs['position_embeddings'] = pruning_pars['position_embeddings'] return args, kwargs - def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx): + def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx): if len(kwargs['position_ids'][0]) == 1: return args, kwargs + if layer_idx != self.pruning_loc[0]: + kwargs['position_ids'] = pruning_pars['position_ids'] + kwargs['cache_position'] = pruning_pars['cache_position'] + kwargs['position_embeddings'] = pruning_pars['position_embeddings'] + return args, kwargs + + def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_idx): + + if len(kwargs['position_ids'][0]) == 1: + return layer_outs from transformers.models.llama.modeling_llama import \ apply_rotary_pos_emb @@ -150,8 +163,7 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx): hidden_states = kwargs['hidden_states'] position_embeddings = kwargs['position_embeddings'] position_ids = kwargs['position_ids'] - past_key_value = kwargs['past_key_value'] - cache_position = kwargs['cache_position'] + past_key_value = layer_outs[2] attention_mask = kwargs['attention_mask'] t_token_idx = pruning_pars['t_token_idx'] @@ -179,12 +191,8 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx): query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - temp_cache = copy.deepcopy(past_key_value) - cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position} - key_states, value_states = temp_cache.update( - key_states, value_states, - layer_idx, cache_kwargs - ) + key_states = past_key_value.key_cache[layer_idx] + value_states = past_key_value.value_cache[layer_idx] t_token_idx = t_token_idx[1] + v_token_start + v_token_num L, S = query_states.size(-2), key_states.size(-2) scale_factor = 1 / math.sqrt(query_states.size(-1)) @@ -201,19 +209,16 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx): pruning_pars['attn_logits'] = attn_logits - return args, kwargs + return layer_outs @prefill_wrapper def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx): - # pruning_pars['attn_logits'] 对llavaHf运行存在BUG, - # 使用layer_outputs[1]运行llavaHf无问题,但精度没对上 - # llava:attn_logits = pruning_pars['attn_logits'] - # llavahf:attn_logits = layer_outputs[1] if 'attn_logits' not in pruning_pars: attn_logits = layer_outputs[1] else: attn_logits = pruning_pars['attn_logits'] + merge_flag = pruning_pars['merge_flag'] v_token_start = pruning_pars['v_token_start'] v_token_num = pruning_pars['v_token_num'] text_token_start = pruning_pars['text_token_start'] @@ -255,7 +260,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer total_sparse_token_idx = torch.where(policy == 0)[1].unsqueeze(0) # merge and cluster - if s_flag and total_sparse_token_idx.shape[1] > 0: + if s_flag and merge_flag and total_sparse_token_idx.shape[1] > 0: total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx) merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1] @@ -359,6 +364,14 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): ) elif self.model.__class__.__name__ == 'Llava': self.blocks[block_idx].self_attn.register_forward_pre_hook( + functools.partial( + update_kwargs_hook, + pruning_pars=self.pruning_paras, + layer_idx=block_idx, + ), + with_kwargs=True + ) + self.blocks[block_idx].self_attn.register_forward_hook( functools.partial( get_attn_logits_hook, pruning_pars=self.pruning_paras,