Skip to content

Commit 56afa2d

Browse files
authored
PyramidDrop and SparseVLM for llava (#396)
1 parent 948c6ed commit 56afa2d

File tree

4 files changed

+127
-25
lines changed

4 files changed

+127
-25
lines changed

llmc/compression/token_reduction/holitom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def SigLipEncoder_forward(
3535
output_attentions: Optional[bool] = None,
3636
output_hidden_states: Optional[bool] = None,
3737
return_dict: Optional[bool] = None,
38-
) -> Union[Tuple, BaseModelOutput]:
38+
) -> Union[Tuple]:
3939
output_attentions = (
4040
output_attentions
4141
if output_attentions is not None

llmc/compression/token_reduction/pyramiddrop.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import functools
22
import math
3+
from functools import wraps
4+
from types import MethodType
35

46
import torch
57
from torch import nn
@@ -26,13 +28,17 @@ def add_sparse_config(self):
2628
image_token_ratio_list = self.special_config['image_token_ratio_list']
2729
image_token_ratio_list.insert(0, 1.0)
2830
self.special_config['image_token_ratio_list'] = image_token_ratio_list
31+
if self.model.__class__.__name__ == 'LlavaHf':
32+
llama_model = self.model.vlm_model.language_model.model
33+
elif self.model.__class__.__name__ == 'Llava':
34+
llama_model = self.model.vlm_model.model
2935
self.special_config['tokenizer_padding_side'] = getattr(
30-
self.model.vlm_model.language_model.model.config,
36+
llama_model.config,
3137
'tokenizer_padding_side',
3238
'right',
3339
)
3440

35-
self.model.model.parameters = self.special_config
41+
self.pruning_paras = self.special_config
3642

3743
def register_reduction_modules(self):
3844
@prefill_wrapper
@@ -214,8 +220,12 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):
214220
attention_mask_list.append(new_attention_mask)
215221

216222
# Truncate sequences to max length as image embeddings can make the sequence longer
223+
if self.model.__class__.__name__ == 'LlavaHf':
224+
llama_model = self.model.vlm_model.language_model.model
225+
elif self.model.__class__.__name__ == 'Llava':
226+
llama_model = self.model.vlm_model.model
217227
tokenizer_model_max_length = getattr(
218-
self.model.vlm_model.language_model.model.config,
228+
llama_model.config,
219229
'tokenizer_model_max_length',
220230
2048,
221231
)
@@ -321,6 +331,39 @@ def input_hook(module, input_args, pruning_pars):
321331

322332
return input_args
323333

334+
def input_hook_llava(fn, pruning_paras):
335+
@wraps(fn)
336+
def wrapper(self, *args, **kwargs):
337+
if len(args) == 0:
338+
return fn(*args, **kwargs)
339+
input_args = args[0]
340+
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
341+
return fn(*args, **kwargs)
342+
343+
input_ids = args[0]
344+
attention_mask = args[2]
345+
346+
image_token_posi = []
347+
prompt_len = []
348+
vision_tokens = []
349+
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask):
350+
seq = cur_input_ids[cur_attention_mask]
351+
image_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist()
352+
if image_index == []:
353+
image_token_posi.append(-1)
354+
prompt_len.append(cur_input_ids.shape[0])
355+
else:
356+
image_token_posi.append(image_index[0])
357+
prompt_len.append(cur_input_ids.shape[0] - 1)
358+
vision_tokens.append(pruning_paras['vision_token_length'])
359+
360+
pruning_paras['image_token_posi'] = image_token_posi
361+
pruning_paras['prompt_len'] = prompt_len
362+
pruning_paras['image_tokens'] = vision_tokens
363+
364+
return fn(*args, **kwargs)
365+
return wrapper
366+
324367
@prefill_wrapper
325368
def read_parameter_hook(module, args, kwargs, pruning_pars):
326369
kwargs['attention_mask'] = pruning_pars['attention_mask']
@@ -330,17 +373,27 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
330373

331374
return args, kwargs
332375

