Skip to content

Commit 5f2895f

Browse files
authored
support qwen2.5vl(fastv,dart,visionzip) (#423)
1 parent 343847a commit 5f2895f

File tree

14 files changed

+708
-228
lines changed

14 files changed

+708
-228
lines changed

configs/sparsification/methods/DART/dart.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ sparse:
1616
method: TokenReduction
1717
special:
1818
method: DART
19-
pruning_loc: 2
19+
pruning_loc: 5
2020
reduction_ratio: 0.778
21-
max_num_trunction: 128
2221
pivot_image_token: 4
2322
pivot_text_token : 4
2423
save:

configs/sparsification/methods/FastV/fastv.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ sparse:
1717
special:
1818
method: FastV
1919
pruning_loc: 3
20-
rate: 0.778
20+
rate: 0.778 # prune_rate
2121
save:
2222
save_trans: False
2323
save_fake: False

configs/sparsification/methods/VisionZip/visionzip.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ sparse:
1616
vision:
1717
method: TokenReduction
1818
special:
19-
method: VisionZip
19+
method: VisionZip # retain
2020
dominant: 191 # visual_tokens = dominan_tokens + 1(cls_token)
2121
contextual: 30
2222
save:

llmc/compression/token_reduction/dart.py

Lines changed: 31 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import functools
22
import math
3-
from functools import wraps
4-
from types import MethodType
53

64
import torch
75

@@ -19,95 +17,43 @@ def __init__(self, config, model, blocks):
1917
self.register_reduction_modules()
2018

2119
def add_sparse_config(self):
22-
2320
self.pruning_loc = self.special_config['pruning_loc']
24-
self.special_config['image_token_length'] = \
25-
self.model.pruning_config['image_token_length']
26-
self.special_config['IMAGE_TOKEN_INDEX'] = \
27-
self.model.pruning_config['IMAGE_TOKEN_INDEX']
2821

2922
self.pruning_paras = self.special_config
3023

3124
def register_reduction_modules(self):
3225

33-
def input_hook_llava(fn, pruning_paras):
34-
@wraps(fn)
35-
def wrapper(self, *args, **kwargs):
36-
if len(args) == 0:
37-
return fn(*args, **kwargs)
38-
input_args = args[0]
39-
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
40-
return fn(*args, **kwargs)
41-
42-
input_ids = args[0]
43-
attention_mask = args[2]
44-
token_indices = (
45-
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
46-
)
47-
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item()
26+
@prefill_wrapper
27+
def vtoken_length_hook(module, input_args, pruning_paras):
4828

49-
outputs = fn(*args, **kwargs)
50-
return outputs
51-
return wrapper
29+
input_ids = input_args[0]
30+
token_indices = torch.where(
31+
input_ids[0] == pruning_paras['vision_token_index']
32+
)[0]
33+
pruning_paras['vision_token_length'] = token_indices.shape[0]
5234

53-
def get_seq_len_hook(module, args, kwargs, pruning_paras):
54-
if kwargs['input_ids'] is not None:
55-
pruning_paras['seq_len'] = kwargs['input_ids'].shape[1]
56-
elif kwargs['inputs_embeds'] is not None:
57-
pruning_paras['seq_len'] = kwargs['inputs_embeds'].shape[1]
58-
else:
59-
raise ValueError('You have to specify either input_ids or inputs_embeds')
35+
return input_args
6036

37+
@prefill_wrapper
6138
def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_idx):
62-
from transformers.models.llama.modeling_llama import (
63-
apply_rotary_pos_emb, repeat_kv)
64-
if len(kwargs['position_ids'][0]) == 1:
65-
return layer_outs
6639

67-
hidden_states = kwargs['hidden_states']
68-
position_embeddings = kwargs['position_embeddings']
69-
position_ids = kwargs['position_ids']
70-
past_key_value = layer_outs[2]
71-
72-
bsz, q_len, _ = hidden_states.size()
73-
query_states = module.q_proj(hidden_states)
74-
key_states = module.k_proj(hidden_states)
75-
value_states = module.v_proj(hidden_states)
76-
query_states = query_states.view(
77-
bsz, q_len, module.num_heads, module.head_dim
78-
).transpose(1, 2)
79-
key_states = key_states.view(
80-
bsz, q_len, module.num_key_value_heads, module.head_dim
81-
).transpose(1, 2)
82-
value_states = value_states.view(
83-
bsz, q_len, module.num_key_value_heads, module.head_dim
84-
).transpose(1, 2)
85-
86-
if position_embeddings is None:
87-
cos, sin = module.rotary_emb(value_states, position_ids)
88-
else:
89-
cos, sin = position_embeddings
90-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
91-
if past_key_value is not None:
92-
key_states = past_key_value.key_cache[layer_idx]
93-
value_states = past_key_value.value_cache[layer_idx]
94-
key_states = repeat_kv(key_states, module.num_key_value_groups)
95-
value_states = repeat_kv(value_states, module.num_key_value_groups)
96-
97-
pruning_paras['any_states'] = (query_states, key_states, value_states)
40+
past_key_value = kwargs['past_key_value']
41+
if past_key_value is None:
42+
raise ValueError('DART needs past_key_value but got None.')
43+
pruning_paras['any_states'] = past_key_value.key_cache[layer_idx]
9844

