Skip to content

Commit 108bf01

Browse files
authored
update fastervlm (#398)
1 parent 10c7ac3 commit 108bf01

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

llmc/compression/token_reduction/fastervlm.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,41 +62,36 @@ def pruning_hook(module, args, kwargs, pruning_paras):
6262
image_features = args[0]
6363
image_attentions = pruning_paras['image_attentions']
6464

65-
# image_attentions = image_attentions.max(dim=1)[0] # (B, N) = (1, 576)
66-
image_attentions = image_attentions.mean(dim=1) # (B, N) = (1, 576)
67-
68-
B, N = image_features.shape[:2]
65+
B, N, C = image_features.shape
6966
visual_token_num = self.visual_token_num # T
7067

71-
# prune visual tokens by random scores
72-
# token_weights = torch.rand(B, N, device=image_features.device) # (B, N)
73-
# token_indices = torch.topk(token_weights, k=visual_token_num, dim=1)[1] # (B, T)
74-
# token_indices = torch.sort(token_indices, dim=1)[0] # (B, T)
75-
7668
# prune visual tokens by attention scores
69+
image_attentions = image_attentions.mean(dim=1) # (B, N)
7770
token_indices = torch.topk(image_attentions, k=visual_token_num, dim=1)[1] # (B, T)
78-
token_indices = torch.sort(token_indices, dim=1)[0] # (B, T)
7971

8072
# generate index mask
81-
index_mask = torch.zeros(B, N, dtype=torch.bool, device=image_features.device) # (B, N)
82-
index_mask.scatter_(1, token_indices, True) # (B, N)
73+
index_masks = torch.zeros(
74+
B, N,
75+
dtype=torch.bool,
76+
device=image_features.device
77+
) # (B, N)
78+
index_masks.scatter_(1, token_indices, True) # (B, N)
8379

84-
pruning_paras['index_mask'] = index_mask
85-
pruning_paras['image_attentions'] = image_attentions
80+
pruning_paras['index_masks'] = index_masks
8681

8782
return (image_features,), kwargs
8883

8984
def get_image_mask_hook(module, args, kwargs, pruning_paras):
90-
pruning_paras['image_mask'] = (
85+
pruning_paras['image_masks'] = (
9186
kwargs['input_ids'] == pruning_paras['image_token_index']
9287
) # (B, len)
9388

9489
def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras):
9590

9691
# Only batch size 1 is currently supported.
9792
inputs_embeds = kwargs['inputs_embeds']
98-
image_mask = pruning_paras['image_mask'][0]
99-
index_mask = pruning_paras['index_mask'][0]
93+
image_mask = pruning_paras['image_masks'][0]
94+
index_mask = pruning_paras['index_masks'][0]
10095

10196
B, L = inputs_embeds.shape[:2]
10297
device = inputs_embeds.device
@@ -123,7 +118,7 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras):
123118
def prepare_inputs_hook(module, inputs, outputs, pruning_paras):
124119

125120
image_features = outputs
126-
index_masks = pruning_paras['index_mask']
121+
index_masks = pruning_paras['index_masks']
127122
# image_attentions = pruning_paras['image_attentions']
128123
new_image_features = []
129124
for image_feature, index_mask in zip(image_features, index_masks):

0 commit comments

Comments
 (0)