diff --git a/llmc/compression/token_reduction/fastervlm.py b/llmc/compression/token_reduction/fastervlm.py index 635595b5..3a4ff166 100644 --- a/llmc/compression/token_reduction/fastervlm.py +++ b/llmc/compression/token_reduction/fastervlm.py @@ -55,6 +55,10 @@ def update_attentions_hook(m, x, outs, pruning_paras): def pruning_hook(module, args, kwargs, pruning_paras): + # for llavahf bs 1 + if 'image_attentions' not in pruning_paras: + pruning_paras['image_attentions'] = pruning_paras['image_attentions_list'][0] + image_features = args[0] image_attentions = pruning_paras['image_attentions'] @@ -105,12 +109,8 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras): keep_indexs = torch.cat([non_visual_indexs, keep_visual_indexs]).sort().values new_inputs_embeds = kwargs['inputs_embeds'][:, keep_indexs, :] - - new_attention_mask = torch.ones( - new_inputs_embeds.shape[:2], - dtype=kwargs['attention_mask'].dtype, device=device - ) - new_position_ids = torch.arange(new_inputs_embeds.shape[1], device=device).unsqueeze(0) + new_attention_mask = kwargs['attention_mask'][:, keep_indexs] + new_position_ids = kwargs['position_ids'][:, keep_indexs] new_cache_position = kwargs['cache_position'][keep_indexs] kwargs['inputs_embeds'] = new_inputs_embeds @@ -173,11 +173,8 @@ def prepare_inputs_hook(module, inputs, outputs, pruning_paras): functools.partial(get_image_mask_hook, pruning_paras=self.pruning_paras), with_kwargs=True ) - - self.model.model.register_forward_pre_hook( - functools.partial( - prepare_inputs_for_llm_hook, pruning_paras=self.pruning_paras - ), + self.model.model.model.register_forward_pre_hook( + functools.partial(prepare_inputs_for_llm_hook, pruning_paras=self.pruning_paras), with_kwargs=True ) elif self.model.__class__.__name__ == 'Llava':