9945
return layer_outs
10046

10147
@prefill_wrapper
10248
def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
10349

104-
image_token_start_index = pruning_paras['image_token_start_index']
105-
image_token_length = pruning_paras['image_token_length']
106-
any_states = pruning_paras['any_states'][-2]
107-
seq_length = pruning_paras['seq_len']
50+
image_token_start_index = pruning_paras['vision_token_start_index']
51+
image_token_length = pruning_paras['vision_token_length']
52+
any_states = pruning_paras['any_states']
10853

10954
hidden_states = args[0]
11055
attention_mask = kwargs['attention_mask']
56+
seq_length = hidden_states.shape[1]
11157
device = hidden_states.device
11258
last_layer_state = normlayer(hidden_states)
11359

@@ -140,27 +86,20 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
14086
kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone())
14187

14288
position_embeddings = kwargs['position_embeddings']
143-
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
144-
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
89+
index_dim = 1 if position_embeddings[0].dim() == 3 else 2
90+
new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone()
91+
new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone()
14592
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
14693
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)
14794

14895
return (hidden_states,), kwargs
14996

150-
hook_fn = input_hook_llava(
151-
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
152-
self.pruning_paras
153-
)
154-
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
155-
hook_fn, self.model.vlm_model
156-
)
157-
158-
self.model.model.model.register_forward_pre_hook(
159-
functools.partial(get_seq_len_hook, pruning_paras=self.pruning_paras),
160-
with_kwargs=True
161-
)
97+
if self.special_config['vision_token_length'] is None:
98+
self.model.embed_tokens.register_forward_pre_hook(
99+
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
100+
)
162101

163-
self.blocks[self.pruning_loc - 1].self_attn.register_forward_hook(
102+
self.blocks[self.pruning_loc - 1].register_forward_hook(
164103
functools.partial(
165104
get_any_states_hook,
166105
pruning_paras=self.pruning_paras,
@@ -173,24 +112,21 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
173112
functools.partial(
174113
pruning_hook,
175114
pruning_paras=self.pruning_paras,
176-
normlayer=self.model.model.model.norm
115+
normlayer=self.model.language_model.norm
177116
),
178117
with_kwargs=True
179118
)
180119

181120

182121
def get_retained_image_token(pruning_paras, last_layer_state, any_states):
183-
image_token_start_index = pruning_paras['image_token_start_index']
184-
image_token_length = pruning_paras['image_token_length']
185-
MAX_NUM_TRUNCTION = pruning_paras['max_num_trunction']
122+
image_token_start_index = pruning_paras['vision_token_start_index']
123+
image_token_length = pruning_paras['vision_token_length']
186124
pivot_image_token = pruning_paras['pivot_image_token']
187125
pivot_text_token = pruning_paras['pivot_text_token']
188126
reduction_ratio = pruning_paras['reduction_ratio']
189-
TOKEN_TOPK = math.ceil(
190-
(
191-
MAX_NUM_TRUNCTION if MAX_NUM_TRUNCTION is not None
192-
else (image_token_length * (1 - reduction_ratio))
193-
) // (pivot_image_token + pivot_text_token))
127+
TOKEN_TOPK = int(
128+
image_token_length * (1 - reduction_ratio) / (pivot_image_token + pivot_text_token)
129+
)
194130
device = last_layer_state.device
195131

196132
any_states = any_states.permute(0, 2, 1, 3)

llmc/compression/token_reduction/fastv.py

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import functools
2-
from functools import wraps
3-
from types import MethodType
42

53
import torch
64

@@ -18,46 +16,22 @@ def __init__(self, config, model, blocks):
1816
self.register_reduction_modules()
1917

2018
def add_sparse_config(self):
21-
2219
self.pruning_loc = self.special_config['pruning_loc']
23-
self.special_config['image_token_length'] = \
24-
self.model.pruning_config['image_token_length']
25-
self.special_config['IMAGE_TOKEN_INDEX'] = \
26-
self.model.pruning_config['IMAGE_TOKEN_INDEX']
27-
self.special_config['attn_scores'] = None
2820

2921
self.pruning_paras = self.special_config
3022

3123
def register_reduction_modules(self):
3224

3325
@prefill_wrapper
34-
def input_hook(module, input_args, pruning_paras):
26+
def vtoken_length_hook(module, input_args, pruning_paras):
3527
input_ids = input_args[0]
36-
image_token_idxs = (input_ids[0] ==
37-
pruning_paras['vision_token_index']).nonzero(as_tuple=True)[0]
38-
pruning_paras['image_token_start_index'] = image_token_idxs[0].item()
39-
28+
token_indices = torch.where(
29+
input_ids[0] == pruning_paras['vision_token_index']
30+
)[0]
31+
pruning_paras['vision_token_length'] = token_indices.shape[0]
4032
return input_args
4133

42-
def input_hook_llava(fn, pruning_paras):
43-
@wraps(fn)
44-
def wrapper(self, *args, **kwargs):
45-
if len(args) == 0:
46-
return fn(*args, **kwargs)
47-
input_args = args[0]
48-
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
49-
return fn(*args, **kwargs)
50-
51-
input_ids = args[0]
52-
attention_mask = args[2]
53-
token_indices = \
54-
input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX']
55-
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item()
56-
57-
outputs = fn(*args, **kwargs)
58-
return outputs
59-
return wrapper
60-
34+
@prefill_wrapper
6135
def update_output_attentions_hook(module, args, kwargs, pruning_paras):
6236
kwargs['output_attentions'] = True
6337
pruning_paras['attn_scores'] = module.__class__.forward(module, *args, **kwargs)[1]
@@ -68,8 +42,8 @@ def update_output_attentions_hook(module, args, kwargs, pruning_paras):
6842
def fastv_pruning_hook(module, args, kwargs, pruning_paras):
6943

7044
rate = pruning_paras['rate']
71-
image_token_start_index = pruning_paras['image_token_start_index']
72-
image_token_length = pruning_paras['image_token_length']
45+
image_token_start_index = pruning_paras['vision_token_start_index']
46+
image_token_length = pruning_paras['vision_token_length']
7347

7448
hidden_states = args[0]
7549
causal_mask = kwargs['attention_mask']
@@ -121,24 +95,17 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):
12195
kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone())
12296

12397
position_embeddings = kwargs['position_embeddings']
124-
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
125-
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
98+
index_dim = 1 if position_embeddings[0].dim() == 3 else 2
99+
new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone()
100+
new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone()
126101
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
127102
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)
128103

