1010from llmc .utils .registry_factory import TOKEN_REDUCTION_REGISTRY
1111
1212from .token_reduction_module import TokenReductionModule
13+ from .utils import prefill_wrapper
1314
1415
1516@TOKEN_REDUCTION_REGISTRY .register ('PyramidDrop' )
@@ -20,38 +21,21 @@ def __init__(self, config, model, blocks):
2021 self .register_reduction_modules ()
2122
2223 def add_sparse_config (self ):
23- special_config = self . config . get ( 'special' , {})
24- self .pruning_loc = special_config ['layer_list' ]
25- image_token_ratio_list = special_config ['image_token_ratio_list' ]
24+
25+ self .pruning_loc = self . special_config ['layer_list' ]
26+ image_token_ratio_list = self . special_config ['image_token_ratio_list' ]
2627 image_token_ratio_list .insert (0 , 1.0 )
27- special_config ['image_token_ratio_list' ] = image_token_ratio_list
28- special_config ['tokenizer_padding_side' ] = getattr (
28+ self . special_config ['image_token_ratio_list' ] = image_token_ratio_list
29+ self . special_config ['tokenizer_padding_side' ] = getattr (
2930 self .model .vlm_model .language_model .model .config ,
3031 'tokenizer_padding_side' ,
3132 'right' ,
3233 )
33- special_config ['is_video_model' ] = self .model .pruning_config ['is_video_model' ]
34-
35- # vision_token can be image or video
36- if special_config ['is_video_model' ]:
37- special_config ['vision_token_index' ] = self .model .pruning_config [
38- 'video_token_index'
39- ]
40- special_config ['vision_token_length' ] = self .model .pruning_config [
41- 'video_token_length'
42- ]
43- else :
44- special_config ['vision_token_index' ] = self .model .pruning_config [
45- 'image_token_index'
46- ]
47- special_config ['vision_token_length' ] = self .model .pruning_config [
48- 'image_token_length'
49- ]
50-
51- self .model .model .parameters = special_config
5234
53- def register_reduction_modules ( self ):
35+ self . model . model . parameters = self . special_config
5436
37+ def register_reduction_modules (self ):
38+ @prefill_wrapper
5539 def pruning_hook (module , args , kwargs , pruning_pars , cur_num , layer_idx ):
5640
5741 if layer_idx == self .pruning_loc [0 ]:
@@ -315,10 +299,9 @@ def pruning_hook(module, args, kwargs, pruning_pars, cur_num, layer_idx):
315299
316300 return (new_input_embeds ,), kwargs
317301
302+ @prefill_wrapper
318303 def input_hook (module , input_args , pruning_pars ):
319- # for the decoding stage
320- if input_args [0 ].shape [1 ] == 1 :
321- return input_args
304+
322305 input_ids = input_args [0 ]
323306 pre_prompt_length_list = []
324307 image_token_posi = []
@@ -338,9 +321,8 @@ def input_hook(module, input_args, pruning_pars):
338321
339322 return input_args
340323
324+ @prefill_wrapper
341325 def read_parameter_hook (module , args , kwargs , pruning_pars ):
342- if args [0 ].shape [1 ] == 1 :
343- return args , kwargs
344326 kwargs ['attention_mask' ] = pruning_pars ['attention_mask' ]
345327 # kwargs['cache_position'] = pruning_pars['cache_position']
346328 kwargs ['position_ids' ] = pruning_pars ['position_ids' ]
0 commit comments