@@ -62,41 +62,36 @@ def pruning_hook(module, args, kwargs, pruning_paras):
6262 image_features = args [0 ]
6363 image_attentions = pruning_paras ['image_attentions' ]
6464
65- # image_attentions = image_attentions.max(dim=1)[0] # (B, N) = (1, 576)
66- image_attentions = image_attentions .mean (dim = 1 ) # (B, N) = (1, 576)
67-
68- B , N = image_features .shape [:2 ]
65+ B , N , C = image_features .shape
6966 visual_token_num = self .visual_token_num # T
7067
71- # prune visual tokens by random scores
72- # token_weights = torch.rand(B, N, device=image_features.device) # (B, N)
73- # token_indices = torch.topk(token_weights, k=visual_token_num, dim=1)[1] # (B, T)
74- # token_indices = torch.sort(token_indices, dim=1)[0] # (B, T)
75-
7668 # prune visual tokens by attention scores
69+ image_attentions = image_attentions .mean (dim = 1 ) # (B, N)
7770 token_indices = torch .topk (image_attentions , k = visual_token_num , dim = 1 )[1 ] # (B, T)
78- token_indices = torch .sort (token_indices , dim = 1 )[0 ] # (B, T)
7971
8072 # generate index mask
81- index_mask = torch .zeros (B , N , dtype = torch .bool , device = image_features .device ) # (B, N)
82- index_mask .scatter_ (1 , token_indices , True ) # (B, N)
73+ index_masks = torch .zeros (
74+ B , N ,
75+ dtype = torch .bool ,
76+ device = image_features .device
77+ ) # (B, N)
78+ index_masks .scatter_ (1 , token_indices , True ) # (B, N)
8379
84- pruning_paras ['index_mask' ] = index_mask
85- pruning_paras ['image_attentions' ] = image_attentions
80+ pruning_paras ['index_masks' ] = index_masks
8681
8782 return (image_features ,), kwargs
8883
8984 def get_image_mask_hook (module , args , kwargs , pruning_paras ):
90- pruning_paras ['image_mask ' ] = (
85+ pruning_paras ['image_masks ' ] = (
9186 kwargs ['input_ids' ] == pruning_paras ['image_token_index' ]
9287 ) # (B, len)
9388
9489 def prepare_inputs_for_llm_hook (module , args , kwargs , pruning_paras ):
9590
9691 # Only batch size 1 is currently supported.
9792 inputs_embeds = kwargs ['inputs_embeds' ]
98- image_mask = pruning_paras ['image_mask ' ][0 ]
99- index_mask = pruning_paras ['index_mask ' ][0 ]
93+ image_mask = pruning_paras ['image_masks ' ][0 ]
94+ index_mask = pruning_paras ['index_masks ' ][0 ]
10095
10196 B , L = inputs_embeds .shape [:2 ]
10297 device = inputs_embeds .device
@@ -123,7 +118,7 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras):
123118 def prepare_inputs_hook (module , inputs , outputs , pruning_paras ):
124119
125120 image_features = outputs
126- index_masks = pruning_paras ['index_mask ' ]
121+ index_masks = pruning_paras ['index_masks ' ]
127122 # image_attentions = pruning_paras['image_attentions']
128123 new_image_features = []
129124 for image_feature , index_mask in zip (image_features , index_masks ):
0 commit comments