-
Notifications
You must be signed in to change notification settings - Fork 66
fix bugs for dart fasyv sparsevlm and update ci #412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,13 @@ base: | |
| seed: &seed 0 | ||
| model: | ||
| type: Opt | ||
| path: /home/runner/work/llmc/llmc/ci_check/opt-125m | ||
| path: /home/runner/work/LightCompress/LightCompress/ci_check/opt-125m | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| torch_dtype: auto | ||
| calib: | ||
| name: wikitext2 | ||
| download: False | ||
| n_samples: 4 | ||
| path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2 | ||
| path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2 | ||
| bs: 1 | ||
| seq_len: 16 | ||
| preproc: wikitext2_gptq | ||
|
|
@@ -17,7 +17,7 @@ eval: | |
| eval_pos: [fake_quant] | ||
| name: wikitext2 | ||
| download: False | ||
| path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2 | ||
| path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2 | ||
| bs: 1 | ||
| seq_len: 16 | ||
| inference_per_block: False | ||
|
|
@@ -40,4 +40,4 @@ quant: | |
| quant_out: True | ||
| save: | ||
| save_fake: False | ||
| save_path: /home/runner/work/llmc/llmc/save/opt-125m_gptq_w4a16 | ||
| save_path: /home/runner/work/LightCompress/LightCompress/save/opt-125m_gptq_w4a16 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ sparse: | |
| pruning_loc: [2] # [2, 6, 15] | ||
| retained_tokens: 192 | ||
| init_token_total_shape: 668 | ||
| merge_flag: False | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| save: | ||
| save_trans: False | ||
| save_fake: False | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,7 +44,7 @@ def wrapper(self, *args, **kwargs): | |
| token_indices = ( | ||
| input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX'] | ||
| ) | ||
| pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item() | ||
| pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| outputs = fn(*args, **kwargs) | ||
| return outputs | ||
|
|
@@ -67,7 +67,7 @@ def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_i | |
| hidden_states = kwargs['hidden_states'] | ||
| position_embeddings = kwargs['position_embeddings'] | ||
| position_ids = kwargs['position_ids'] | ||
| past_key_value = kwargs['past_key_value'] | ||
| past_key_value = layer_outs[2] | ||
|
|
||
| bsz, q_len, _ = hidden_states.size() | ||
| query_states = module.q_proj(hidden_states) | ||
|
|
@@ -193,10 +193,8 @@ def get_retained_image_token(pruning_paras, last_layer_state, any_states): | |
| ) // (pivot_image_token + pivot_text_token)) | ||
| device = last_layer_state.device | ||
|
|
||
| any_states = ( | ||
| any_states.permute(0, 2, 1, 3) | ||
| .reshape(any_states.shape[0], any_states.shape[1], -1) | ||
| ) | ||
| any_states = any_states.permute(0, 2, 1, 3) | ||
| any_states = any_states.reshape(any_states.shape[0], any_states.shape[1], -1) | ||
|
|
||
| k_states_image_token = any_states[0][ | ||
| image_token_start_index:image_token_start_index + image_token_length, : | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,7 +52,7 @@ def wrapper(self, *args, **kwargs): | |
| attention_mask = args[2] | ||
| token_indices = \ | ||
| input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX'] | ||
| pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item() | ||
| pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| outputs = fn(*args, **kwargs) | ||
| return outputs | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,8 @@ | |
| from .token_reduction_module import TokenReductionModule | ||
| from .utils import prefill_wrapper, prefill_wrapper_model | ||
|
|
||
| layer_dict = {} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| @TOKEN_REDUCTION_REGISTRY.register('SparseVLM') | ||
| class SparseVLM(TokenReductionModule): | ||
|
|
@@ -24,6 +26,8 @@ def add_sparse_config(self): | |
| special_config = self.config.get('special', {}) | ||
|
|
||
| self.pruning_loc = special_config.get('pruning_loc', [2, 6, 15]) | ||
| global layer_dict | ||
| layer_dict = {layer: idx for idx, layer in enumerate(self.pruning_loc)} | ||
| special_config['retained_tokens'] = special_config.get('retained_tokens', 192) | ||
| special_config['init_token_total_shape'] = special_config.get('init_token_total_shape', 668) | ||
| special_config['generate_process_count'] = 0 | ||
|
|
@@ -44,7 +48,8 @@ def input_hook(module, input_args, pruning_pars): | |
| # find the position of the first image token | ||
| for seq in input_ids: | ||
| image_token_index = ( | ||
| seq == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0] | ||
| seq == IMAGE_TOKEN_INDEX | ||
| ).nonzero(as_tuple=True)[0] | ||
| if len(image_token_index) > 0: | ||
| pre_prompt_length_list.append(image_token_index[0].item()) | ||
| else: | ||
|
|
@@ -95,33 +100,31 @@ def wrapper(self, *args, **kwargs): | |
| @prefill_wrapper_model | ||
| def register_module_pars(module, args, kwargs, pruning_pars): | ||
| pre_prompt_length_list = pruning_pars['pre_prompt_length_list'] | ||
| inputs_embeds = kwargs['inputs_embeds'] | ||
| if inputs_embeds is None: | ||
| inputs_embeds = module.embed_tokens(kwargs['input_ids']) | ||
| hidden_states = inputs_embeds # shape: (B, L, C) | ||
| hidden_states = kwargs['inputs_embeds'] | ||
| if hidden_states is None: | ||
| hidden_states = module.embed_tokens(kwargs['input_ids']) | ||
|
|
||
| B, L, _ = hidden_states.shape | ||
| pruning_pars['B'] = B | ||
| init_n = pruning_pars['init_token_total_shape'] + \ | ||
| pruning_pars['generate_process_count'] # 668 | ||
| pruning_pars['generate_process_count'] # 668 | ||
| pruning_pars['prev_decision'] = torch.ones( | ||
| B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device) | ||
| pruning_pars['policy'] = torch.ones( | ||
| B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device) | ||
|
|
||
| pruning_pars['v_token_start'] = pre_prompt_length_list[0] if len( | ||
| pre_prompt_length_list) != 0 else 0 # 35 | ||
| v_token_start = pruning_pars['v_token_start'] | ||
| pruning_pars['text_token_start'] = pruning_pars['v_token_start'] + \ | ||
| pruning_pars['image_shape'] # 35 + 576 = 611 | ||
| text_token_start = pruning_pars['text_token_start'] | ||
| v_token_start = pre_prompt_length_list[0] if len( | ||
| pre_prompt_length_list) != 0 else 0 | ||
| text_token_start = v_token_start + pruning_pars['image_shape'] | ||
| pruning_pars['v_token_start'] = v_token_start # 35 | ||
| pruning_pars['text_token_start'] = text_token_start # 611 | ||
| pruning_pars['v_token_num'] = pruning_pars['image_shape'] # 576 | ||
|
|
||
| if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1): | ||
| v_t = hidden_states[:, v_token_start: text_token_start, :] | ||
| t_t = hidden_states[:, text_token_start:, :] | ||
| m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] # 52? | ||
| m_v_t = m_v_t.softmax(2).mean(1) # [1, 53] | ||
| m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 52] | ||
| m_v_t = m_v_t.softmax(2).mean(1) # [1, 52] | ||
| pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean()) | ||
|
|
||
| return args, kwargs | ||
|
|
@@ -134,10 +137,20 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx) | |
| kwargs['position_embeddings'] = pruning_pars['position_embeddings'] | ||
| return args, kwargs | ||
|
|
||
| def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx): | ||
| def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx): | ||
|
|
||
| if len(kwargs['position_ids'][0]) == 1: | ||
| return args, kwargs | ||
| if layer_idx != self.pruning_loc[0]: | ||
| kwargs['position_ids'] = pruning_pars['position_ids'] | ||
| kwargs['cache_position'] = pruning_pars['cache_position'] | ||
| kwargs['position_embeddings'] = pruning_pars['position_embeddings'] | ||
| return args, kwargs | ||
|
|
||
| def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_idx): | ||
|
|
||
| if len(kwargs['position_ids'][0]) == 1: | ||
| return layer_outs | ||
|
|
||
| from transformers.models.llama.modeling_llama import \ | ||
| apply_rotary_pos_emb | ||
|
|
@@ -150,8 +163,7 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx): | |
| hidden_states = kwargs['hidden_states'] | ||
| position_embeddings = kwargs['position_embeddings'] | ||
| position_ids = kwargs['position_ids'] | ||
| past_key_value = kwargs['past_key_value'] | ||
| cache_position = kwargs['cache_position'] | ||
| past_key_value = layer_outs[2] | ||
| attention_mask = kwargs['attention_mask'] | ||
|
|
||
| t_token_idx = pruning_pars['t_token_idx'] | ||
|
|
@@ -179,12 +191,8 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx): | |
|
|
||
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
| if past_key_value is not None: | ||
| temp_cache = copy.deepcopy(past_key_value) | ||
| cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position} | ||
| key_states, value_states = temp_cache.update( | ||
| key_states, value_states, | ||
| layer_idx, cache_kwargs | ||
| ) | ||
| key_states = past_key_value.key_cache[layer_idx] | ||
| value_states = past_key_value.value_cache[layer_idx] | ||
| t_token_idx = t_token_idx[1] + v_token_start + v_token_num | ||
| L, S = query_states.size(-2), key_states.size(-2) | ||
|
Comment on lines
196
to
197
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| scale_factor = 1 / math.sqrt(query_states.size(-1)) | ||
|
|
@@ -201,19 +209,16 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx): | |
|
|
||
| pruning_pars['attn_logits'] = attn_logits | ||
|
|
||
| return args, kwargs | ||
| return layer_outs | ||
|
|
||
| @prefill_wrapper | ||
| def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx): | ||
|
|
||
| # pruning_pars['attn_logits'] 对llavaHf运行存在BUG, | ||
| # 使用layer_outputs[1]运行llavaHf无问题,但精度没对上 | ||
| # llava:attn_logits = pruning_pars['attn_logits'] | ||
| # llavahf:attn_logits = layer_outputs[1] | ||
| if 'attn_logits' not in pruning_pars: | ||
| attn_logits = layer_outputs[1] | ||
| else: | ||
| attn_logits = pruning_pars['attn_logits'] | ||
| merge_flag = pruning_pars['merge_flag'] | ||
| v_token_start = pruning_pars['v_token_start'] | ||
| v_token_num = pruning_pars['v_token_num'] | ||
| text_token_start = pruning_pars['text_token_start'] | ||
|
|
@@ -255,7 +260,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer | |
| total_sparse_token_idx = torch.where(policy == 0)[1].unsqueeze(0) | ||
|
|
||
| # merge and cluster | ||
| if s_flag and total_sparse_token_idx.shape[1] > 0: | ||
| if s_flag and merge_flag and total_sparse_token_idx.shape[1] > 0: | ||
| total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx) | ||
|
|
||
| merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1] | ||
|
|
@@ -359,6 +364,14 @@ def read_parameter_hook(module, args, kwargs, pruning_pars): | |
| ) | ||
| elif self.model.__class__.__name__ == 'Llava': | ||
| self.blocks[block_idx].self_attn.register_forward_pre_hook( | ||
| functools.partial( | ||
| update_kwargs_hook, | ||
| pruning_pars=self.pruning_paras, | ||
| layer_idx=block_idx, | ||
| ), | ||
| with_kwargs=True | ||
| ) | ||
|
Comment on lines
366
to
+373
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| self.blocks[block_idx].self_attn.register_forward_hook( | ||
| functools.partial( | ||
| get_attn_logits_hook, | ||
| pruning_pars=self.pruning_paras, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The file path is hardcoded. Consider using an environment variable that can be substituted by the CI system to improve portability.