Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions configs/sparsification/methods/SparseVLM/sparsevlm_multi_turn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
base:
seed: &seed 42
model:
type: Llava
path: model path
torch_dtype: auto
eval:
eval_pos: [transformed] # transformed
name: custom_gen
type: just_infer
download: False
path: /data/nvme1/yongyang/projects/llmc_plus/general_custom_data

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The path for the evaluation data is hardcoded to a user-specific absolute path. This reduces the portability of the configuration file and makes it difficult for other developers to run this configuration without modification.

For better reusability, please consider using a placeholder, similar to how model.path and save.save_path are configured in this file.

    path: /path/to/general_custom_data

apply_chat_template: True
bs: 1
inference_per_block: False
max_new_tokens: 512
statistics: False
sparse:
method: TokenReduction
special:
method: SparseVLM
pruning_loc: [2, 6, 15]
retained_tokens: 192
prune_flag: True
merge_flag: True
save:
save_trans: False
save_fake: False
save_path: /path/to/save/
6 changes: 6 additions & 0 deletions llmc/compression/token_reduction/fastv.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):
top_attention_rank_index = \
last_layer_attention_avg_last_tok_image.topk(
round(image_token_length * (1 - rate))).indices + image_token_start_index

if self.model.first_turn_question:
module.register_buffer('top_attention_rank_index', top_attention_rank_index)
else:
top_attention_rank_index = module.top_attention_rank_index

# keep index
keep_indexs = torch.cat(
(
Expand Down
19 changes: 9 additions & 10 deletions llmc/compression/token_reduction/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,22 @@ def random_pruning_hook(module, args, kwargs, pruning_paras):

device = hidden_states.device

vision_indexes = torch.arange(
image_token_start_index,
image_token_start_index + image_token_length,
device=device,
)
if self.model.first_turn_question:
logger.info(' -----first_turn_question-----')
vision_indexes = torch.arange(
image_token_start_index,
image_token_start_index + image_token_length,
device=device,
)
num_keep = round(image_token_length * (1 - rate))
rand_idx = torch.randperm(image_token_length, device=device)[:num_keep]
vision_indexes = vision_indexes[rand_idx]

# save vision_indexes to module
module.register_buffer('vision_indexes', vision_indexes)
# save rand_idx to module
module.register_buffer('rand_idx', rand_idx)
else:
logger.info(' -----not first_turn_question-----')
# load vision_indexes from module (prompt cache)
vision_indexes = module.vision_indexes
rand_idx = module.rand_idx
vision_indexes = vision_indexes[rand_idx]

# keep index
keep_indexs = torch.cat(
Expand Down
7 changes: 7 additions & 0 deletions llmc/compression/token_reduction/sparsevlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,13 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
text_token_start = prompt_length + image_shape
policy[batch, text_token_start:] = 1

if self.model.first_turn_question:
vision_mask = policy[:, v_token_start:v_token_start + v_token_num]
module.register_buffer('vision_mask', vision_mask)
else:
vision_mask = module.vision_mask
policy[:, v_token_start:v_token_start + v_token_num] = vision_mask

total_sparse_token_idx = torch.where(policy == 0)[1].unsqueeze(0)

# merge and cluster
Expand Down
6 changes: 6 additions & 0 deletions llmc/compression/token_reduction/visionzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def visionzip_hook(m, images, image_forward_outs):
mask = torch.ones_like(
hidden_states[:, :, 0], dtype=torch.bool, device=metric.device
).scatter_(1, all_indices, False)

if self.model.first_turn_question:
m.register_buffer('mask', mask)
else:
mask = m.mask

dominant_tokens = hidden_states.masked_select(~mask.unsqueeze(-1)).view(
hidden_states.shape[0], dominant_num + 1, hidden_states.shape[2]
)
Expand Down
3 changes: 1 addition & 2 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def build_model(self):
'IMAGE_TOKEN_INDEX': IMAGE_TOKEN_INDEX, # for llava
}
self.processor = None
self.first_turn_question = True

def get_extra_rot_module_besides_embed_layers(self):
return [self.vision_projector[2]]
Expand Down Expand Up @@ -163,8 +164,6 @@ def load_images(image_files):
out.append(image)
return out

self.first_turn_question = True

for data_idx, questions in enumerate(img_qas):
self.first_turn_question = True

Expand Down