Skip to content

Commit 1c37518

Browse files
authored
fix bugs (#383)
1 parent f7e07e7 commit 1c37518

File tree

6 files changed

+84
-39
lines changed

6 files changed

+84
-39
lines changed

llmc/compression/token_reduction/fastv.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
77

88
from .token_reduction_module import TokenReductionModule
9+
from .utils import prefill_wrapper
910

1011

1112
@TOKEN_REDUCTION_REGISTRY.register('FastV')
@@ -16,18 +17,25 @@ def __init__(self, config, model, blocks):
1617
self.register_reduction_modules()
1718

1819
def add_sparse_config(self):
19-
special_config = self.config.get('special', {})
20-
self.pruning_loc = special_config['pruning_loc']
21-
special_config['image_token_start_index'] = \
22-
self.model.pruning_config['image_token_start_index']
23-
special_config['image_token_length'] = \
20+
21+
self.pruning_loc = self.special_config['pruning_loc']
22+
self.special_config['image_token_length'] = \
2423
self.model.pruning_config['image_token_length']
25-
special_config['attn_scores'] = None
24+
self.special_config['attn_scores'] = None
2625

27-
self.model.model.parameters = special_config
26+
self.model.model.parameters = self.special_config
2827

2928
def register_reduction_modules(self):
3029

30+
@prefill_wrapper
31+
def input_hook(module, input_args, pruning_pars):
32+
input_ids = input_args[0]
33+
image_token_idxs = (input_ids[0] ==
34+
pruning_pars['vision_token_index']).nonzero(as_tuple=True)[0]
35+
pruning_pars['image_token_start_index'] = image_token_idxs[0].item()
36+
37+
return input_args
38+
3139
def update_output_attentions_hook(module, args, kwargs):
3240
kwargs['output_attentions'] = True
3341
return args, kwargs
@@ -36,6 +44,7 @@ def store_attention_hook(m, x, layer_outputs, pruning_pars):
3644
layer_attention = layer_outputs[1]
3745
pruning_pars['attn_scores'] = layer_attention
3846

47+
@prefill_wrapper
3948
def fastv_pruning_hook(module, args, kwargs, pruning_pars):
4049

4150
rate = pruning_pars['rate']
@@ -96,6 +105,7 @@ def fastv_pruning_hook(module, args, kwargs, pruning_pars):
96105

97106
return (hidden_states,), kwargs
98107

108+
@prefill_wrapper
99109
def read_parameter_hook(module, args, kwargs, pruning_pars):
100110
kwargs['attention_mask'] = pruning_pars['attention_mask']
101111
kwargs['cache_position'] = pruning_pars['cache_position']
@@ -104,6 +114,10 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
104114

105115
return args, kwargs
106116

117+
self.model.embed_tokens.register_forward_pre_hook(
118+
functools.partial(input_hook, pruning_pars=self.model.model.parameters)
119+
)
120+
107121
self.blocks[self.pruning_loc - 1].register_forward_pre_hook(
108122
update_output_attentions_hook,
109123
with_kwargs=True

llmc/compression/token_reduction/pyramiddrop.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
1111

1212
from .token_reduction_module import TokenReductionModule
13+
from .utils import prefill_wrapper
1314

1415

1516
@TOKEN_REDUCTION_REGISTRY.register('PyramidDrop')
@@ -20,38 +21,21 @@ def __init__(self, config, model, blocks):
2021
self.register_reduction_modules()
2122

2223
def add_sparse_config(self):
23-
special_config = self.config.get('special', {})
24-
self.pruning_loc = special_config['layer_list']
25-
image_token_ratio_list = special_config['image_token_ratio_list']
24+
25+
self.pruning_loc = self.special_config['layer_list']
26+
image_token_ratio_list = self.special_config['image_token_ratio_list']
2627
image_token_ratio_list.insert(0, 1.0)
27-
special_config['image_token_ratio_list'] = image_token_ratio_list
28-
special_config['tokenizer_padding_side'] = getattr(
28+
self.special_config['image_token_ratio_list'] = image_token_ratio_list
29+
self.special_config['tokenizer_padding_side'] = getattr(
2930
self.model.vlm_model.language_model.model.config,
3031
'tokenizer_padding_side',
3132
'right',
3233
)
33-
special_config['is_video_model'] = self.model.pruning_config['is_video_model']
34-
35-
# vision_token can be image or video
36-
if special_config['is_video_model']:
37-
special_config['vision_token_index'] = self.model.pruning_config[
38-
'video_token_index'
39-
]
40-
special_config['vision_token_length'] = self.model.pruning_config[
41-
'video_token_length'
42-
]
43-
else:
44-
special_config['vision_token_index'] = self.model.pruning_config[
45-
'image_token_index'
46-
]
47-
special_config['vision_token_length'] = self.model.pruning_config[
48-
'image_token_length'
49-
]
50-
51-
self.model.model.parameters = special_config
5234

53-
def register_reduction_modules(self):
35+
self.model.model.parameters = self.special_config
5436

37+
def register_reduction_modules(self):
38+
@prefill_wrapper
5539
def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):
5640

5741
if layer_idx == self.pruning_loc[0]:
@@ -315,10 +299,9 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):
315299

316300
return (new_input_embeds,), kwargs
317301

302+
@prefill_wrapper
318303
def input_hook(module, input_args, pruning_pars):
319-
# for the decoding stage
320-
if input_args[0].shape[1] == 1:
321-
return input_args
304+
322305
input_ids = input_args[0]
323306
pre_prompt_length_list = []
324307
image_token_posi = []
@@ -338,9 +321,8 @@ def input_hook(module, input_args, pruning_pars):
338321

339322
return input_args
340323

324+
@prefill_wrapper
341325
def read_parameter_hook(module, args, kwargs, pruning_pars):
342-
if args[0].shape[1] == 1:
343-
return args, kwargs
344326
kwargs['attention_mask'] = pruning_pars['attention_mask']
345327
# kwargs['cache_position'] = pruning_pars['cache_position']
346328
kwargs['position_ids'] = pruning_pars['position_ids']

llmc/compression/token_reduction/sparsevlm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
77

88
from .token_reduction_module import TokenReductionModule
9+
from .utils import prefill_wrapper, prefill_wrapper_model
910

1011

1112
@TOKEN_REDUCTION_REGISTRY.register('SparseVLM')
@@ -29,7 +30,7 @@ def add_sparse_config(self):
2930
self.model.model.parameters = special_config
3031

3132
def register_reduction_modules(self):
32-
33+
@prefill_wrapper
3334
def input_hook(module, input_args, pruning_pars):
3435
input_ids = input_args[0]
3536
pre_prompt_length_list = []
@@ -51,6 +52,7 @@ def input_hook(module, input_args, pruning_pars):
5152

5253
return input_args
5354

55+
@prefill_wrapper_model
5456
def register_module_pars(module, args, kwargs, pruning_pars):
5557
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
5658
inputs_embeds = kwargs['inputs_embeds']
@@ -92,6 +94,7 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
9294
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
9395
return args, kwargs
9496

97+
@prefill_wrapper
9598
def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx):
9699

