Skip to content

Commit 10c7ac3

Browse files
authored
fix bug: fastervlm for llavahf (#397)
1 parent 56afa2d commit 10c7ac3

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

llmc/compression/token_reduction/fastervlm.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def update_attentions_hook(m, x, outs, pruning_paras):
5555

5656
def pruning_hook(module, args, kwargs, pruning_paras):
5757

58+
# for llavahf bs 1
59+
if 'image_attentions' not in pruning_paras:
60+
pruning_paras['image_attentions'] = pruning_paras['image_attentions_list'][0]
61+
5862
image_features = args[0]
5963
image_attentions = pruning_paras['image_attentions']
6064

@@ -105,12 +109,8 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras):
105109
keep_indexs = torch.cat([non_visual_indexs, keep_visual_indexs]).sort().values
106110

107111
new_inputs_embeds = kwargs['inputs_embeds'][:, keep_indexs, :]
108-
109-
new_attention_mask = torch.ones(
110-
new_inputs_embeds.shape[:2],
111-
dtype=kwargs['attention_mask'].dtype, device=device
112-
)
113-
new_position_ids = torch.arange(new_inputs_embeds.shape[1], device=device).unsqueeze(0)
112+
new_attention_mask = kwargs['attention_mask'][:, keep_indexs]
113+
new_position_ids = kwargs['position_ids'][:, keep_indexs]
114114
new_cache_position = kwargs['cache_position'][keep_indexs]
115115

116116
kwargs['inputs_embeds'] = new_inputs_embeds
@@ -173,11 +173,8 @@ def prepare_inputs_hook(module, inputs, outputs, pruning_paras):
173173
functools.partial(get_image_mask_hook, pruning_paras=self.pruning_paras),
174174
with_kwargs=True
175175
)
176-
177-
self.model.model.register_forward_pre_hook(
178-
functools.partial(
179-
prepare_inputs_for_llm_hook, pruning_paras=self.pruning_paras
180-
),
176+
self.model.model.model.register_forward_pre_hook(
177+
functools.partial(prepare_inputs_for_llm_hook, pruning_paras=self.pruning_paras),
181178
with_kwargs=True
182179
)
183180
elif self.model.__class__.__name__ == 'Llava':

0 commit comments

Comments
 (0)