333-
self.model.embed_tokens.register_forward_pre_hook(
334-
functools.partial(input_hook, pruning_pars=self.model.model.parameters)
335-
)
376+
if self.model.__class__.__name__ == 'LlavaHf':
377+
self.model.embed_tokens.register_forward_pre_hook(
378+
functools.partial(input_hook, pruning_pars=self.pruning_paras)
379+
)
380+
elif self.model.__class__.__name__ == 'Llava':
381+
from llava.constants import IMAGE_TOKEN_INDEX
382+
hook_fn = input_hook_llava(
383+
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
384+
self.pruning_paras
385+
)
386+
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
387+
hook_fn, self.model.vlm_model
388+
)
336389

337390
for layer_idx in range(self.pruning_loc[0], len(self.blocks)):
338391
if layer_idx in self.pruning_loc:
339392
stage = self.pruning_loc.index(layer_idx)
340393
self.blocks[layer_idx].register_forward_pre_hook(
341394
functools.partial(
342395
pruning_hook,
343-
pruning_pars=self.model.model.parameters,
396+
pruning_pars=self.pruning_paras,
344397
cur_num=stage,
345398
layer_idx=layer_idx,
346399
),
@@ -349,7 +402,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
349402
else:
350403
self.blocks[layer_idx].register_forward_pre_hook(
351404
functools.partial(
352-
read_parameter_hook, pruning_pars=self.model.model.parameters
405+
read_parameter_hook, pruning_pars=self.pruning_paras
353406
),
354407
with_kwargs=True,
355408
)

llmc/compression/token_reduction/sparsevlm.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import functools
2+
from functools import wraps
3+
from types import MethodType
24

35
import einops as ein
46
import torch
@@ -27,7 +29,7 @@ def add_sparse_config(self):
2729
special_config['token_length_list'] = []
2830
special_config['image_shape'] = self.model.pruning_config['image_token_length']
2931
special_config['image_token_index'] = self.model.pruning_config['image_token_index']
30-
self.model.model.parameters = special_config
32+
self.pruning_paras = special_config
3133

3234
def register_reduction_modules(self):
3335
@prefill_wrapper
@@ -52,16 +54,48 @@ def input_hook(module, input_args, pruning_pars):
5254

5355
return input_args
5456

57+
def input_hook_llava(fn, pruning_paras):
58+
@wraps(fn)
59+
def wrapper(self, *args, **kwargs):
60+
if len(args) == 0:
61+
return fn(*args, **kwargs)
62+
input_args = args[0]
63+
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
64+
return fn(*args, **kwargs)
65+
66+
input_ids = args[0]
67+
attention_mask = args[2]
68+
69+
pre_prompt_length_list = []
70+
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask):
71+
seq = cur_input_ids[cur_attention_mask]
72+
image_token_index = torch.where(seq == IMAGE_TOKEN_INDEX)[0].tolist()
73+
if len(image_token_index) > 0:
74+
pre_prompt_length_list.append(image_token_index[0])
75+
else:
76+
pre_prompt_length_list.append(0)
77+
pruning_paras['pre_prompt_length_list'] = pre_prompt_length_list
78+
79+
outputs = fn(*args, **kwargs)
80+
81+
token_length_list = []
82+
for cur_attention_mask in outputs[2]:
83+
token_length_list.append(cur_attention_mask.sum().item())
84+
pruning_paras['token_length_list'] = token_length_list
85+
86+
return outputs
87+
return wrapper
88+
5589
@prefill_wrapper_model
5690
def register_module_pars(module, args, kwargs, pruning_pars):
5791
pre_prompt_length_list = pruning_pars['pre_prompt_length_list']
5892
inputs_embeds = kwargs['inputs_embeds']
5993
if inputs_embeds is None:
60-
inputs_embeds = self.embed_tokens(kwargs['input_ids'])
94+
inputs_embeds = module.embed_tokens(kwargs['input_ids'])
6195
hidden_states = inputs_embeds # shape: (B, L, C)
6296

