diff --git a/configs/sparsification/methods/SparseVLM/sparsevlm_multi_turn.yml b/configs/sparsification/methods/SparseVLM/sparsevlm_multi_turn.yml new file mode 100644 index 00000000..2fc436d7 --- /dev/null +++ b/configs/sparsification/methods/SparseVLM/sparsevlm_multi_turn.yml @@ -0,0 +1,29 @@ +base: + seed: &seed 42 +model: + type: Llava + path: model path + torch_dtype: auto +eval: + eval_pos: [transformed] # transformed + name: custom_gen + type: just_infer + download: False + path: /data/nvme1/yongyang/projects/llmc_plus/general_custom_data + apply_chat_template: True + bs: 1 + inference_per_block: False + max_new_tokens: 512 + statistics: False +sparse: + method: TokenReduction + special: + method: SparseVLM + pruning_loc: [2, 6, 15] + retained_tokens: 192 + prune_flag: True + merge_flag: True +save: + save_trans: False + save_fake: False + save_path: /path/to/save/ diff --git a/llmc/compression/token_reduction/fastv.py b/llmc/compression/token_reduction/fastv.py index ecafde5a..8a699845 100644 --- a/llmc/compression/token_reduction/fastv.py +++ b/llmc/compression/token_reduction/fastv.py @@ -90,6 +90,12 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras): top_attention_rank_index = \ last_layer_attention_avg_last_tok_image.topk( round(image_token_length * (1 - rate))).indices + image_token_start_index + + if self.model.first_turn_question: + module.register_buffer('top_attention_rank_index', top_attention_rank_index) + else: + top_attention_rank_index = module.top_attention_rank_index + # keep index keep_indexs = torch.cat( ( diff --git a/llmc/compression/token_reduction/random.py b/llmc/compression/token_reduction/random.py index 9d71084f..d6dfde1d 100644 --- a/llmc/compression/token_reduction/random.py +++ b/llmc/compression/token_reduction/random.py @@ -76,23 +76,22 @@ def random_pruning_hook(module, args, kwargs, pruning_paras): device = hidden_states.device + vision_indexes = torch.arange( + image_token_start_index, + image_token_start_index + image_token_length, + device=device, + ) if self.model.first_turn_question: - logger.info(' -----first_turn_question-----') - vision_indexes = torch.arange( - image_token_start_index, - image_token_start_index + image_token_length, - device=device, - ) num_keep = round(image_token_length * (1 - rate)) rand_idx = torch.randperm(image_token_length, device=device)[:num_keep] vision_indexes = vision_indexes[rand_idx] - # save vision_indexes to module - module.register_buffer('vision_indexes', vision_indexes) + # save rand_idx to module + module.register_buffer('rand_idx', rand_idx) else: - logger.info(' -----not first_turn_question-----') # load vision_indexes from module (prompt cache) - vision_indexes = module.vision_indexes + rand_idx = module.rand_idx + vision_indexes = vision_indexes[rand_idx] # keep index keep_indexs = torch.cat( diff --git a/llmc/compression/token_reduction/sparsevlm.py b/llmc/compression/token_reduction/sparsevlm.py index 22269cf7..f91c9564 100755 --- a/llmc/compression/token_reduction/sparsevlm.py +++ b/llmc/compression/token_reduction/sparsevlm.py @@ -251,6 +251,13 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer text_token_start = prompt_length + image_shape policy[batch, text_token_start:] = 1 + if self.model.first_turn_question: + vision_mask = policy[:, v_token_start:v_token_start + v_token_num] + module.register_buffer('vision_mask', vision_mask) + else: + vision_mask = module.vision_mask + policy[:, v_token_start:v_token_start + v_token_num] = vision_mask + total_sparse_token_idx = torch.where(policy == 0)[1].unsqueeze(0) # merge and cluster diff --git a/llmc/compression/token_reduction/visionzip.py b/llmc/compression/token_reduction/visionzip.py index f72580a5..5109b478 100755 --- a/llmc/compression/token_reduction/visionzip.py +++ b/llmc/compression/token_reduction/visionzip.py @@ -323,6 +323,12 @@ def visionzip_hook(m, images, image_forward_outs): mask = torch.ones_like( hidden_states[:, :, 0], dtype=torch.bool, device=metric.device ).scatter_(1, all_indices, False) + + if self.model.first_turn_question: + m.register_buffer('mask', mask) + else: + mask = m.mask + dominant_tokens = hidden_states.masked_select(~mask.unsqueeze(-1)).view( hidden_states.shape[0], dominant_num + 1, hidden_states.shape[2] ) diff --git a/llmc/models/llava.py b/llmc/models/llava.py index c1bfbce3..8b74812a 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -75,6 +75,7 @@ def build_model(self): 'IMAGE_TOKEN_INDEX': IMAGE_TOKEN_INDEX, # for llava } self.processor = None + self.first_turn_question = True def get_extra_rot_module_besides_embed_layers(self): return [self.vision_projector[2]] @@ -163,8 +164,6 @@ def load_images(image_files): out.append(image) return out - self.first_turn_question = True - for data_idx, questions in enumerate(img_qas): self.first_turn_question = True