Skip to content

Commit 2c61449

Browse files
authored
Optimize code of FastV and fix SparseVLM's bugs related to LLaVA. (#402)
1 parent 3f725ca commit 2c61449

File tree

2 files changed

+133
-49
lines changed

2 files changed

+133
-49
lines changed

llmc/compression/token_reduction/fastv.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
from functools import wraps
23
from types import MethodType
34

45
import torch
@@ -39,26 +40,23 @@ def input_hook(module, input_args, pruning_paras):
3940

4041
return input_args
4142

42-
def make_hook_prepare_inputs_labels_for_multimodal(pruning_paras):
43-
def hook_prepare_inputs_labels_for_multimodal(
44-
self,
45-
input_ids,
46-
position_ids,
47-
attention_mask,
48-
past_key_values,
49-
labels,
50-
images,
51-
modalities=['image'],
52-
image_sizes=None,
53-
):
54-
if 'image_token_start_index' not in pruning_paras:
55-
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
56-
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
57-
return self._original_prepare_inputs_labels_for_multimodal(
58-
input_ids, position_ids, attention_mask,
59-
past_key_values, labels, images, modalities, image_sizes
60-
)
61-
return hook_prepare_inputs_labels_for_multimodal
43+
def input_hook_llava(fn, pruning_paras):
44+
@wraps(fn)
45+
def wrapper(self, *args, **kwargs):
46+
if len(args) == 0:
47+
return fn(*args, **kwargs)
48+
input_args = args[0]
49+
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
50+
return fn(*args, **kwargs)
51+
52+
input_ids = args[0]
53+
attention_mask = args[2]
54+
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
55+
pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item()
56+
57+
outputs = fn(*args, **kwargs)
58+
return outputs
59+
return wrapper
6260

6361
def update_output_attentions_hook(module, args, kwargs, pruning_paras):
6462
kwargs['output_attentions'] = True
@@ -129,9 +127,10 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):
129127
functools.partial(input_hook, pruning_paras=self.pruning_paras)
130128
)
131129
elif self.model.__class__.__name__ == 'Llava':
132-
hook_fn = make_hook_prepare_inputs_labels_for_multimodal(self.pruning_paras)
133-
self.model.vlm_model._original_prepare_inputs_labels_for_multimodal = (
134-
self.model.vlm_model.prepare_inputs_labels_for_multimodal
130+
from llava.constants import IMAGE_TOKEN_INDEX
131+
hook_fn = input_hook_llava(
132+
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
133+
self.pruning_paras
135134
)
136135
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
137136
hook_fn, self.model.vlm_model

llmc/compression/token_reduction/sparsevlm.py

Lines changed: 111 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import copy
12
import functools
3+
import math
24
from functools import wraps
35
from types import MethodType
46

@@ -66,22 +68,26 @@ def wrapper(self, *args, **kwargs):
6668
input_ids = args[0]
6769
attention_mask = args[2]
6870

71+
if attention_mask is None:
72+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
73+
else:
74+
attention_mask = attention_mask.bool()
75+
6976
pre_prompt_length_list = []
7077
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask):
7178
seq = cur_input_ids[cur_attention_mask]
72-
image_token_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist()
73-
if len(image_token_index) > 0:
74-
pre_prompt_length_list.append(image_token_index[0])
75-
else:
76-
pre_prompt_length_list.append(0)
79+
image_token_index = (
80+
[-1]
81+
+ torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist()
82+
+ [seq.shape[0]]
83+
)
84+
pre_prompt_length_list.append(image_token_index[1])
85+
7786
pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list
7887

7988
outputs = fn(*args, **kwargs)
8089

81-
token_length_list = []
82-
for cur_attention_mask in outputs[2]:
83-
token_length_list.append(cur_attention_mask.sum().item())
84-
pruning_paras['token_length_list'] = token_length_list
90+
pruning_paras['token_length_list'] = outputs[2].sum(dim=1).tolist()
8591

8692
return outputs
8793
return wrapper
@@ -128,14 +134,90 @@ def update_output_attentions_hook(module, args, kwargs, pruning_pars, layer_idx)
128134
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
129135
return args, kwargs
130136

