Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:

- name: Download dataset
run: |
# pwd # /home/runner/work/llmc/llmc
# pwd # /home/runner/work/LightCompress/LightCompress
cd tools
python download_calib_dataset.py --save_path ../check/datasets/calib --dataset_name pileval
python download_eval_dataset.py --save_path ../check/datasets/eval --dataset_name wikitext2
Expand All @@ -46,17 +46,17 @@ jobs:

- name: Preparation for check.
run: |
cd ci_check # /home/runner/work/llmc/llmc/ci_check
cd ci_check # /home/runner/work/LightCompress/LightCompress/ci_check
python change_files.py

- name: Run awq check
run: |
cd ci_check # /home/runner/work/llmc/llmc/ci_check
cd ci_check # /home/runner/work/LightCompress/LightCompress/ci_check
bash run_awq.sh

- name: Run gptq check
run: |
cd ci_check # /home/runner/work/llmc/llmc/ci_check
cd ci_check # /home/runner/work/LightCompress/LightCompress/ci_check
bash run_gptq.sh

- name: Check success
Expand Down
8 changes: 4 additions & 4 deletions ci_check/awq_w4a16_fakequant_eval.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ base:
seed: &seed 42
model:
type: Opt
path: /home/runner/work/llmc/llmc/ci_check/opt-125m
path: /home/runner/work/LightCompress/LightCompress/ci_check/opt-125m

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file path is hardcoded. Consider using an environment variable that can be substituted by the CI system to improve portability.

torch_dtype: auto
calib:
name: pileval
download: False
path: /home/runner/work/llmc/llmc/check/datasets/calib/pileval
path: /home/runner/work/LightCompress/LightCompress/check/datasets/calib/pileval
n_samples: 4 # 128
bs: -1
seq_len: 16 # 512
Expand All @@ -17,7 +17,7 @@ eval:
eval_pos: [pretrain, transformed, fake_quant]
name: wikitext2
download: False
path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2
path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2
bs: 1
seq_len: 16 # 2048
eval_token_consist: True
Expand All @@ -35,4 +35,4 @@ quant:
clip_sym: False
save:
save_trans: False
save_path: /home/runner/work/llmc/llmc/save/opt-125m_awq_w4a16
save_path: /home/runner/work/LightCompress/LightCompress/save/opt-125m_awq_w4a16
8 changes: 4 additions & 4 deletions ci_check/gptq_w_only.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ base:
seed: &seed 0
model:
type: Opt
path: /home/runner/work/llmc/llmc/ci_check/opt-125m
path: /home/runner/work/LightCompress/LightCompress/ci_check/opt-125m

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file path is hardcoded. Consider using an environment variable that can be substituted by the CI system to improve portability.

torch_dtype: auto
calib:
name: wikitext2
download: False
n_samples: 4
path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2
path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2
bs: 1
seq_len: 16
preproc: wikitext2_gptq
Expand All @@ -17,7 +17,7 @@ eval:
eval_pos: [fake_quant]
name: wikitext2
download: False
path: /home/runner/work/llmc/llmc/check/datasets/eval/wikitext2
path: /home/runner/work/LightCompress/LightCompress/check/datasets/eval/wikitext2
bs: 1
seq_len: 16
inference_per_block: False
Expand All @@ -40,4 +40,4 @@ quant:
quant_out: True
save:
save_fake: False
save_path: /home/runner/work/llmc/llmc/save/opt-125m_gptq_w4a16
save_path: /home/runner/work/LightCompress/LightCompress/save/opt-125m_gptq_w4a16
1 change: 1 addition & 0 deletions configs/sparsification/methods/SparseVLM/sparsevlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ sparse:
pruning_loc: [2] # [2, 6, 15]
retained_tokens: 192
init_token_total_shape: 668
merge_flag: False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider the implications of setting merge_flag to False. Ensure this aligns with the intended behavior of disabling token merging, as it might impact performance or accuracy.

save:
save_trans: False
save_fake: False
Expand Down
10 changes: 4 additions & 6 deletions llmc/compression/token_reduction/dart.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def wrapper(self, *args, **kwargs):
token_indices = (
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
)
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Ensure that token_indices always contains at least one element before accessing token_indices[0][0]. Otherwise, this could raise an IndexError if no image tokens are found.

pruning_paras['image_token_start_index'] = token_indices[0][0].item() if len(token_indices[0]) > 0 else 0


