Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions llmc/compression/token_reduction/fastervlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def update_attentions_hook(m, x, outs, pruning_paras):

def pruning_hook(module, args, kwargs, pruning_paras):

# for llavahf bs 1
if 'image_attentions' not in pruning_paras:
pruning_paras['image_attentions'] = pruning_paras['image_attentions_list'][0]
Comment on lines +59 to +60

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic introduces a potential IndexError. Accessing pruning_paras['image_attentions_list'][0] will raise an exception if the list is empty, which can happen with text-only inputs where no image attentions are generated. A more robust implementation should handle this issue.

if 'image_attentions' not in pruning_paras and pruning_paras.get('image_attentions_list'):
    pruning_paras['image_attentions'] = pruning_paras['image_attentions_list'][-1]


image_features = args[0]
image_attentions = pruning_paras['image_attentions']

Expand Down Expand Up @@ -105,12 +109,8 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras):
keep_indexs = torch.cat([non_visual_indexs, keep_visual_indexs]).sort().values

new_inputs_embeds = kwargs['inputs_embeds'][:, keep_indexs, :]

new_attention_mask = torch.ones(
new_inputs_embeds.shape[:2],
dtype=kwargs['attention_mask'].dtype, device=device
)
new_position_ids = torch.arange(new_inputs_embeds.shape[1], device=device).unsqueeze(0)
new_attention_mask = kwargs['attention_mask'][:, keep_indexs]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using the original attention_mask and position_ids tensors with the keep_indexs to maintain the correct attention patterns and positional information after pruning. This ensures that the model attends to the correct tokens and positions after the visual tokens have been reduced.

new_attention_mask = kwargs['attention_mask'][:, keep_indexs]
new_position_ids = kwargs['position_ids'][:, keep_indexs]

new_position_ids = kwargs['position_ids'][:, keep_indexs]
new_cache_position = kwargs['cache_position'][keep_indexs]

kwargs['inputs_embeds'] = new_inputs_embeds
Expand Down Expand Up @@ -173,11 +173,8 @@ def prepare_inputs_hook(module, inputs, outputs, pruning_paras):
functools.partial(get_image_mask_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)

self.model.model.register_forward_pre_hook(
functools.partial(
prepare_inputs_for_llm_hook, pruning_paras=self.pruning_paras
),
self.model.model.model.register_forward_pre_hook(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider simplifying the nested attribute access self.model.model.model to improve code readability and maintainability. Accessing nested attributes can be fragile and tightly couples the code to the specific internal structure of the LlavaHf and Hugging Face Llama models.

functools.partial(prepare_inputs_for_llm_hook, pruning_paras=self.pruning_paras),
with_kwargs=True
)
elif self.model.__class__.__name__ == 'Llava':
Expand Down