diff --git a/llmc/compression/token_reduction/fastv.py b/llmc/compression/token_reduction/fastv.py index 2fff71c8..b1c56d29 100644 --- a/llmc/compression/token_reduction/fastv.py +++ b/llmc/compression/token_reduction/fastv.py @@ -48,25 +48,24 @@ def hook_prepare_inputs_labels_for_multimodal( past_key_values, labels, images, - image_sizes + modalities=['image'], + image_sizes=None, ): if 'image_token_start_index' not in pruning_paras: token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX pruning_paras['image_token_start_index'] = torch.where(token_indices)[0].item() return self._original_prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, - past_key_values, labels, images, image_sizes + past_key_values, labels, images, modalities, image_sizes ) return hook_prepare_inputs_labels_for_multimodal - def update_output_attentions_hook(module, args, kwargs): + def update_output_attentions_hook(module, args, kwargs, pruning_paras): kwargs['output_attentions'] = True + pruning_paras['attn_scores'] = module.__class__.forward(module, *args, **kwargs)[1] + kwargs['output_attentions'] = False return args, kwargs - def store_attention_hook(m, x, layer_outputs, pruning_paras): - layer_attention = layer_outputs[1] - pruning_paras['attn_scores'] = layer_attention - @prefill_wrapper def fastv_pruning_hook(module, args, kwargs, pruning_paras): @@ -76,7 +75,6 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras): hidden_states = args[0] causal_mask = kwargs['attention_mask'] - cache_position = kwargs['cache_position'] device = hidden_states.device # last_layer_attention = layer_outputs[1] @@ -106,37 +104,26 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras): # sort index keep_indexs = keep_indexs.sort().values - # update seq length - new_seq_length = keep_indexs.shape[0] # filter hidden states & - hidden_states = hidden_states[:, keep_indexs, :] # update position ids position_ids = keep_indexs.unsqueeze(0) # update attention mask - causal_mask = _update_causal_mask( - causal_mask, None, hidden_states, 0 - ) if causal_mask is not None else None - kwargs['attention_mask'] = causal_mask - kwargs['cache_position'] = cache_position[:new_seq_length] - kwargs['position_ids'] = position_ids - kwargs['position_embeddings'] = None - pruning_paras['attention_mask'] = causal_mask - pruning_paras['cache_position'] = cache_position[:new_seq_length] - pruning_paras['position_ids'] = position_ids - pruning_paras['position_embeddings'] = None + if causal_mask is not None: + causal_mask = causal_mask[:, :, :hidden_states.shape[1], :hidden_states.shape[1]] + kwargs['attention_mask'].resize_as_(causal_mask).copy_(causal_mask.clone()) + kwargs['cache_position'].resize_as_(position_ids.squeeze(0)).copy_( + position_ids.squeeze(0).clone()) + kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone()) + + position_embeddings = kwargs['position_embeddings'] + new_pe0 = position_embeddings[0][:, keep_indexs, :].clone() + new_pe1 = position_embeddings[1][:, keep_indexs, :].clone() + position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0) + position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1) return (hidden_states,), kwargs - @prefill_wrapper - def read_parameter_hook(module, args, kwargs, pruning_paras): - kwargs['attention_mask'] = pruning_paras['attention_mask'] - kwargs['cache_position'] = pruning_paras['cache_position'] - kwargs['position_ids'] = pruning_paras['position_ids'] - kwargs['position_embeddings'] = pruning_paras['position_embeddings'] - - return args, kwargs - if self.model.__class__.__name__ == 'LlavaHf': self.model.embed_tokens.register_forward_pre_hook( functools.partial(input_hook, pruning_paras=self.pruning_paras) @@ -151,21 +138,11 @@ def read_parameter_hook(module, args, kwargs, pruning_paras): ) self.blocks[self.pruning_loc - 1].register_forward_pre_hook( - update_output_attentions_hook, + functools.partial(update_output_attentions_hook, pruning_paras=self.pruning_paras), with_kwargs=True ) - self.blocks[self.pruning_loc - 1].register_forward_hook( - functools.partial(store_attention_hook, pruning_paras=self.pruning_paras), - ) - self.blocks[self.pruning_loc].register_forward_pre_hook( functools.partial(fastv_pruning_hook, pruning_paras=self.pruning_paras), with_kwargs=True ) - - for idx in range(self.pruning_loc + 1, len(self.blocks)): - self.blocks[idx].register_forward_pre_hook( - functools.partial(read_parameter_hook, pruning_paras=self.pruning_paras), - with_kwargs=True - ) diff --git a/llmc/models/llava.py b/llmc/models/llava.py index 818fd313..73a5a8ff 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -44,46 +44,13 @@ def build_model(self): self.llava_config.use_cache = True self.vlm_model_config.use_cache = True logger.info(f'self.vlm_model_config : {self.vlm_model_config}') + self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model( self.model_path, None, get_model_name_from_path(self.model_path), - load_8bit=False, - load_4bit=False, - device='cpu', - torch_dtype=self.torch_dtype, - config=self.llava_config, - ) - - # llava forward not support "cache_position" - ori_forward = self.vlm_model.forward - - def safe_forward(*args, **kwargs): - kwargs.pop('cache_position', None) - return ori_forward(*args, **kwargs) - self.vlm_model.forward = safe_forward - - # llava generate use "inputs" instead of "input_ids" - ori_generate = self.vlm_model.generate - - def safe_generate(*args, **kwargs): - if 'input_ids' in kwargs: - kwargs['inputs'] = kwargs.pop('input_ids') - return ori_generate(*args, **kwargs) - self.vlm_model.generate = safe_generate - - # "attention_mask" is passed via kwargs rather than as an explicit keyword argument. - ori_prepare_inputs_for_generation = self.vlm_model.prepare_inputs_for_generation - - def safe_prepare_inputs_for_generation( - self, input_ids, past_key_values=None, - inputs_embeds=None, attention_mask=None, **kwargs): - if attention_mask is not None: - kwargs['attention_mask'] = attention_mask - return ori_prepare_inputs_for_generation( - input_ids, past_key_values, inputs_embeds, **kwargs) - self.vlm_model.prepare_inputs_for_generation = types.MethodType( - safe_prepare_inputs_for_generation, self.vlm_model + device_map='cpu', + attn_implementation='sdpa' ) self.eval_name = 'LlavaEval'