outputs = fn(*args, **kwargs)
return outputs
Expand All @@ -67,7 +67,7 @@ def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_i
hidden_states = kwargs['hidden_states']
position_embeddings = kwargs['position_embeddings']
position_ids = kwargs['position_ids']
past_key_value = kwargs['past_key_value']
past_key_value = layer_outs[2]

bsz, q_len, _ = hidden_states.size()
query_states = module.q_proj(hidden_states)
Expand Down Expand Up @@ -193,10 +193,8 @@ def get_retained_image_token(pruning_paras, last_layer_state, any_states):
) // (pivot_image_token + pivot_text_token))
device = last_layer_state.device

any_states = (
any_states.permute(0, 2, 1, 3)
.reshape(any_states.shape[0], any_states.shape[1], -1)
)
any_states = any_states.permute(0, 2, 1, 3)
any_states = any_states.reshape(any_states.shape[0], any_states.shape[1], -1)

k_states_image_token = any_states[0][
image_token_start_index:image_token_start_index + image_token_length, :
Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/token_reduction/fastv.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def wrapper(self, *args, **kwargs):
attention_mask = args[2]
token_indices = \
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Ensure that token_indices always contains at least one element before accessing token_indices[0][0]. Otherwise, this could raise an IndexError if no image tokens are found.

pruning_paras['image_token_start_index'] = token_indices[0][0].item() if len(token_indices[0]) > 0 else 0


outputs = fn(*args, **kwargs)
return outputs
Expand Down
71 changes: 42 additions & 29 deletions llmc/compression/token_reduction/sparsevlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .token_reduction_module import TokenReductionModule
from .utils import prefill_wrapper, prefill_wrapper_model

layer_dict = {}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Introducing a global layer_dict can lead to unexpected behavior if multiple SparseVLM instances are used concurrently. Refactor to use an instance variable to ensure thread safety and proper encapsulation.



@TOKEN_REDUCTION_REGISTRY.register('SparseVLM')
class SparseVLM(TokenReductionModule):
Expand All @@ -24,6 +26,8 @@ def add_sparse_config(self):
special_config = self.config.get('special', {})

self.pruning_loc = special_config.get('pruning_loc', [2, 6, 15])
global layer_dict
layer_dict = {layer: idx for idx, layer in enumerate(self.pruning_loc)}
special_config['retained_tokens'] = special_config.get('retained_tokens', 192)
special_config['init_token_total_shape'] = special_config.get('init_token_total_shape', 668)
special_config['generate_process_count'] = 0
Expand All @@ -44,7 +48,8 @@ def input_hook(module, input_args, pruning_pars):
# find the position of the first image token
for seq in input_ids:
image_token_index = (
seq == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[0]
seq == IMAGE_TOKEN_INDEX
).nonzero(as_tuple=True)[0]
if len(image_token_index) > 0:
pre_prompt_length_list.append(image_token_index[0].item())
else:
Expand Down Expand Up @@ -95,33 +100,31 @@ def wrapper(self, *args, **kwargs):
@prefill_wrapper_model
def register_module_pars(module, args, kwargs, pruning_pars):
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
inputs_embeds = kwargs['inputs_embeds']
if inputs_embeds is None:
inputs_embeds = module.embed_tokens(kwargs['input_ids'])
hidden_states = inputs_embeds # shape: (B, L, C)
hidden_states = kwargs['inputs_embeds']
if hidden_states is None:
hidden_states = module.embed_tokens(kwargs['input_ids'])

B, L, _ = hidden_states.shape
pruning_pars['B'] = B
init_n = pruning_pars['init_token_total_shape'] + \
pruning_pars['generate_process_count'] # 668
pruning_pars['generate_process_count'] # 668
pruning_pars['prev_decision'] = torch.ones(
B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device)
pruning_pars['policy'] = torch.ones(
B, init_n, 1, dtype=hidden_states.dtype, device=hidden_states.device)

pruning_pars['v_token_start'] = pre_prompt_length_list[0] if len(
pre_prompt_length_list) != 0 else 0 # 35
v_token_start = pruning_pars['v_token_start']
pruning_pars['text_token_start'] = pruning_pars['v_token_start'] + \
pruning_pars['image_shape'] # 35 + 576 = 611
text_token_start = pruning_pars['text_token_start']
v_token_start = pre_prompt_length_list[0] if len(
pre_prompt_length_list) != 0 else 0
text_token_start = v_token_start + pruning_pars['image_shape']
pruning_pars['v_token_start'] = v_token_start # 35
pruning_pars['text_token_start'] = text_token_start # 611
pruning_pars['v_token_num'] = pruning_pars['image_shape'] # 576

if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1):
v_t = hidden_states[:, v_token_start: text_token_start, :]
t_t = hidden_states[:, text_token_start:, :]
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] # 52?
m_v_t = m_v_t.softmax(2).mean(1) # [1, 53]
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 52]
m_v_t = m_v_t.softmax(2).mean(1) # [1, 52]
pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean())