97100
attn_logits = layer_outputs[1]
@@ -195,6 +198,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
195198

196199
return new_output
197200

201+
@prefill_wrapper
198202
def read_parameter_hook(module, args, kwargs, pruning_pars):
199203
kwargs['position_ids'] = pruning_pars['position_ids']
200204
kwargs['cache_position'] = pruning_pars['cache_position']

llmc/compression/token_reduction/token_reduction_module.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,26 @@ def __init__(self, config, model, blocks):
44
self.config = config
55
self.model = model
66
self.blocks = blocks
7+
self.set_sparse_config()
8+
9+
def set_sparse_config(self):
10+
self.special_config = self.config.get('special', {})
11+
self.special_config['is_video_model'] = self.model.pruning_config['is_video_model']
12+
# vision_token can be image or video
13+
if self.special_config['is_video_model']:
14+
self.special_config['vision_token_index'] = self.model.pruning_config[
15+
'video_token_index'
16+
]
17+
self.special_config['vision_token_length'] = self.model.pruning_config[
18+
'video_token_length'
19+
]
20+
else:
21+
self.special_config['vision_token_index'] = self.model.pruning_config[
22+
'image_token_index'
23+
]
24+
self.special_config['vision_token_length'] = self.model.pruning_config[
25+
'image_token_length'
26+
]
727

828
def register_reduction_modules(self):
929
pass

llmc/compression/token_reduction/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,35 @@
1+
from functools import wraps
12
from typing import Any, List, Optional, Tuple, Union
23

34
import torch
45
import torch.nn as nn
56
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
67

78

9+
def prefill_wrapper(func):
10+
@wraps(func)
11+
def wrapper(*args, **kwargs):
12+
# for the decoding stage
13+
if len(args) > 1:
14+
input_args = args[1]
15+
if hasattr(input_args[0], 'shape') and input_args[0].shape[1] == 1:
16+
return None
17+
return func(*args, **kwargs)
18+
return wrapper
19+
20+
21+
def prefill_wrapper_model(func):
22+
@wraps(func)
23+
def wrapper(*args, **kwargs):
24+
# for the decoding stage
25+
if len(args) > 1:
26+
input_args = args[2]['inputs_embeds']
27+
if hasattr(input_args, 'shape') and input_args.shape[1] == 1:
28+
return None
29+
return func(*args, **kwargs)
30+
return wrapper
31+
32+
833
def parse_r(num_layers: int, r: Union[List[int], Tuple[int, float], int]) -> List[int]:
934
"""Copy from the TOME. https://github.com/facebookresearch/ToMe.
1035

llmc/models/llava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def safe_prepare_inputs_for_generation(
9696
self.model = self.vlm_model
9797
self.model_config = self.vlm_model_config.text_config
9898
self.pruning_config = {
99-
'image_token_start_index': 5,
99+
'is_video_model': False,
100100
'image_token_length': self.vlm_model_config.image_seq_length,
101101
'select_layer': self.vlm_model_config.vision_feature_layer,
102102
'select_feature': self.vlm_model_config.vision_feature_select_strategy,

0 commit comments

Comments
 (0)