1- import functools
21from functools import wraps
32from types import MethodType
43
76from llmc .utils .registry_factory import TOKEN_REDUCTION_REGISTRY
87
98from .token_reduction_module import TokenReductionModule
10- from .utils import prefill_wrapper
119
1210
1311def pairwise_cosine_similarity (matrix ):
@@ -22,7 +20,7 @@ def divprune(
2220 cosine_matrix = None ,
2321 threshold_ratio = 0.1 ,
2422):
25- threshold_terms = int ( round (threshold_ratio * image_feature_length ) )
23+ threshold_terms = round (threshold_ratio * image_feature_length )
2624 if cosine_matrix is None :
2725 cosine_matrix = 1.0 - (pairwise_cosine_similarity (visual_feature_vectors ))
2826
@@ -53,22 +51,16 @@ def divprune(
5351 return s , cosine_matrix
5452
5553
56- def divprune_post_hook (
57- input_ids ,
58- position_ids ,
59- attention_mask ,
60- past_key_values ,
61- inputs_embeds ,
62- labels ,
63- pruning_paras = None ,
64- ):
65- rate = pruning_paras ['rate' ]
66- SYS_TOKEN_LEN = pruning_paras ['image_token_start_index' ]
67- img_feature_len = pruning_paras ['image_token_length' ]
54+ def divprune_post_hook (* args , pruning_paras = None ):
55+ args = list (args )
56+ position_ids , attention_mask , inputs_embeds = args [1 ], args [2 ], args [4 ]
57+ rate = pruning_paras ['reduction_ratio' ]
58+ SYS_TOKEN_LEN = pruning_paras ['vision_token_start_index' ]
59+ img_feature_len = pruning_paras ['vision_token_length' ]
6860 device = inputs_embeds .device
6961 visual_tokens = inputs_embeds [0 ][SYS_TOKEN_LEN : SYS_TOKEN_LEN + img_feature_len ]
7062 selected_visual_tokens , cosine_matrix = divprune (
71- visual_tokens , img_feature_len , None , threshold_ratio = rate
63+ visual_tokens , img_feature_len , None , threshold_ratio = 1 - rate
7264 )
7365
7466 selected_visual_tokens += SYS_TOKEN_LEN
@@ -83,20 +75,13 @@ def divprune_post_hook(
8375 )
8476 keep_indexs = keep_indexs .sort ().values
8577
86- inputs_embeds = inputs_embeds [:, keep_indexs ]
8778 if position_ids is not None :
88- position_ids = position_ids [:, keep_indexs , :]
79+ args [ 1 ] = position_ids [:, keep_indexs , :]
8980 if attention_mask is not None :
90- attention_mask = attention_mask [:, keep_indexs ]
91-
92- return (
93- input_ids ,
94- position_ids ,
95- attention_mask ,
96- past_key_values ,
97- inputs_embeds ,
98- labels ,
99- )
81+ args [2 ] = attention_mask [:, keep_indexs ]
82+ args [4 ] = inputs_embeds [:, keep_indexs ]
83+
84+ return tuple (args )
10085
10186
10287@TOKEN_REDUCTION_REGISTRY .register ('DivPrune' )
@@ -107,43 +92,34 @@ def __init__(self, config, model, blocks):
10792 self .register_reduction_modules ()
10893
10994 def add_sparse_config (self ):
110- self .special_config ['image_token_length' ] = self .model .pruning_config [
111- 'image_token_length'
112- ]
113-
11495 self .pruning_paras = self .special_config
11596
11697 def register_reduction_modules (self ):
11798
118- def input_hook_llava (fn , pruning_paras ):
99+ def input_hook_llava (fn , pruning_paras , llava_next ):
119100 @wraps (fn )
120101 def wrapper (self , * args , ** kwargs ):
121- if len (args ) == 0 :
122- return fn (* args , ** kwargs )
123- input_args = args [0 ]
124- if hasattr (input_args [0 ], 'shape' ) and input_args [0 ].shape [0 ] == 1 :
102+ if args [0 ].shape [1 ] == 1 :
125103 return fn (* args , ** kwargs )
126-
127- input_ids = args [0 ]
128- attention_mask = args [2 ]
129- token_indices = input_ids [0 ][attention_mask [0 ]] == IMAGE_TOKEN_INDEX
130- pruning_paras ['image_token_start_index' ] = torch .where (token_indices )[
131- 0
132- ][0 ].item ()
133-
134- outputs = fn (* args , ** kwargs )
135-
136- return divprune_post_hook (* outputs , pruning_paras = pruning_paras )
137-
104+ outs = fn (* args , ** kwargs )
105+
106+ if llava_next :
107+ message = (
108+ 'To obtain the vision_token_length for LLaVA-1.6, you should append '
109+ '`image_features[0].shape[0]` to the return value of the function '
110+ '`prepare_inputs_labels_for_multimodal`, and modify the related code.'
111+ )
112+ assert len (outs ) == 7 , message
113+ pruning_paras ['vision_token_length' ] = outs [- 1 ]
114+ return divprune_post_hook (* outs , pruning_paras = pruning_paras )
138115 return wrapper
139116
140117 if self .model .__class__ .__name__ == 'Llava' :
141- from llava .constants import IMAGE_TOKEN_INDEX
142118
143- hook_fn = input_hook_llava (
144- self .model .vlm_model .prepare_inputs_labels_for_multimodal ,
145- self .pruning_paras ,
146- )
147119 self .model .vlm_model .prepare_inputs_labels_for_multimodal = MethodType (
148- hook_fn , self .model .vlm_model
120+ input_hook_llava (
121+ self .model .vlm_model .prepare_inputs_labels_for_multimodal ,
122+ self .pruning_paras ,
123+ llava_next = self .special_config ['vision_token_length' ] is None
124+ ), self .model .vlm_model
149125 )
0 commit comments