63-
pruning_pars['B'], L, _ = hidden_states.shape
64-
B = pruning_pars['B']
97+
B, L, _ = hidden_states.shape
98+
pruning_pars['B'] = B
6599
init_n = pruning_pars['init_token_total_shape'] + \
66100
pruning_pars['generate_process_count'] # 668
67101
pruning_pars['prev_decision'] = torch.ones(
@@ -80,7 +114,7 @@ def register_module_pars(module, args, kwargs, pruning_pars):
80114
if (len(pre_prompt_length_list) != 0 and hidden_states.shape[1] != 1):
81115
v_t = hidden_states[:, v_token_start: text_token_start, :]
82116
t_t = hidden_states[:, text_token_start:, :]
83-
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53]
117+
m_v_t = v_t @ t_t.transpose(1, 2) # [1, 576, 53] # 52?
84118
m_v_t = m_v_t.softmax(2).mean(1) # [1, 53]
85119
pruning_pars['t_token_idx'] = torch.where(m_v_t > m_v_t.mean())
86120

@@ -206,17 +240,31 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
206240

207241
return args, kwargs
208242

209-
self.model.embed_tokens.register_forward_pre_hook(
210-
functools.partial(
211-
input_hook,
212-
pruning_pars=self.model.model.parameters
243+
if self.model.__class__.__name__ == 'LlavaHf':
244+
self.model.embed_tokens.register_forward_pre_hook(
245+
functools.partial(
246+
input_hook,
247+
pruning_pars=self.pruning_paras
248+
)
249+
)
250+
elif self.model.__class__.__name__ == 'Llava':
251+
from llava.constants import IMAGE_TOKEN_INDEX
252+
hook_fn = input_hook_llava(
253+
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
254+
self.pruning_paras
255+
)
256+
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
257+
hook_fn, self.model.vlm_model
213258
)
214-
)
215259

216-
self.model.model.register_forward_pre_hook(
260+
if self.model.__class__.__name__ == 'LlavaHf':
261+
llama_model = self.model.model
262+
elif self.model.__class__.__name__ == 'Llava':
263+
llama_model = self.model.model.model
264+
llama_model.register_forward_pre_hook(
217265
functools.partial(
218266
register_module_pars,
219-
pruning_pars=self.model.model.parameters),
267+
pruning_pars=self.pruning_paras),
220268
with_kwargs=True
221269
)
222270

@@ -228,15 +276,15 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
228276
self.blocks[block_idx].register_forward_pre_hook(
229277
functools.partial(
230278
update_output_attentions_hook,
231-
pruning_pars=self.model.model.parameters,
279+
pruning_pars=self.pruning_paras,
232280
layer_idx=block_idx,
233281
),
234282
with_kwargs=True
235283
)
236284
self.blocks[block_idx].register_forward_hook(
237285
functools.partial(
238286
decoder_attn_hook,
239-
pruning_pars=self.model.model.parameters,
287+
pruning_pars=self.pruning_paras,
240288
layer_idx=block_idx,
241289
),
242290
with_kwargs=True
@@ -245,7 +293,7 @@ def read_parameter_hook(module, args, kwargs, pruning_pars):
245293
self.blocks[block_idx].register_forward_pre_hook(
246294
functools.partial(
247295
read_parameter_hook,
248-
pruning_pars=self.model.model.parameters
296+
pruning_pars=self.pruning_paras
249297
),
250298
with_kwargs=True
251299
)
@@ -278,6 +326,7 @@ def attn_postprocess_topk(
278326
self_attn_weights = self_attn_weights.mean(1) # B, L[Q], L[K]
279327

280328
t_token_idx = t_token_idx[1] + text_token_start
329+
281330
relation_vis_text = self_attn_weights[:, t_token_idx,
282331
v_token_start: v_token_start + v_token_num] # B, L2, L1
283332

llmc/compression/token_reduction/tome.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def add_sparse_config(self):
4343
else:
4444
raise ValueError('Invalid r format. Expected int or (start, step) tuple.')
4545

46-
self.model.model.parameters = special_config
46+
self.pruning_paras = special_config
4747

4848
def patch_layer(self):
4949
for idx, block in enumerate(self.blocks):

0 commit comments

Comments
 (0)