1212from llmc .utils .registry_factory import TOKEN_REDUCTION_REGISTRY
1313
1414from .token_reduction_module import TokenReductionModule
15- from .utils import apply_info , prefill_wrapper
15+ from .utils import (apply_info , prefill_wrapper ,
16+ prepare_inputs_labels_for_multimodal_with_index_masks )
1617
1718
1819def visionzip_forward (
@@ -286,15 +287,19 @@ def __init__(self, config, model, blocks):
286287 self .register_reduction_modules ()
287288
288289 def add_sparse_config (self ):
289- special_config = self .config .get ('special' , {})
290- self .dominant = special_config ['dominant' ]
291- self .contextual = special_config ['contextual' ]
290+ self .dominant = self .special_config ['dominant' ]
291+ self .contextual = self .special_config ['contextual' ]
292292
293- self .pruning_paras = special_config
293+ self .pruning_paras = self .special_config
294+ prune_only = self .special_config .get ('prune_only' , False )
295+ merge_only = self .special_config .get ('merge_only' , False )
296+ assert not (prune_only and merge_only ), 'prune_only and merge_only cannot both be True'
297+ self .pruning_paras ['prune_only' ] = prune_only
298+ self .pruning_paras ['merge_only' ] = merge_only
294299
295300 def register_reduction_modules (self ):
296301
297- def visionzip_hook (m , images , image_forward_outs ):
302+ def visionzip_hook (m , images , image_forward_outs , pruning_paras , llava_next ):
298303 attn_weights = image_forward_outs .attentions [- 2 ]
299304 hidden_states = image_forward_outs .hidden_states [- 2 ]
300305 metric = self .blocks [- 2 ].self_attn .k_proj .metric
@@ -306,17 +311,22 @@ def visionzip_hook(m, images, image_forward_outs):
306311 cls_attention = attn_weights [:, :, cls_idx , cls_idx + 1 :]
307312 cls_attention_sum = cls_attention .sum (dim = 1 )
308313 topk_indices = cls_attention_sum .topk (dominant_num , dim = 1 ).indices + 1
309- all_indices = torch .cat (
310- [
311- torch .zeros (
312- (hidden_states .shape [0 ], 1 ),
313- dtype = topk_indices .dtype ,
314- device = topk_indices .device ,
315- ),
316- topk_indices ,
317- ],
318- dim = 1 ,
319- )
314+ if pruning_paras ['merge_only' ]:
315+ all_indices = torch .zeros (
316+ (hidden_states .shape [0 ], 1 ),
317+ dtype = topk_indices .dtype , device = topk_indices .device
318+ )
319+ dominant_num = 0
320+ else :
321+ all_indices = torch .cat (
322+ [
323+ torch .zeros (
324+ (hidden_states .shape [0 ], 1 ),
325+ dtype = topk_indices .dtype , device = topk_indices .device ,
326+ ),
327+ topk_indices ,
328+ ], dim = 1 ,
329+ )
320330
321331 mask = torch .ones_like (
322332 hidden_states [:, :, 0 ], dtype = torch .bool , device = metric .device
@@ -355,6 +365,15 @@ def visionzip_hook(m, images, image_forward_outs):
355365 target_indices = torch .arange (
356366 0 , metric_normalized .shape [1 ], step , device = metric_normalized .device
357367 )[:contextual_num ]
368+
369+ # keep_idxs
370+ index_masks = ~ mask
371+ if not pruning_paras ['prune_only' ]:
372+ pruned_indices = mask .nonzero (as_tuple = False )[:, 1 ].view (hidden_states .shape [0 ], - 1 )
373+ target_index = pruned_indices [:, target_indices ]
374+ index_masks .scatter_ (1 , target_index , True )
375+ pruning_paras ['index_masks' ] = index_masks [:, 1 :]
376+
358377 target_tokens = metric_normalized [:, target_indices , :]
359378
360379 tokens_to_merge = metric_normalized [
@@ -401,9 +420,15 @@ def visionzip_hook(m, images, image_forward_outs):
401420 ).to (images [0 ].dtype )
402421
403422 res = list (image_forward_outs .hidden_states )
404- res [- 2 ] = hidden_states_save .contiguous ()
423+ if not llava_next :
424+ if pruning_paras ['prune_only' ]:
425+ res [- 2 ] = dominant_tokens .contiguous ().to (images [0 ].dtype )
426+ else :
427+ res [- 2 ] = hidden_states_save .contiguous ()
405428 image_forward_outs .hidden_states = tuple (res )
406429
430+ return image_forward_outs
431+
407432 def store_key_hook (m , x , outputs ):
408433 bsz = x [0 ].shape [0 ]
409434 raw_outputs = (
@@ -418,10 +443,13 @@ def update_output_attentions_hook(module, args, kwargs):
418443 kwargs ['output_attentions' ] = True
419444 return args , kwargs
420445
446+ def update_index_masks_hook (module , inps , outs , pruning_paras ):
447+ module .index_masks = pruning_paras ['index_masks' ]
448+
421449 if self .model .__class__ .__name__ == 'LlavaHf' :
422450 vision_tower = self .model .vlm_model .vision_tower
423451 elif self .model .__class__ .__name__ == 'Llava' :
424- vision_tower = self .model .vlm_model . model . vision_tower .vision_tower
452+ vision_tower = self .model .vision_model .vision_tower
425453
426454 if self .model .__class__ .__name__ in ('LlavaHf' , 'Llava' ):
427455 apply_info (
@@ -444,7 +472,25 @@ def update_output_attentions_hook(module, args, kwargs):
444472 block .self_attn .k_proj .head_dim = block .self_attn .head_dim
445473 block .self_attn .k_proj .register_forward_hook (store_key_hook )
446474
447- vision_tower .register_forward_hook (visionzip_hook )
475+ vision_tower .register_forward_hook (
476+ functools .partial (
477+ visionzip_hook ,
478+ pruning_paras = self .pruning_paras ,
479+ llava_next = self .special_config ['vision_token_length' ] is None
480+ )
481+ )
482+
483+ # llava_next
484+ if self .special_config ['vision_token_length' ] is None :
485+
486+ self .model .vlm_model .prepare_inputs_labels_for_multimodal = MethodType (
487+ prepare_inputs_labels_for_multimodal_with_index_masks ,
488+ self .model .vlm_model
489+ )
490+
491+ self .model .vision_model .register_forward_hook (
492+ functools .partial (update_index_masks_hook , pruning_paras = self .pruning_paras ),
493+ )
448494
449495 def get_metric (fn , pruning_paras ):
450496 @wraps (fn )
0 commit comments