Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions llmc/compression/token_reduction/fastervlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,41 +62,36 @@ def pruning_hook(module, args, kwargs, pruning_paras):
image_features = args[0]
image_attentions = pruning_paras['image_attentions']

# image_attentions = image_attentions.max(dim=1)[0] # (B, N) = (1, 576)
image_attentions = image_attentions.mean(dim=1) # (B, N) = (1, 576)

B, N = image_features.shape[:2]
B, N, C = image_features.shape

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Changing image_features.shape[:2] to image_features.shape implies that image_features is now strictly expected to be a 3-dimensional tensor. If image_features could sometimes be 2-dimensional, this change would lead to a ValueError during unpacking. Please confirm that image_features will always have 3 dimensions (B, N, C) at this point in the execution.

visual_token_num = self.visual_token_num # T

# prune visual tokens by random scores
# token_weights = torch.rand(B, N, device=image_features.device) # (B, N)
# token_indices = torch.topk(token_weights, k=visual_token_num, dim=1)[1] # (B, T)
# token_indices = torch.sort(token_indices, dim=1)[0] # (B, T)

# prune visual tokens by attention scores
image_attentions = image_attentions.mean(dim=1) # (B, N)
token_indices = torch.topk(image_attentions, k=visual_token_num, dim=1)[1] # (B, T)
token_indices = torch.sort(token_indices, dim=1)[0] # (B, T)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removal of torch.sort(token_indices, dim=1)[0] means that the token_indices will no longer be sorted. While the scatter_ operation itself does not strictly require sorted indices, if any downstream logic implicitly relied on these indices being in a specific order (e.g., for consistency or debugging), this change could introduce subtle behavioral differences. Please ensure that the order of token_indices does not matter for any subsequent operations.

# generate index mask
index_mask = torch.zeros(B, N, dtype=torch.bool, device=image_features.device) # (B, N)
index_mask.scatter_(1, token_indices, True) # (B, N)
index_masks = torch.zeros(
B, N,
dtype=torch.bool,
device=image_features.device
) # (B, N)
index_masks.scatter_(1, token_indices, True) # (B, N)

pruning_paras['index_mask'] = index_mask
pruning_paras['image_attentions'] = image_attentions
pruning_paras['index_masks'] = index_masks

Comment on lines +80 to 81

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The processed image_attentions (after mean(dim=1)) was previously stored back into pruning_paras['image_attentions']. This line has been removed. If any other part of the system or subsequent hooks rely on pruning_paras['image_attentions'] containing the mean-reduced attention, this could lead to incorrect behavior. If this processed value is not used elsewhere, consider adding a comment to clarify its scope.

return (image_features,), kwargs

def get_image_mask_hook(module, args, kwargs, pruning_paras):
pruning_paras['image_mask'] = (
pruning_paras['image_masks'] = (
kwargs['input_ids'] == pruning_paras['image_token_index']
) # (B, len)

def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras):

# Only batch size 1 is currently supported.
inputs_embeds = kwargs['inputs_embeds']
image_mask = pruning_paras['image_mask'][0]
index_mask = pruning_paras['index_mask'][0]
image_mask = pruning_paras['image_masks'][0]
index_mask = pruning_paras['index_masks'][0]

B, L = inputs_embeds.shape[:2]
device = inputs_embeds.device
Expand All @@ -123,7 +118,7 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras):
def prepare_inputs_hook(module, inputs, outputs, pruning_paras):

image_features = outputs
index_masks = pruning_paras['index_mask']
index_masks = pruning_paras['index_masks']
# image_attentions = pruning_paras['image_attentions']
new_image_features = []
for image_feature, index_mask in zip(image_features, index_masks):
Expand Down