Skip to content

Commit 93afb24

Browse files
authored
update sparsevlm for llava1.6 (#427)
1 parent 774e2c5 commit 93afb24

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

llmc/compression/token_reduction/sparsevlm.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
sparse_token_list_192 = []
1818
sparse_token_list_128 = []
1919
sparse_token_list_64 = []
20+
sparse_token_list_640 = []
21+
sparse_token_list_320 = []
22+
sparse_token_list_160 = []
2023
sparse_token_dict = {}
2124

2225

@@ -55,7 +58,7 @@ def input_hook(module, args, pruning_paras):
5558
pre_prompt_length_list.append(0)
5659
pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list
5760

58-
def input_hook_llava(fn, pruning_paras):
61+
def input_hook_llava(fn, pruning_paras, llava_next=False):
5962
@wraps(fn)
6063
def wrapper(self, *args, **kwargs):
6164
if args[0].shape[1] == 1:
@@ -81,11 +84,14 @@ def wrapper(self, *args, **kwargs):
8184

8285
pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list
8386

84-
return fn(*args, **kwargs)
87+
outs = fn(*args, **kwargs)
88+
if llava_next:
89+
pruning_paras['vision_token_length'] = outs[-1]
90+
return outs
8591
return wrapper
8692

8793
@prefill_wrapper_model
88-
def register_module_pars(module, args, kwargs, pruning_paras):
94+
def register_module_paras(module, args, kwargs, pruning_paras):
8995
pre_prompt_length_list = pruning_paras['pre_prompt_length_list']
9096
hidden_states = kwargs['inputs_embeds']
9197
if hidden_states is None:
@@ -227,7 +233,8 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_paras, laye
227233
text_token_start,
228234
t_token_idx,
229235
layer_idx,
230-
retained_tokens
236+
retained_tokens,
237+
pruning_paras['reduction_ratio']
231238
)
232239
if not prune_flag:
233240
pred_score_vis = torch.zeros_like(relation_vis_text, dtype=bool)
@@ -353,7 +360,8 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
353360
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
354361
input_hook_llava(
355362
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
356-
self.pruning_paras
363+
self.pruning_paras,
364+
llava_next=self.special_config['vision_token_length'] is None
357365
), self.model.vlm_model
358366
)
359367

@@ -362,7 +370,7 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
362370
elif self.model.__class__.__name__ == 'Llava':
363371
llama_model = self.model.model.model
364372
llama_model.register_forward_pre_hook(
365-
functools.partial(register_module_pars, pruning_paras=self.pruning_paras),
373+
functools.partial(register_module_paras, pruning_paras=self.pruning_paras),
366374
with_kwargs=True
367375
)
368376

@@ -417,6 +425,7 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
417425

418426
def update_list():
419427
global sparse_token_list_192, sparse_token_list_128, sparse_token_list_64
428+
global sparse_token_list_640, sparse_token_list_320, sparse_token_list_160
420429
global prune_flag, merge_flag, sparse_token_dict
421430

422431
if layer_dict == {2: 0, 6: 1, 15: 2}: # 2*576 4*300 10*200 16*110
@@ -428,10 +437,16 @@ def update_list():
428437
sparse_token_list_192 = [180]
429438
sparse_token_list_128 = [114]
430439
sparse_token_list_64 = [48]
440+
sparse_token_list_640 = [0.1979]
441+
sparse_token_list_320 = [0.0833]
442+
sparse_token_list_160 = [0.0261]
431443
elif prune_flag:
432444
sparse_token_list_192 = [192]
433445
sparse_token_list_128 = [128]
434446
sparse_token_list_64 = [64]
447+
sparse_token_list_640 = [0.2222]
448+
sparse_token_list_320 = [0.1111]
449+
sparse_token_list_160 = [0.0555]
435450
elif merge_flag:
436451
sparse_token_list_192 = [149]
437452
sparse_token_list_128 = [78]
@@ -444,7 +459,10 @@ def update_list():
444459
sparse_token_dict = {
445460
192: sparse_token_list_192,
446461
128: sparse_token_list_128,
447-
64: sparse_token_list_64
462+
64: sparse_token_list_64,
463+
640: sparse_token_list_640,
464+
320: sparse_token_list_320,
465+
160: sparse_token_list_160
448466
}
449467

450468

@@ -455,7 +473,8 @@ def attn_postprocess_topk(
455473
text_token_start,
456474
t_token_idx,
457475
layer_idx,
458-
retained_tokens):
476+
retained_tokens,
477+
reduction_ratio):
459478
'''
460479
self_attn_weights: [B, H, L, L]
461480
'''
@@ -470,13 +489,17 @@ def attn_postprocess_topk(
470489

471490
relation_vis = relation_vis_text
472491
s_flag = True # s_flag controls whether token merge is needed.
473-
474-
sparse_token_list = sparse_token_dict[retained_tokens]
475-
492+
if retained_tokens in [192, 128, 64]:
493+
sparse_token_list = sparse_token_dict[retained_tokens]
494+
else:
495+
sparse_token_list = sparse_token_dict[round((1 - reduction_ratio) * 2880)]
496+
retained_tokens_prune = sparse_token_list[layer_dict[layer_idx]]
497+
if retained_tokens_prune < 1:
498+
retained_tokens_prune = round(retained_tokens_prune * v_token_num)
476499
if v_token_num != 0:
477500
mask = torch.zeros_like(relation_vis, dtype=bool)
478501
_, indices = torch.topk(relation_vis, min(
479-
sparse_token_list[layer_dict[layer_idx]], v_token_num - 1), dim=1)
502+
retained_tokens_prune, v_token_num - 1), dim=1)
480503
mask[0][indices] = 1
481504
else:
482505
mask = torch.ones_like(relation_vis_text, dtype=bool)

llmc/compression/token_reduction/token_reduction_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def wrapper(self, *args, **kwargs):
4242

4343
message = (
4444
'To obtain the vision_token_length for LLaVA-1.6, you should append '
45-
'`image_features.shape[1]` to the return value of the function '
45+
'`image_features[0].shape[0]` to the return value of the function '
4646
'`prepare_inputs_labels_for_multimodal`, and modify the related code accordingly.'
4747
)
4848
outs = fn(*args, **kwargs)

0 commit comments

Comments
 (0)