@@ -55,6 +55,10 @@ def update_attentions_hook(m, x, outs, pruning_paras):
5555
5656 def pruning_hook (module , args , kwargs , pruning_paras ):
5757
58+ # for llavahf bs 1
59+ if 'image_attentions' not in pruning_paras :
60+ pruning_paras ['image_attentions' ] = pruning_paras ['image_attentions_list' ][0 ]
61+
5862 image_features = args [0 ]
5963 image_attentions = pruning_paras ['image_attentions' ]
6064
@@ -105,12 +109,8 @@ def prepare_inputs_for_llm_hook(module, args, kwargs, pruning_paras):
105109 keep_indexs = torch .cat ([non_visual_indexs , keep_visual_indexs ]).sort ().values
106110
107111 new_inputs_embeds = kwargs ['inputs_embeds' ][:, keep_indexs , :]
108-
109- new_attention_mask = torch .ones (
110- new_inputs_embeds .shape [:2 ],
111- dtype = kwargs ['attention_mask' ].dtype , device = device
112- )
113- new_position_ids = torch .arange (new_inputs_embeds .shape [1 ], device = device ).unsqueeze (0 )
112+ new_attention_mask = kwargs ['attention_mask' ][:, keep_indexs ]
113+ new_position_ids = kwargs ['position_ids' ][:, keep_indexs ]
114114 new_cache_position = kwargs ['cache_position' ][keep_indexs ]
115115
116116 kwargs ['inputs_embeds' ] = new_inputs_embeds
@@ -173,11 +173,8 @@ def prepare_inputs_hook(module, inputs, outputs, pruning_paras):
173173 functools .partial (get_image_mask_hook , pruning_paras = self .pruning_paras ),
174174 with_kwargs = True
175175 )
176-
177- self .model .model .register_forward_pre_hook (
178- functools .partial (
179- prepare_inputs_for_llm_hook , pruning_paras = self .pruning_paras
180- ),
176+ self .model .model .model .register_forward_pre_hook (
177+ functools .partial (prepare_inputs_for_llm_hook , pruning_paras = self .pruning_paras ),
181178 with_kwargs = True
182179 )
183180 elif self .model .__class__ .__name__ == 'Llava' :
0 commit comments