137+
def get_attn_logits_hook(module, args, kwargs, pruning_pars, layer_idx):
138+
139+
if len(kwargs['position_ids'][0]) == 1:
140+
return args, kwargs
141+
142+
from transformers.models.llama.modeling_llama import \
143+
apply_rotary_pos_emb
144+
145+
if layer_idx != self.pruning_loc[0]:
146+
kwargs['position_ids'] = pruning_pars['position_ids']
147+
kwargs['cache_position'] = pruning_pars['cache_position']
148+
kwargs['position_embeddings'] = pruning_pars['position_embeddings']
149+
150+
hidden_states = kwargs['hidden_states']
151+
position_embeddings = kwargs['position_embeddings']
152+
position_ids = kwargs['position_ids']
153+
past_key_value = kwargs['past_key_value']
154+
cache_position = kwargs['cache_position']
155+
attention_mask = kwargs['attention_mask']
156+
157+
t_token_idx = pruning_pars['t_token_idx']
158+
v_token_start = pruning_pars['v_token_start']
159+
v_token_num = pruning_pars['v_token_num']
160+
161+
bsz, q_len, _ = hidden_states.size()
162+
query_states = module.q_proj(hidden_states)
163+
key_states = module.k_proj(hidden_states)
164+
value_states = module.v_proj(hidden_states)
165+
query_states = query_states.view(
166+
bsz, q_len, module.num_heads, module.head_dim
167+
).transpose(1, 2)
168+
key_states = key_states.view(
169+
bsz, q_len, module.num_key_value_heads, module.head_dim
170+
).transpose(1, 2)
171+
value_states = value_states.view(
172+
bsz, q_len, module.num_key_value_heads, module.head_dim
173+
).transpose(1, 2)
174+
175+
if position_embeddings is None:
176+
cos, sin = module.rotary_emb(value_states, position_ids)
177+
else:
178+
cos, sin = position_embeddings
179+
180+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
181+
if past_key_value is not None:
182+
temp_cache = copy.deepcopy(past_key_value)
183+
cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position}
184+
key_states, value_states = temp_cache.update(
185+
key_states, value_states,
186+
layer_idx, cache_kwargs
187+
)
188+
t_token_idx = t_token_idx[1] + v_token_start + v_token_num
189+
L, S = query_states.size(-2), key_states.size(-2)
190+
scale_factor = 1 / math.sqrt(query_states.size(-1))
191+
attn_bias = torch.zeros(L, S, dtype=query_states.dtype)
192+
if module.is_causal:
193+
assert attention_mask is None
194+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
195+
attn_bias.masked_fill_(temp_mask.logical_not(), float('-inf'))
196+
attn_bias.to(query_states.dtype)
197+
198+
attn_logits = query_states @ key_states.transpose(2, 3) * scale_factor
199+
attn_logits += attn_bias.to(query_states.device)
200+
attn_logits = torch.softmax(attn_logits, dim=-1)
201+
202+
pruning_pars['attn_logits'] = attn_logits
203+
204+
return args, kwargs
205+
131206
@prefill_wrapper
132207
def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer_idx):
133208

134-
attn_logits = layer_outputs[1]
209+
# pruning_pars['attn_logits'] 对llavaHf运行存在BUG,
210+
# 使用layer_outputs[1]运行llavaHf无问题,但精度没对上
211+
# llava:attn_logits = pruning_pars['attn_logits']
212+
# llavahf:attn_logits = layer_outputs[1]
213+
if 'attn_logits' not in pruning_pars:
214+
attn_logits = layer_outputs[1]
215+
else:
216+
attn_logits = pruning_pars['attn_logits']
135217
v_token_start = pruning_pars['v_token_start']
218+
v_token_num = pruning_pars['v_token_num']
136219
text_token_start = pruning_pars['text_token_start']
137220
t_token_idx = pruning_pars['t_token_idx']
138-
v_token_num = pruning_pars['v_token_num']
139221
retained_tokens = pruning_pars['retained_tokens']
140222
B = pruning_pars['B']
141223
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
@@ -145,10 +227,7 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
145227
pruning_pars['position_ids'] = position_ids
146228
else:
147229
position_ids = pruning_pars['position_ids']
148-
149230
hidden_states = inputs[0] # [B, L, D]
150-
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
151-
image_shape = pruning_pars['image_shape']
152231

153232
pred_score_vis, s_flag, relation_vis_text = attn_postprocess_topk(
154233
attn_logits,
@@ -177,7 +256,6 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
177256

178257
# merge and cluster
179258
if s_flag and total_sparse_token_idx.shape[1] > 0:
180-
total_sparse_token_idx = torch.where(policy == 0)[1].unsqueeze(0)
181259
total_sparse_token = batch_index_select(layer_outputs[0], total_sparse_token_idx)
182260

183261
merge_token_idx_stage1 = torch.where(pred_score_vis == 0)[1]
@@ -208,20 +286,17 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_pars, layer
208286
)
209287
layer_outputs = (select_and_merge_token, layer_outputs[1])
210288
position_ids = position_ids[:, :len(select_token_idx[0]) + cluster_num]
211-
# prev_decision = policy
212289
v_token_num = pred_score_vis.sum() + cluster_num
213290
text_token_start = v_token_start + v_token_num
214291
else:
215292
select_token_idx = torch.where(policy == 1)[1].unsqueeze(0)
216293
layer_outputs = (batch_index_select(layer_outputs[0], select_token_idx),
217294
layer_outputs[1])
218295
position_ids = position_ids[:, :len(select_token_idx[0])]
219-
# prev_decision = policy
220296
v_token_num = pred_score_vis.sum()
221297
text_token_start = v_token_start + v_token_num
222298

223299
new_output = layer_outputs
224-
# hidden_states = layer_outputs[0]
225300
cache_position = position_ids.detach().clone()
226301

227302
pruning_pars['v_token_num'] = v_token_num
@@ -273,14 +348,24 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
273348

274349
for block_idx in range(sorted_pruning_locs[0], total_layers):
275350
if block_idx in sorted_pruning_locs:
276-
self.blocks[block_idx].register_forward_pre_hook(
277-
functools.partial(
278-
update_output_attentions_hook,
279-
pruning_pars=self.pruning_paras,
280-
layer_idx=block_idx,
281-
),
282-
with_kwargs=True
283-
)
351+
if self.model.__class__.__name__ == 'LlavaHf':
352+
self.blocks[block_idx].register_forward_pre_hook(
353+
functools.partial(
354+
update_output_attentions_hook,
355+
pruning_pars=self.pruning_paras,
356+
layer_idx=block_idx,
357+
),
358+
with_kwargs=True
359+
)
360+
elif self.model.__class__.__name__ == 'Llava':
361+
self.blocks[block_idx].self_attn.register_forward_pre_hook(
362+
functools.partial(
363+
get_attn_logits_hook,
364+
pruning_pars=self.pruning_paras,
365+
layer_idx=block_idx,
366+
),
367+
with_kwargs=True
368+
)
284369
self.blocks[block_idx].register_forward_hook(
285370
functools.partial(
286371
decoder_attn_hook,

0 commit comments

Comments
 (0)