129104
return (hidden_states,), kwargs
130105

131-
if self.model.__class__.__name__ == 'LlavaHf':
106+
if self.special_config['vision_token_length'] is None:
132107
self.model.embed_tokens.register_forward_pre_hook(
133-
functools.partial(input_hook, pruning_paras=self.pruning_paras)
134-
)
135-
elif self.model.__class__.__name__ == 'Llava':
136-
hook_fn = input_hook_llava(
137-
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
138-
self.pruning_paras
139-
)
140-
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
141-
hook_fn, self.model.vlm_model
108+
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
142109
)
143110

144111
self.blocks[self.pruning_loc - 1].register_forward_pre_hook(

llmc/compression/token_reduction/token_reduction_module.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@ def set_sparse_config(self):
2323
'video_token_length'
2424
]
2525
else:
26-
self.special_config['vision_token_index'] = self.model.pruning_config[
27-
'image_token_index'
28-
]
29-
self.special_config['vision_token_length'] = self.model.pruning_config[
30-
'image_token_length'
31-
]
26+
self.special_config['vision_token_index'] = self.model.pruning_config.get(
27+
'image_token_index', None
28+
)
29+
self.special_config['vision_token_start_index'] = self.model.pruning_config.get(
30+
'vision_token_start_index', None
31+
)
32+
self.special_config['vision_token_length'] = self.model.pruning_config.get(
33+
'image_token_length', None
34+
)
3235

3336
def register_reduction_modules(self):
3437
pass

llmc/compression/token_reduction/utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,12 @@ def make_tome_class(transformer_class):
6363
class VisionZipTransformer(transformer_class):
6464
"""
6565
Modifications:
66-
- Initialize r, token size, and token sources.
66+
- Initialize r
6767
"""
68-
69-
def forward(self, *args, **kwdargs) -> torch.Tensor:
68+
def forward(self, *args, **kwargs) -> torch.Tensor:
7069
self._info['r'] = parse_r(len(self.vision_model.encoder.layers), self.r)
7170
# self._info["r"] = self.r
72-
73-
self._info['size'] = None
74-
self._info['source'] = None
75-
76-
return super().forward(*args, **kwdargs)
71+
return super().forward(*args, **kwargs)
7772

7873
return VisionZipTransformer
7974

@@ -93,7 +88,6 @@ def apply_info(model, dominant_num, contextual_num):
9388
for module in model.modules():
9489
if isinstance(module, CLIPEncoderLayer):
9590
module.self_attn.k_proj._info = model._info
96-
module.self_attn.k_proj.metric = None
9791

9892

9993
def add_post_hook_to_get_2dPool(model, post_hook_fn, pruning_paras):

0 commit comments

Comments
 (0)