11import functools
22import math
3- from functools import wraps
4- from types import MethodType
53
64import torch
75
@@ -19,95 +17,43 @@ def __init__(self, config, model, blocks):
1917 self .register_reduction_modules ()
2018
2119 def add_sparse_config (self ):
22-
2320 self .pruning_loc = self .special_config ['pruning_loc' ]
24- self .special_config ['image_token_length' ] = \
25- self .model .pruning_config ['image_token_length' ]
26- self .special_config ['IMAGE_TOKEN_INDEX' ] = \
27- self .model .pruning_config ['IMAGE_TOKEN_INDEX' ]
2821
2922 self .pruning_paras = self .special_config
3023
3124 def register_reduction_modules (self ):
3225
33- def input_hook_llava (fn , pruning_paras ):
34- @wraps (fn )
35- def wrapper (self , * args , ** kwargs ):
36- if len (args ) == 0 :
37- return fn (* args , ** kwargs )
38- input_args = args [0 ]
39- if hasattr (input_args [0 ], 'shape' ) and input_args [0 ].shape [0 ] == 1 :
40- return fn (* args , ** kwargs )
41-
42- input_ids = args [0 ]
43- attention_mask = args [2 ]
44- token_indices = (
45- input_ids [0 ][attention_mask [0 ]] == pruning_paras ['IMAGE_TOKEN_INDEX' ]
46- )
47- pruning_paras ['image_token_start_index' ] = torch .where (token_indices )[0 ][0 ].item ()
26+ @prefill_wrapper
27+ def vtoken_length_hook (module , input_args , pruning_paras ):
4828
49- outputs = fn (* args , ** kwargs )
50- return outputs
51- return wrapper
29+ input_ids = input_args [0 ]
30+ token_indices = torch .where (
31+ input_ids [0 ] == pruning_paras ['vision_token_index' ]
32+ )[0 ]
33+ pruning_paras ['vision_token_length' ] = token_indices .shape [0 ]
5234
53- def get_seq_len_hook (module , args , kwargs , pruning_paras ):
54- if kwargs ['input_ids' ] is not None :
55- pruning_paras ['seq_len' ] = kwargs ['input_ids' ].shape [1 ]
56- elif kwargs ['inputs_embeds' ] is not None :
57- pruning_paras ['seq_len' ] = kwargs ['inputs_embeds' ].shape [1 ]
58- else :
59- raise ValueError ('You have to specify either input_ids or inputs_embeds' )
35+ return input_args
6036
37+ @prefill_wrapper
6138 def get_any_states_hook (module , args , kwargs , layer_outs , pruning_paras , layer_idx ):
62- from transformers .models .llama .modeling_llama import (
63- apply_rotary_pos_emb , repeat_kv )
64- if len (kwargs ['position_ids' ][0 ]) == 1 :
65- return layer_outs
6639
67- hidden_states = kwargs ['hidden_states' ]
68- position_embeddings = kwargs ['position_embeddings' ]
69- position_ids = kwargs ['position_ids' ]
70- past_key_value = layer_outs [2 ]
71-
72- bsz , q_len , _ = hidden_states .size ()
73- query_states = module .q_proj (hidden_states )
74- key_states = module .k_proj (hidden_states )
75- value_states = module .v_proj (hidden_states )
76- query_states = query_states .view (
77- bsz , q_len , module .num_heads , module .head_dim
78- ).transpose (1 , 2 )
79- key_states = key_states .view (
80- bsz , q_len , module .num_key_value_heads , module .head_dim
81- ).transpose (1 , 2 )
82- value_states = value_states .view (
83- bsz , q_len , module .num_key_value_heads , module .head_dim
84- ).transpose (1 , 2 )
85-
86- if position_embeddings is None :
87- cos , sin = module .rotary_emb (value_states , position_ids )
88- else :
89- cos , sin = position_embeddings
90- query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
91- if past_key_value is not None :
92- key_states = past_key_value .key_cache [layer_idx ]
93- value_states = past_key_value .value_cache [layer_idx ]
94- key_states = repeat_kv (key_states , module .num_key_value_groups )
95- value_states = repeat_kv (value_states , module .num_key_value_groups )
96-
97- pruning_paras ['any_states' ] = (query_states , key_states , value_states )
40+ past_key_value = kwargs ['past_key_value' ]
41+ if past_key_value is None :
42+ raise ValueError ('DART needs past_key_value but got None.' )
43+ pruning_paras ['any_states' ] = past_key_value .key_cache [layer_idx ]
9844
9945 return layer_outs
10046
10147 @prefill_wrapper
10248 def pruning_hook (module , args , kwargs , pruning_paras , normlayer ):
10349
104- image_token_start_index = pruning_paras ['image_token_start_index' ]
105- image_token_length = pruning_paras ['image_token_length' ]
106- any_states = pruning_paras ['any_states' ][- 2 ]
107- seq_length = pruning_paras ['seq_len' ]
50+ image_token_start_index = pruning_paras ['vision_token_start_index' ]
51+ image_token_length = pruning_paras ['vision_token_length' ]
52+ any_states = pruning_paras ['any_states' ]
10853
10954 hidden_states = args [0 ]
11055 attention_mask = kwargs ['attention_mask' ]
56+ seq_length = hidden_states .shape [1 ]
11157 device = hidden_states .device
11258 last_layer_state = normlayer (hidden_states )
11359
@@ -140,27 +86,20 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
14086 kwargs ['position_ids' ].resize_as_ (position_ids ).copy_ (position_ids .clone ())
14187
14288 position_embeddings = kwargs ['position_embeddings' ]
143- new_pe0 = position_embeddings [0 ][:, keep_indexs , :].clone ()
144- new_pe1 = position_embeddings [1 ][:, keep_indexs , :].clone ()
89+ index_dim = 1 if position_embeddings [0 ].dim () == 3 else 2
90+ new_pe0 = position_embeddings [0 ].index_select (index_dim , keep_indexs ).clone ()
91+ new_pe1 = position_embeddings [1 ].index_select (index_dim , keep_indexs ).clone ()
14592 position_embeddings [0 ].resize_as_ (new_pe0 ).copy_ (new_pe0 )
14693 position_embeddings [1 ].resize_as_ (new_pe0 ).copy_ (new_pe1 )
14794
14895 return (hidden_states ,), kwargs
14996
150- hook_fn = input_hook_llava (
151- self .model .vlm_model .prepare_inputs_labels_for_multimodal ,
152- self .pruning_paras
153- )
154- self .model .vlm_model .prepare_inputs_labels_for_multimodal = MethodType (
155- hook_fn , self .model .vlm_model
156- )
157-
158- self .model .model .model .register_forward_pre_hook (
159- functools .partial (get_seq_len_hook , pruning_paras = self .pruning_paras ),
160- with_kwargs = True
161- )
97+ if self .special_config ['vision_token_length' ] is None :
98+ self .model .embed_tokens .register_forward_pre_hook (
99+ functools .partial (vtoken_length_hook , pruning_paras = self .pruning_paras )
100+ )
162101
163- self .blocks [self .pruning_loc - 1 ].self_attn . register_forward_hook (
102+ self .blocks [self .pruning_loc - 1 ].register_forward_hook (
164103 functools .partial (
165104 get_any_states_hook ,
166105 pruning_paras = self .pruning_paras ,
@@ -173,24 +112,21 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
173112 functools .partial (
174113 pruning_hook ,
175114 pruning_paras = self .pruning_paras ,
176- normlayer = self .model .model . model .norm
115+ normlayer = self .model .language_model .norm
177116 ),
178117 with_kwargs = True
179118 )
180119
181120
182121def get_retained_image_token (pruning_paras , last_layer_state , any_states ):
183- image_token_start_index = pruning_paras ['image_token_start_index' ]
184- image_token_length = pruning_paras ['image_token_length' ]
185- MAX_NUM_TRUNCTION = pruning_paras ['max_num_trunction' ]
122+ image_token_start_index = pruning_paras ['vision_token_start_index' ]
123+ image_token_length = pruning_paras ['vision_token_length' ]
186124 pivot_image_token = pruning_paras ['pivot_image_token' ]
187125 pivot_text_token = pruning_paras ['pivot_text_token' ]
188126 reduction_ratio = pruning_paras ['reduction_ratio' ]
189- TOKEN_TOPK = math .ceil (
190- (
191- MAX_NUM_TRUNCTION if MAX_NUM_TRUNCTION is not None
192- else (image_token_length * (1 - reduction_ratio ))
193- ) // (pivot_image_token + pivot_text_token ))
127+ TOKEN_TOPK = int (
128+ image_token_length * (1 - reduction_ratio ) / (pivot_image_token + pivot_text_token )
129+ )
194130 device = last_layer_state .device
195131
196132 any_states = any_states .permute (0 , 2 , 1 , 3 )
0 commit comments