return args, kwargs
Expand All @@ -134,10 +137,20 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
return args, kwargs

def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
def update_kwargs_hook(module, args, kwargs, pruning_pars, layer_idx):

if len(kwargs['position_ids'][0]) == 1:
return args, kwargs
if layer_idx != self.pruning_loc[0]:
kwargs['position_ids'] = pruning_pars['position_ids']
kwargs['cache_position'] = pruning_pars['cache_position']
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
return args, kwargs

def get_attn_logits_hook(module, args, kwargs, layer_outs, pruning_pars, layer_idx):

if len(kwargs['position_ids'][0]) == 1:
return layer_outs

from transformers.models.llama.modeling_llama import \
apply_rotary_pos_emb
Expand All @@ -150,8 +163,7 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
hidden_states = kwargs['hidden_states']
position_embeddings = kwargs['position_embeddings']
position_ids = kwargs['position_ids']
past_key_value = kwargs['past_key_value']
cache_position = kwargs['cache_position']
past_key_value = layer_outs[2]
attention_mask = kwargs['attention_mask']

t_token_idx = pruning_pars['t_token_idx']
Expand Down Expand Up @@ -179,12 +191,8 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
temp_cache = copy.deepcopy(past_key_value)
cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position}
key_states, value_states = temp_cache.update(
key_states, value_states,
layer_idx, cache_kwargs
)
key_states = past_key_value.key_cache[layer_idx]
value_states = past_key_value.value_cache[layer_idx]
t_token_idx = t_token_idx[1] + v_token_start + v_token_num
L, S = query_states.size(-2), key_states.size(-2)
Comment on lines 196 to 197

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change removes parentheses around any_states.permute and any_states.reshape. Confirm that this change does not affect the order of operations or the intended result.

scale_factor = 1 / math.sqrt(query_states.size(-1))
Expand All @@ -201,19 +209,16 @@ def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):

pruning_pars['attn_logits'] = attn_logits

return args, kwargs
return layer_outs

@prefill_wrapper
def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx):

# pruning_pars['attn_logits'] 对llavaHf运行存在BUG,
# 使用layer_outputs[1]运行llavaHf无问题,但精度没对上
# llava:attn_logits = pruning_pars['attn_logits']
# llavahf:attn_logits = layer_outputs[1]
if 'attn_logits' not in pruning_pars:
attn_logits = layer_outputs[1]
else:
attn_logits = pruning_pars['attn_logits']
merge_flag = pruning_pars['merge_flag']
v_token_start = pruning_pars['v_token_start']
v_token_num = pruning_pars['v_token_num']
text_token_start = pruning_pars['text_token_start']
Expand Down Expand Up @@ -255,7 +260,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
total_sparse_token_idx = torch.where(policy == 0)[1].unsqueeze(0)

# merge and cluster
if s_flag and total_sparse_token_idx.shape[1] > 0:
if s_flag and merge_flag and total_sparse_token_idx.shape[1] > 0:
total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx)

merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1]
Expand Down Expand Up @@ -359,6 +364,14 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
)
elif self.model.__class__.__name__ == 'Llava':
self.blocks[block_idx].self_attn.register_forward_pre_hook(
functools.partial(
update_kwargs_hook,
pruning_pars=self.pruning_paras,
layer_idx=block_idx,
),
with_kwargs=True
)
Comment on lines 366 to +373

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider consolidating the duplicate registration of update_kwargs_hook and get_attn_logits_hook for Llava models to reduce code duplication and improve maintainability.

self.blocks[block_idx].self_attn.register_forward_hook(
functools.partial(
get_attn_logits_hook,
pruning_pars=self.pruning_paras,
Expand Down