diff --git a/llmc/compression/token_reduction/fastervlm.py b/llmc/compression/token_reduction/fastervlm.py index 3a4ff166..691815be 100644 --- a/llmc/compression/token_reduction/fastervlm.py +++ b/llmc/compression/token_reduction/fastervlm.py @@ -62,32 +62,27 @@ def pruning_hook(module, args, kwargs, pruning_paras): image_features = args[0] image_attentions = pruning_paras['image_attentions'] - # image_attentions = image_attentions.max(dim=1)[0] # (B, N) = (1, 576) - image_attentions = image_attentions.mean(dim=1) # (B, N) = (1, 576) - - B, N = image_features.shape[:2] + B, N, C = image_features.shape visual_token_num = self.visual_token_num # T - # prune visual tokens by random scores - # token_weights = torch.rand(B, N, device=image_features.device) # (B, N) - # token_indices = torch.topk(token_weights, k=visual_token_num, dim=1)[1] # (B, T) - # token_indices = torch.sort(token_indices, dim=1)[0] # (B, T) - # prune visual tokens by attention scores + image_attentions = image_attentions.mean(dim=1) # (B, N) token_indices = torch.topk(image_attentions, k=visual_token_num, dim=1)[1] # (B, T) - token_indices = torch.sort(token_indices, dim=1)[0] # (B, T) # generate index mask - index_mask = torch.zeros(B, N, dtype=torch.bool, device=image_features.device) # (B, N) - index_mask.scatter_(1, token_indices, True) # (B, N) + index_masks = torch.zeros( + B, N, + dtype=torch.bool, + device=image_features.device + ) # (B, N) + index_masks.scatter_(1, token_indices, True) # (B, N) - pruning_paras['index_mask'] = index_mask - pruning_paras['image_attentions'] = image_attentions + pruning_paras['index_masks'] = index_masks return (image_features,), kwargs def get_image_mask_hook(module, args, kwargs, pruning_paras): - pruning_paras['image_mask'] = ( + pruning_paras['image_masks'] = ( kwargs['input_ids'] == pruning_paras['image_token_index'] ) # (B, len) @@ -95,8 +90,8 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras): # Only batch size 1 is currently supported. inputs_embeds = kwargs['inputs_embeds'] - image_mask = pruning_paras['image_mask'][0] - index_mask = pruning_paras['index_mask'][0] + image_mask = pruning_paras['image_masks'][0] + index_mask = pruning_paras['index_masks'][0] B, L = inputs_embeds.shape[:2] device = inputs_embeds.device @@ -123,7 +118,7 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras): def prepare_inputs_hook(module, inputs, outputs, pruning_paras): image_features = outputs - index_masks = pruning_paras['index_mask'] + index_masks = pruning_paras['index_masks'] # image_attentions = pruning_paras['image_attentions'] new_image_features = [] for image_feature, index_mask in zip(image_features, index_masks):