diff --git a/configs/sparsification/methods/DART/dart.yml b/configs/sparsification/methods/DART/dart.yml index 81c8b08c..426f256a 100644 --- a/configs/sparsification/methods/DART/dart.yml +++ b/configs/sparsification/methods/DART/dart.yml @@ -16,9 +16,8 @@ sparse: method: TokenReduction special: method: DART - pruning_loc: 2 + pruning_loc: 5 reduction_ratio: 0.778 - max_num_trunction: 128 pivot_image_token: 4 pivot_text_token : 4 save: diff --git a/configs/sparsification/methods/FastV/fastv.yml b/configs/sparsification/methods/FastV/fastv.yml index b8968ea7..4253eb00 100644 --- a/configs/sparsification/methods/FastV/fastv.yml +++ b/configs/sparsification/methods/FastV/fastv.yml @@ -17,7 +17,7 @@ sparse: special: method: FastV pruning_loc: 3 - rate: 0.778 + rate: 0.778 # prune_rate save: save_trans: False save_fake: False diff --git a/configs/sparsification/methods/VisionZip/visionzip.yml b/configs/sparsification/methods/VisionZip/visionzip.yml index ff639f4a..76d1fec8 100644 --- a/configs/sparsification/methods/VisionZip/visionzip.yml +++ b/configs/sparsification/methods/VisionZip/visionzip.yml @@ -16,7 +16,7 @@ sparse: vision: method: TokenReduction special: - method: VisionZip + method: VisionZip # retain dominant: 191 # visual_tokens = dominan_tokens + 1(cls_token) contextual: 30 save: diff --git a/llmc/compression/token_reduction/dart.py b/llmc/compression/token_reduction/dart.py index 68c7a416..f237282d 100644 --- a/llmc/compression/token_reduction/dart.py +++ b/llmc/compression/token_reduction/dart.py @@ -1,7 +1,5 @@ import functools import math -from functools import wraps -from types import MethodType import torch @@ -19,95 +17,43 @@ def __init__(self, config, model, blocks): self.register_reduction_modules() def add_sparse_config(self): - self.pruning_loc = self.special_config['pruning_loc'] - self.special_config['image_token_length'] = \ - self.model.pruning_config['image_token_length'] - self.special_config['IMAGE_TOKEN_INDEX'] = \ - self.model.pruning_config['IMAGE_TOKEN_INDEX'] self.pruning_paras = self.special_config def register_reduction_modules(self): - def input_hook_llava(fn, pruning_paras): - @wraps(fn) - def wrapper(self, *args, **kwargs): - if len(args) == 0: - return fn(*args, **kwargs) - input_args = args[0] - if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: - return fn(*args, **kwargs) - - input_ids = args[0] - attention_mask = args[2] - token_indices = ( - input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX'] - ) - pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item() + @prefill_wrapper + def vtoken_length_hook(module, input_args, pruning_paras): - outputs = fn(*args, **kwargs) - return outputs - return wrapper + input_ids = input_args[0] + token_indices = torch.where( + input_ids[0] == pruning_paras['vision_token_index'] + )[0] + pruning_paras['vision_token_length'] = token_indices.shape[0] - def get_seq_len_hook(module, args, kwargs, pruning_paras): - if kwargs['input_ids'] is not None: - pruning_paras['seq_len'] = kwargs['input_ids'].shape[1] - elif kwargs['inputs_embeds'] is not None: - pruning_paras['seq_len'] = kwargs['inputs_embeds'].shape[1] - else: - raise ValueError('You have to specify either input_ids or inputs_embeds') + return input_args + @prefill_wrapper def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_idx): - from transformers.models.llama.modeling_llama import ( - apply_rotary_pos_emb, repeat_kv) - if len(kwargs['position_ids'][0]) == 1: - return layer_outs - hidden_states = kwargs['hidden_states'] - position_embeddings = kwargs['position_embeddings'] - position_ids = kwargs['position_ids'] - past_key_value = layer_outs[2] - - bsz, q_len, _ = hidden_states.size() - query_states = module.q_proj(hidden_states) - key_states = module.k_proj(hidden_states) - value_states = module.v_proj(hidden_states) - query_states = query_states.view( - bsz, q_len, module.num_heads, module.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, module.num_key_value_heads, module.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, module.num_key_value_heads, module.head_dim - ).transpose(1, 2) - - if position_embeddings is None: - cos, sin = module.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - key_states = past_key_value.key_cache[layer_idx] - value_states = past_key_value.value_cache[layer_idx] - key_states = repeat_kv(key_states, module.num_key_value_groups) - value_states = repeat_kv(value_states, module.num_key_value_groups) - - pruning_paras['any_states'] = (query_states, key_states, value_states) + past_key_value = kwargs['past_key_value'] + if past_key_value is None: + raise ValueError('DART needs past_key_value but got None.') + pruning_paras['any_states'] = past_key_value.key_cache[layer_idx] return layer_outs @prefill_wrapper def pruning_hook(module, args, kwargs, pruning_paras, normlayer): - image_token_start_index = pruning_paras['image_token_start_index'] - image_token_length = pruning_paras['image_token_length'] - any_states = pruning_paras['any_states'][-2] - seq_length = pruning_paras['seq_len'] + image_token_start_index = pruning_paras['vision_token_start_index'] + image_token_length = pruning_paras['vision_token_length'] + any_states = pruning_paras['any_states'] hidden_states = args[0] attention_mask = kwargs['attention_mask'] + seq_length = hidden_states.shape[1] device = hidden_states.device last_layer_state = normlayer(hidden_states) @@ -140,27 +86,20 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer): kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone()) position_embeddings = kwargs['position_embeddings'] - new_pe0 = position_embeddings[0][:, keep_indexs, :].clone() - new_pe1 = position_embeddings[1][:, keep_indexs, :].clone() + index_dim = 1 if position_embeddings[0].dim() == 3 else 2 + new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone() + new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone() position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0) position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1) return (hidden_states,), kwargs - hook_fn = input_hook_llava( - self.model.vlm_model.prepare_inputs_labels_for_multimodal, - self.pruning_paras - ) - self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( - hook_fn, self.model.vlm_model - ) - - self.model.model.model.register_forward_pre_hook( - functools.partial(get_seq_len_hook, pruning_paras=self.pruning_paras), - with_kwargs=True - ) + if self.special_config['vision_token_length'] is None: + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras) + ) - self.blocks[self.pruning_loc - 1].self_attn.register_forward_hook( + self.blocks[self.pruning_loc - 1].register_forward_hook( functools.partial( get_any_states_hook, pruning_paras=self.pruning_paras, @@ -173,24 +112,21 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer): functools.partial( pruning_hook, pruning_paras=self.pruning_paras, - normlayer=self.model.model.model.norm + normlayer=self.model.language_model.norm ), with_kwargs=True ) def get_retained_image_token(pruning_paras, last_layer_state, any_states): - image_token_start_index = pruning_paras['image_token_start_index'] - image_token_length = pruning_paras['image_token_length'] - MAX_NUM_TRUNCTION = pruning_paras['max_num_trunction'] + image_token_start_index = pruning_paras['vision_token_start_index'] + image_token_length = pruning_paras['vision_token_length'] pivot_image_token = pruning_paras['pivot_image_token'] pivot_text_token = pruning_paras['pivot_text_token'] reduction_ratio = pruning_paras['reduction_ratio'] - TOKEN_TOPK = math.ceil( - ( - MAX_NUM_TRUNCTION if MAX_NUM_TRUNCTION is not None - else (image_token_length * (1 - reduction_ratio)) - ) // (pivot_image_token + pivot_text_token)) + TOKEN_TOPK = int( + image_token_length * (1 - reduction_ratio) / (pivot_image_token + pivot_text_token) + ) device = last_layer_state.device any_states = any_states.permute(0, 2, 1, 3) diff --git a/llmc/compression/token_reduction/fastv.py b/llmc/compression/token_reduction/fastv.py index 8a699845..48080da5 100644 --- a/llmc/compression/token_reduction/fastv.py +++ b/llmc/compression/token_reduction/fastv.py @@ -1,6 +1,4 @@ import functools -from functools import wraps -from types import MethodType import torch @@ -18,46 +16,22 @@ def __init__(self, config, model, blocks): self.register_reduction_modules() def add_sparse_config(self): - self.pruning_loc = self.special_config['pruning_loc'] - self.special_config['image_token_length'] = \ - self.model.pruning_config['image_token_length'] - self.special_config['IMAGE_TOKEN_INDEX'] = \ - self.model.pruning_config['IMAGE_TOKEN_INDEX'] - self.special_config['attn_scores'] = None self.pruning_paras = self.special_config def register_reduction_modules(self): @prefill_wrapper - def input_hook(module, input_args, pruning_paras): + def vtoken_length_hook(module, input_args, pruning_paras): input_ids = input_args[0] - image_token_idxs = (input_ids[0] == - pruning_paras['vision_token_index']).nonzero(as_tuple=True)[0] - pruning_paras['image_token_start_index'] = image_token_idxs[0].item() - + token_indices = torch.where( + input_ids[0] == pruning_paras['vision_token_index'] + )[0] + pruning_paras['vision_token_length'] = token_indices.shape[0] return input_args - def input_hook_llava(fn, pruning_paras): - @wraps(fn) - def wrapper(self, *args, **kwargs): - if len(args) == 0: - return fn(*args, **kwargs) - input_args = args[0] - if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1: - return fn(*args, **kwargs) - - input_ids = args[0] - attention_mask = args[2] - token_indices = \ - input_ids[0][attention_mask[0]] == pruning_paras['IMAGE_TOKEN_INDEX'] - pruning_paras['image_token_start_index'] = torch.where(token_indices)[0][0].item() - - outputs = fn(*args, **kwargs) - return outputs - return wrapper - + @prefill_wrapper def update_output_attentions_hook(module, args, kwargs, pruning_paras): kwargs['output_attentions'] = True pruning_paras['attn_scores'] = module.__class__.forward(module, *args, **kwargs)[1] @@ -68,8 +42,8 @@ def update_output_attentions_hook(module, args, kwargs, pruning_paras): def fastv_pruning_hook(module, args, kwargs, pruning_paras): rate = pruning_paras['rate'] - image_token_start_index = pruning_paras['image_token_start_index'] - image_token_length = pruning_paras['image_token_length'] + image_token_start_index = pruning_paras['vision_token_start_index'] + image_token_length = pruning_paras['vision_token_length'] hidden_states = args[0] causal_mask = kwargs['attention_mask'] @@ -121,24 +95,17 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras): kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone()) position_embeddings = kwargs['position_embeddings'] - new_pe0 = position_embeddings[0][:, keep_indexs, :].clone() - new_pe1 = position_embeddings[1][:, keep_indexs, :].clone() + index_dim = 1 if position_embeddings[0].dim() == 3 else 2 + new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone() + new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone() position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0) position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1) return (hidden_states,), kwargs - if self.model.__class__.__name__ == 'LlavaHf': + if self.special_config['vision_token_length'] is None: self.model.embed_tokens.register_forward_pre_hook( - functools.partial(input_hook, pruning_paras=self.pruning_paras) - ) - elif self.model.__class__.__name__ == 'Llava': - hook_fn = input_hook_llava( - self.model.vlm_model.prepare_inputs_labels_for_multimodal, - self.pruning_paras - ) - self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType( - hook_fn, self.model.vlm_model + functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras) ) self.blocks[self.pruning_loc - 1].register_forward_pre_hook( diff --git a/llmc/compression/token_reduction/token_reduction_module.py b/llmc/compression/token_reduction/token_reduction_module.py index 619cf62a..37e4483a 100644 --- a/llmc/compression/token_reduction/token_reduction_module.py +++ b/llmc/compression/token_reduction/token_reduction_module.py @@ -23,12 +23,15 @@ def set_sparse_config(self): 'video_token_length' ] else: - self.special_config['vision_token_index'] = self.model.pruning_config[ - 'image_token_index' - ] - self.special_config['vision_token_length'] = self.model.pruning_config[ - 'image_token_length' - ] + self.special_config['vision_token_index'] = self.model.pruning_config.get( + 'image_token_index', None + ) + self.special_config['vision_token_start_index'] = self.model.pruning_config.get( + 'vision_token_start_index', None + ) + self.special_config['vision_token_length'] = self.model.pruning_config.get( + 'image_token_length', None + ) def register_reduction_modules(self): pass diff --git a/llmc/compression/token_reduction/utils.py b/llmc/compression/token_reduction/utils.py index 727f9110..3a43c8e1 100755 --- a/llmc/compression/token_reduction/utils.py +++ b/llmc/compression/token_reduction/utils.py @@ -63,17 +63,12 @@ def make_tome_class(transformer_class): class VisionZipTransformer(transformer_class): """ Modifications: - - Initialize r, token size, and token sources. + - Initialize r """ - - def forward(self, *args, **kwdargs) -> torch.Tensor: + def forward(self, *args, **kwargs) -> torch.Tensor: self._info['r'] = parse_r(len(self.vision_model.encoder.layers), self.r) # self._info["r"] = self.r - - self._info['size'] = None - self._info['source'] = None - - return super().forward(*args, **kwdargs) + return super().forward(*args, **kwargs) return VisionZipTransformer @@ -93,7 +88,6 @@ def apply_info(model, dominant_num, contextual_num): for module in model.modules(): if isinstance(module, CLIPEncoderLayer): module.self_attn.k_proj._info = model._info - module.self_attn.k_proj.metric = None def add_post_hook_to_get_2dPool(model, post_hook_fn, pruning_paras): diff --git a/llmc/compression/token_reduction/visionzip.py b/llmc/compression/token_reduction/visionzip.py index 5109b478..76a45e81 100755 --- a/llmc/compression/token_reduction/visionzip.py +++ b/llmc/compression/token_reduction/visionzip.py @@ -1,4 +1,7 @@ import functools +import math +from functools import wraps +from types import MethodType from typing import Any, List, Optional, Tuple, Union import torch @@ -9,7 +12,7 @@ from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY from .token_reduction_module import TokenReductionModule -from .utils import apply_info +from .utils import apply_info, prefill_wrapper def visionzip_forward( @@ -231,55 +234,48 @@ def visionzip_forward( ) -def CLIP_EncoderLayer_forward( +def Qwen2_5_VLVisionAttention_forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - causal_attention_mask: torch.Tensor, - output_attentions: Optional[bool] = False, -) -> Tuple[torch.FloatTensor]: - # docformatter: off - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer - `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` - `(config.encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. - See `attentions` under - returned tensors for more detail. - """ - # docformatter: on - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, + pruning_paras, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, +) -> torch.Tensor: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import \ + apply_rotary_pos_emb_vision + head_dim = self.qkv.in_features // self.num_heads + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape( + seq_length, 3, self.num_heads, -1 + ).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype ) - metric = self.self_attn.k_proj.metric - - hidden_states = residual + hidden_states - - r = self.self_attn.k_proj._info['r'].pop(0) - if r > 0: - self.metric = metric - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + pruning_paras['attn_logits'] = attn_weights + pruning_paras['attn_key'] = k + return attn_output @TOKEN_REDUCTION_REGISTRY.register('VisionZip') @@ -291,8 +287,10 @@ def __init__(self, config, model, blocks): def add_sparse_config(self): special_config = self.config.get('special', {}) - self.dominant = special_config.get('dominant', 192) - self.contextual = special_config.get('contextual', 30) + self.dominant = special_config['dominant'] + self.contextual = special_config['contextual'] + + self.pruning_paras = special_config def register_reduction_modules(self): @@ -425,27 +423,170 @@ def update_output_attentions_hook(module, args, kwargs): elif self.model.__class__.__name__ == 'Llava': vision_tower = self.model.vlm_model.model.vision_tower.vision_tower - apply_info( - vision_tower, - dominant_num=self.dominant, - contextual_num=self.contextual, - ) + if self.model.__class__.__name__ in ('LlavaHf', 'Llava'): + apply_info( + vision_tower, + dominant_num=self.dominant, + contextual_num=self.contextual, + ) if self.model.__class__.__name__ == 'LlavaHf': self.model.vlm_model.__class__.forward = visionzip_forward - elif self.model.__class__.__name__ == 'Llava': - from transformers.models.clip.modeling_clip import CLIPEncoderLayer - CLIPEncoderLayer.forward = CLIP_EncoderLayer_forward + if self.model.__class__.__name__ in ('LlavaHf', 'Llava'): + vision_tower.register_forward_pre_hook( + update_output_attentions_hook, with_kwargs=True + ) - vision_tower.register_forward_pre_hook( - update_output_attentions_hook, with_kwargs=True - ) + r = vision_tower.r + for idx, block in enumerate(self.blocks): + if r[idx]: + block.self_attn.k_proj.num_heads = block.self_attn.num_heads + block.self_attn.k_proj.head_dim = block.self_attn.head_dim + block.self_attn.k_proj.register_forward_hook(store_key_hook) + + vision_tower.register_forward_hook(visionzip_hook) + + def get_metric(fn, pruning_paras): + @wraps(fn) + def wrapper(self, *args, **kwargs): + return fn(self, *args, pruning_paras=pruning_paras, **kwargs) + return wrapper + + def merger_hook(module, inputs, kwargs, layer_outs, pruning_paras): + with torch.no_grad(): + attn_mean = pruning_paras['attn_logits'].mean(dim=0) + attn_key = pruning_paras['attn_key'] + + window_index, _ = module.get_window_index(kwargs['grid_thw']) + reverse_indices = torch.argsort(window_index) + + attn_mean = attn_mean.sum(dim=0) + attn_mean = attn_mean.view(attn_mean.shape[0] // 4, -1).mean(dim=-1) + attn_mean = attn_mean[reverse_indices] + + attn_key = attn_key.view( + attn_key.shape[0], attn_key.shape[1] // 4, + 4, attn_key.shape[-1] + ).mean(dim=2) + attn_key = attn_key[:, reverse_indices, :].mean(dim=0).unsqueeze(0) + + pruning_paras['attn_logits'] = attn_mean + pruning_paras['attn_key'] = attn_key + return layer_outs + + @prefill_wrapper + def get_input_ids_hook(module, input_args, pruning_paras): + pruning_paras['input_ids'] = input_args[0] + return input_args + + def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras): + if kwargs['position_ids'].shape[-1] == 1: + return args, kwargs + attn_logits = pruning_paras['attn_logits'] + attn_key = pruning_paras['attn_key'] + inputs_embeds = kwargs['inputs_embeds'] + position_ids = kwargs['position_ids'] + attention_mask = kwargs['attention_mask'] + + dominant_num = int(self.dominant * attn_logits.size(0)) + contextual_num = max(int(self.contextual * attn_logits.size(0)), 1) + topk_values, topk_indices = torch.topk(attn_logits, dominant_num) + + mask = torch.zeros_like(attn_logits, dtype=torch.bool) + mask[topk_indices] = True + contextual_mask = ~mask + metric_filtered = attn_key[:, contextual_mask] + metric_normalized = metric_filtered / metric_filtered.norm(dim=-1, keepdim=True) + del attn_key, metric_filtered - r = vision_tower.r - for idx, block in enumerate(self.blocks): - if r[idx]: - block.self_attn.k_proj.num_heads = block.self_attn.num_heads - block.self_attn.k_proj.head_dim = block.self_attn.head_dim - block.self_attn.k_proj.register_forward_hook(store_key_hook) + # Contextual Visual Tokens + step = max(1, metric_normalized.shape[1] // contextual_num) + target_indices = torch.arange( + 0, metric_normalized.shape[1], step, + device=metric_normalized.device + )[:contextual_num] + target_tokens = metric_normalized[:, target_indices, :] - vision_tower.register_forward_hook(visionzip_hook) + tokens_to_merge = metric_normalized[ + :, + ~torch.isin( + torch.arange( + metric_normalized.shape[1], + device=metric_normalized.device + ), target_indices + ), + : + ] + similarity = torch.bmm(tokens_to_merge, target_tokens.transpose(1, 2)) + assign_one_hot = torch.zeros( + tokens_to_merge.shape[0], + tokens_to_merge.shape[1], + contextual_num, + dtype=attn_logits.dtype, + device=metric_normalized.device + ) + assign_one_hot.scatter_(2, similarity.argmax(dim=2).unsqueeze(-1), 1) + counts = assign_one_hot.sum(dim=1).clamp(min=1).unsqueeze(-1) + + select_mask = torch.zeros_like(attn_logits, dtype=torch.bool) + select_mask[topk_indices] = True + + false_pos = (~select_mask).nonzero(as_tuple=True)[0] + + select_mask[false_pos[target_indices]] = True + + img_mask = (pruning_paras['input_ids'] == pruning_paras['vision_token_index'])[0] + st_idx = torch.nonzero(img_mask, as_tuple=True)[0] + + if st_idx.numel() > 0: + first, last = st_idx[0].item(), st_idx[-1].item() + img_mask[first: last + 1] = ~select_mask + img_mask = ~img_mask + contexual_input_idx = false_pos[target_indices] + first + + hidden_states_filtered = inputs_embeds[:, first: last + 1][:, contextual_mask] + hidden_to_merge = hidden_states_filtered[ + :, + ~torch.isin( + torch.arange( + hidden_states_filtered.shape[1], + device=hidden_states_filtered.device + ), target_indices + ), + : + ] + aggregated_hidden = torch.bmm(assign_one_hot.transpose(1, 2), hidden_to_merge) / counts + target_hidden = hidden_states_filtered[:, target_indices, :] + + contextual_tokens = target_hidden + aggregated_hidden + + kwargs['position_ids'] = position_ids[:, :, img_mask] + kwargs['attention_mask'] = attention_mask[:, img_mask] + inputs_embeds[:, contexual_input_idx] = contextual_tokens + kwargs['inputs_embeds'] = inputs_embeds[:, img_mask] + del contextual_tokens, hidden_states_filtered, hidden_to_merge, aggregated_hidden + torch.cuda.empty_cache() + return args, kwargs + + if self.model.__class__.__name__ == 'Qwen2_5VL': + self.blocks[-1].attn.forward = MethodType( + get_metric(Qwen2_5_VLVisionAttention_forward, self.pruning_paras), + self.blocks[-1].attn + ) + self.model.vision_model.register_forward_hook( + functools.partial( + merger_hook, + pruning_paras=self.pruning_paras, + ), + with_kwargs=True + ) + self.model.embed_tokens.register_forward_pre_hook( + functools.partial(get_input_ids_hook, pruning_paras=self.pruning_paras) + ) + self.model.language_model.register_forward_pre_hook( + functools.partial( + prune_qwenv25vl_hook, + pruning_paras=self.pruning_paras, + ), + with_kwargs=True + ) diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index bd62d534..48586cf2 100755 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -22,6 +22,7 @@ from .phi3 import Phi3 from .qwen import Qwen from .qwen2 import Qwen2 +from .qwen2_5vl import Qwen2_5VL from .qwen2audio import Qwen2Audio from .qwen2moe import Qwen2Moe from .qwen2vl import Qwen2VL diff --git a/llmc/models/llava.py b/llmc/models/llava.py index 8b74812a..a66c4ad2 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -62,6 +62,7 @@ def build_model(self): self.mm_model = self.vlm_model logger.info(f'self.vlm_model : {self.vlm_model}') self.vision_model = self.vlm_model.get_vision_tower() + self.language_model = self.vlm_model.model self.vision_projector = self.vlm_model.model.mm_projector # Llava merges the language model with the vision projector and vision model self.model = self.vlm_model @@ -71,8 +72,9 @@ def build_model(self): 'image_token_length': self.vlm_model_config.image_seq_length, 'select_layer': self.vlm_model_config.vision_feature_layer, 'select_feature': self.vlm_model_config.vision_feature_select_strategy, - 'image_token_index': self.vlm_model_config.image_token_index, + 'image_token_index': IMAGE_TOKEN_INDEX, 'IMAGE_TOKEN_INDEX': IMAGE_TOKEN_INDEX, # for llava + 'vision_token_start_index': 35, } self.processor = None self.first_turn_question = True diff --git a/llmc/models/llava_hf.py b/llmc/models/llava_hf.py index 8e428739..6a794a3b 100644 --- a/llmc/models/llava_hf.py +++ b/llmc/models/llava_hf.py @@ -47,6 +47,7 @@ def build_model(self): 'select_layer': self.vlm_model_config.vision_feature_layer, 'select_feature': self.vlm_model_config.vision_feature_select_strategy, 'image_token_index': self.vlm_model_config.image_token_index, + 'vision_token_start_index': 35, } self.processor = AutoProcessor.from_pretrained(self.model_path) diff --git a/llmc/models/qwen2_5vl.py b/llmc/models/qwen2_5vl.py new file mode 100755 index 00000000..16bf45e8 --- /dev/null +++ b/llmc/models/qwen2_5vl.py @@ -0,0 +1,236 @@ + +from typing import Optional, Union + +import torch +import torch.nn as nn +from accelerate import Accelerator, DistributedType +from loguru import logger +from transformers import AutoConfig, AutoProcessor, AutoTokenizer + +try: + from transformers import Qwen2_5_VLForConditionalGeneration +except Exception: + logger.warning( + 'Can not import Qwen2_5_VLForConditionalGeneration. ' + 'If you need it, please upgrade transformers.' + ) + +try: + from qwen_vl_utils import process_vision_info +except Exception: + logger.warning( + 'Can not import qwen_vl_utils. ' + 'If you need it, please pip install qwen-vl-utils' + ) + +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .qwen2vl import Qwen2VL + + +@MODEL_REGISTRY +class Qwen2_5VL(Qwen2VL): + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + + def build_model(self): + self.eval_name = 'Qwen2_5VLEval' + self.vlm_model_config = AutoConfig.from_pretrained( + self.model_path, trust_remote_code=True + ) + if not self.use_cache: + if hasattr(self.vlm_model_config, 'use_cache'): + self.vlm_model_config.use_cache = False + logger.info(f'self.vlm_model_config : {self.vlm_model_config}') + + self.vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + self.model_path, + config=self.vlm_model_config, + trust_remote_code=True, + torch_dtype=self.torch_dtype, + low_cpu_mem_usage=True, + ) + self.mm_model = self.vlm_model + logger.info(f'self.vlm_model : {self.vlm_model}') + + self.vision_model = self.vlm_model.visual + self.language_model = self.vlm_model.model + self.vision_projector = self.vision_model.merger + self.model = self.vlm_model + + self.model_config = self.vlm_model_config + + self.min_pixels = 256 * 28 * 28 + self.max_pixels = 1280 * 28 * 28 + logger.warning(f'min_pixels is set to: {self.min_pixels}') + logger.warning(f'max_pixels is set to: {self.max_pixels}') + self.processor = AutoProcessor.from_pretrained( + self.model_path, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels + ) + self.pruning_config = { + 'is_video_model': False, + 'image_token_index': self.vlm_model_config.image_token_id, + 'vision_end_token_id': self.vlm_model_config.vision_end_token_id, + 'vision_start_token_id': self.vlm_model_config.vision_start_token_id, + 'vision_token_start_index': 15 + } + + # todo: check + def get_subsets_in_block(self, block): + if self.get_modality() == 'language': + return super().get_subsets_in_block(block) + elif self.get_modality() == 'vision': + return [ + { + 'layers': { + 'attn.qkv': block.attn.qkv, + }, + 'prev_op': [block.norm1], + 'input': ['attn.qkv'], + 'inspect': block.attn, + 'has_kwargs': True, + }, + { + 'layers': {'attn.proj': block.attn.proj}, + 'prev_op': [block.attn.qkv], + 'input': ['attn.proj'], + 'inspect': block.attn.proj, + 'has_kwargs': False, + }, + { + 'layers': { + 'mlp.gate_proj': block.mlp.gate_proj, + 'mlp.up_proj': block.mlp.up_proj, + }, + 'prev_op': [block.norm2], + 'input': ['mlp.gate_proj'], + 'inspect': block.mlp, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.down_proj': block.mlp.down_proj}, + 'prev_op': [block.mlp.up_proj], + 'input': ['mlp.down_proj'], + 'inspect': block.mlp.down_proj, + 'has_kwargs': False, + 'is_mlp': True, + }, + ] + else: + raise Exception(f'Qwen2_5VL do not support {self.get_modality()} modality.') + + +try: + from lmms_eval.api.model import lmms + from lmms_eval.models.qwen2_5_vl import Qwen2_5_VL + + @MODEL_REGISTRY + class Qwen2_5VLEval(Qwen2_5_VL): + def __init__( + self, + llmc_model, + pretrained: str = 'Qwen/Qwen2.5-VL-3B-Instruct', + device: Optional[str] = 'cuda', + device_map: Optional[str] = 'auto', + batch_size: Optional[Union[int, str]] = 1, + use_cache=True, + attn_implementation: Optional[str] = None, + min_pixels: int = 256 * 28 * 28, + max_pixels: int = 1605632, + max_num_frames: int = 32, + use_custom_video_loader: Optional[bool] = False, + fps: Optional[float] = None, + max_image_size: Optional[int] = None, + system_prompt: Optional[str] = 'You are a helpful assistant.', + interleave_visuals: Optional[bool] = False, + reasoning_prompt: Optional[str] = None, + **kwargs, + ) -> None: + lmms.__init__(self) + # Do not use kwargs for now + assert kwargs == {}, f'Unexpected kwargs: {kwargs}' + + # Validate attention implementation + valid_attn_implementations = [None, 'flash_attention_2', 'sdpa', 'eager'] + if attn_implementation not in valid_attn_implementations: + raise ValueError( + f'attn_implementation must be one of {valid_attn_implementations}, \ + got {attn_implementation}' + ) + + self.use_custom_video_loader = use_custom_video_loader + self.fps = fps + # if self.fps and not self.use_custom_video_loader: + # raise ValueError("FPS is only applicable if use_custom_video_loader is True") + self.max_image_size = max_image_size + if self.max_image_size and not self.use_custom_video_loader: + raise ValueError( + 'max_image_size is only applicable if use_custom_video_loader is True' + ) + + accelerator = Accelerator() + if accelerator.num_processes > 1: + self._device = torch.device(f'cuda:{accelerator.local_process_index}') + self.device_map = f'cuda:{accelerator.local_process_index}' + else: + self._device = torch.device(device) + self.device_map = device_map if device_map else device + + # Prepare model loading arguments + model_kwargs = { + 'torch_dtype': 'auto', + 'device_map': self.device_map, + } + + # Add attention implementation if specified + if attn_implementation is not None: + model_kwargs['attn_implementation'] = attn_implementation + + self._model = llmc_model.eval().cuda() + self.max_pixels = max_pixels + self.min_pixels = min_pixels + self.max_num_frames = max_num_frames + + if reasoning_prompt: + self.reasoning_prompt = reasoning_prompt.replace('\\n', '\n') + else: + self.reasoning_prompt = None + self.processor = AutoProcessor.from_pretrained( + pretrained, + ax_pixels=max_pixels, + min_pixels=min_pixels + ) + self._tokenizer = AutoTokenizer.from_pretrained(pretrained) + self.system_prompt = system_prompt + self.interleave_visuals = interleave_visuals + + self._config = self.model.config + self._max_length = kwargs.get('max_length', 2048) + self.batch_size_per_gpu = int(batch_size) + self.use_cache = use_cache + + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [ + DistributedType.FSDP, + DistributedType.MULTI_GPU, + ], 'Unsupported distributed type provided. Only DDP and FSDP are supported.' + if accelerator.distributed_type == DistributedType.FSDP: + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + logger.info(f'Using {accelerator.num_processes} devices with data parallelism') + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self._rank = 0 + self._world_size = 1 +except Exception: + logger.warning( + 'Can not import lmms_eval. ' + 'If you need it, please upgrade transformers.' + ) diff --git a/llmc/utils/__init__.py b/llmc/utils/__init__.py index 7f8a3841..aa5fefa1 100755 --- a/llmc/utils/__init__.py +++ b/llmc/utils/__init__.py @@ -4,3 +4,4 @@ from .utils import (check_config, copy_files, deploy_all_modality, get_modality, mkdirs, print_important_package_version, seed_all) +from .visualizer import visualize_kept_patches diff --git a/llmc/utils/visualizer.py b/llmc/utils/visualizer.py new file mode 100644 index 00000000..c2d9ebe0 --- /dev/null +++ b/llmc/utils/visualizer.py @@ -0,0 +1,199 @@ +import numpy as np +import torch +from loguru import logger +from PIL import Image, ImageDraw + +try: + import matplotlib.pyplot as plt +except Exception: + logger.warning( + 'Can not import matplotlib. ' + 'If you need it, please install.' + ) + + +def save_image(imgae_tensor, mean, std, save_path): + img = imgae_tensor.cpu().numpy().transpose(1, 2, 0) # (C, H, W) -> (H, W, C) + img = img * std + mean + img = np.clip(img * 255, 0, 255).astype(np.uint8) + Image.fromarray(img).save(save_path) + + +def visualize_kept_patches( + image, keep_idx, mean, std, + patch_size=14, save_path=None, darken_ratio=0.3 +): + assert image.ndim == 3 and image.shape[0] == 3, \ + f'Expected image of shape [3, H, W], got {image.shape}' + # save_image(image,mean,std,save_path) + + _, H, W = image.shape # 3 336 336 + device = image.device + num_patches_h = H // patch_size # 24 + num_patches_w = W // patch_size # 24 + total_patches = num_patches_h * num_patches_w + + patch_mask = torch.zeros(total_patches, dtype=torch.bool, device=device) + patch_mask[keep_idx] = True + patch_mask = patch_mask.view(num_patches_h, num_patches_w) + + mask = patch_mask.repeat_interleave(patch_size, dim=0).repeat_interleave(patch_size, dim=1) + mask = mask.unsqueeze(0) # shape [1, H, W] + + # Darken image + masked_image = image * (mask + (~mask) * darken_ratio) + + save_image(masked_image, mean, std, save_path) + + +def grid_show(to_shows, cols): + rows = (len(to_shows) - 1) // cols + 1 + it = iter(to_shows) + fig, axs = plt.subplots(rows, cols, figsize=(rows * 8.5, cols * 2)) + for i in range(rows): + for j in range(cols): + try: + image, title = next(it) + except StopIteration: + image = np.zeros_like(to_shows[0][0]) + title = 'pad' + axs[i, j].imshow(image) + axs[i, j].set_title(title) + axs[i, j].set_yticks([]) + axs[i, j].set_xticks([]) + plt.show() + + +# def visualize_head(att_map): +# ax = plt.gca() +# # Plot the heatmap +# im = ax.imshow(att_map) +# # Create colorbar +# cbar = ax.figure.colorbar(im, ax=ax) +# plt.show() + + +def visualize_heads(att_map, cols): + to_shows = [] + att_map = att_map.squeeze() + for i in range(att_map.shape[0]): + to_shows.append((att_map[i], f'Head {i}')) + average_att_map = att_map.mean(axis=0) + to_shows.append((average_att_map, 'Head Average')) + grid_show(to_shows, cols=cols) + + +def gray2rgb(image): + return np.repeat(image[..., np.newaxis], 3, 2) + + +def cls_padding(image, mask, cls_weight, grid_size): + if not isinstance(grid_size, tuple): + grid_size = (grid_size, grid_size) + + image = np.array(image) + + H, W = image.shape[:2] + delta_H = int(H / grid_size[0]) + delta_W = int(W / grid_size[1]) + + padding_w = delta_W + padding_h = H + padding = np.ones_like(image) * 255 + padding = padding[:padding_h, :padding_w] + + padded_image = np.hstack((padding, image)) + padded_image = Image.fromarray(padded_image) + draw = ImageDraw.Draw(padded_image) + draw.text((int(delta_W / 4), int(delta_H / 4)), 'CLS', fill=(0, 0, 0)) + + mask = mask / max(np.max(mask), cls_weight) + cls_weight = cls_weight / max(np.max(mask), cls_weight) + + if len(padding.shape) == 3: + padding = padding[:, :, 0] + padding[:, :] = np.min(mask) + mask_to_pad = np.ones((1, 1)) * cls_weight + mask_to_pad = Image.fromarray(mask_to_pad) + mask_to_pad = mask_to_pad.resize((delta_W, delta_H)) + mask_to_pad = np.array(mask_to_pad) + + padding[:delta_H, :delta_W] = mask_to_pad + padded_mask = np.hstack((padding, mask)) + padded_mask = padded_mask + + meta_mask = np.zeros((padded_mask.shape[0], padded_mask.shape[1], 4)) + meta_mask[delta_H:, 0:delta_W, :] = 1 + + return padded_image, padded_mask, meta_mask + + +def visualize_grid_to_grid_with_cls(att_map, grid_index, image, grid_size=14, alpha=0.6): + if not isinstance(grid_size, tuple): + grid_size = (grid_size, grid_size) + + attention_map = att_map[grid_index] + cls_weight = attention_map[0] + + mask = attention_map[1:].reshape(grid_size[0], grid_size[1]) + mask = Image.fromarray(mask).resize((image.size)) + + padded_image, padded_mask, meta_mask = cls_padding(image, mask, cls_weight, grid_size) + + if grid_index != 0: # adjust grid_index since we pad our image + grid_index = grid_index + (grid_index - 1) // grid_size[1] + + grid_image = highlight_grid(padded_image, [grid_index], (grid_size[0], grid_size[1] + 1)) + + fig, ax = plt.subplots(1, 2, figsize=(10, 7)) + fig.tight_layout() + + ax[0].imshow(grid_image) + ax[0].axis('off') + + ax[1].imshow(grid_image) + ax[1].imshow(padded_mask, alpha=alpha, cmap='rainbow') + ax[1].imshow(meta_mask) + ax[1].axis('off') + + +def visualize_grid_to_grid(att_map, grid_index, image, grid_size=14, alpha=0.6): + if not isinstance(grid_size, tuple): + grid_size = (grid_size, grid_size) + + H, W = att_map.shape + # with_cls_token = False + + grid_image = highlight_grid(image, [grid_index], grid_size) + + mask = att_map[grid_index].reshape(grid_size[0], grid_size[1]) + mask = Image.fromarray(mask).resize((image.size)) + + fig, ax = plt.subplots(1, 2, figsize=(10, 7)) + fig.tight_layout() + + ax[0].imshow(grid_image) + ax[0].axis('off') + + ax[1].imshow(grid_image) + ax[1].imshow(mask / np.max(mask), alpha=alpha, cmap='rainbow') + ax[1].axis('off') + plt.show() + + +def highlight_grid(image, grid_indexes, grid_size=14): + if not isinstance(grid_size, tuple): + grid_size = (grid_size, grid_size) + + W, H = image.size + h = H / grid_size[0] + w = W / grid_size[1] + image = image.copy() + for grid_index in grid_indexes: + x, y = np.unravel_index(grid_index, (grid_size[0], grid_size[1])) + a = ImageDraw.ImageDraw(image) + a.rectangle( + [(y * w, x * h), (y * w + w, x * h + h)], + fill=None, outline='red', width=2 + ) + return image