diff --git a/llmc/compression/token_reduction/random.py b/llmc/compression/token_reduction/random.py index 24069d49..9d71084f 100644 --- a/llmc/compression/token_reduction/random.py +++ b/llmc/compression/token_reduction/random.py @@ -3,6 +3,7 @@ from types import MethodType import torch +from loguru import logger from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY @@ -62,6 +63,8 @@ def input_hook(module, input_args, pruning_paras): @prefill_wrapper def random_pruning_hook(module, args, kwargs, pruning_paras): + logger.info(' ========random_pruning_hook======== ') + rate = pruning_paras['rate'] image_token_start_index = pruning_paras['image_token_start_index'] image_token_length = pruning_paras['image_token_length'] @@ -69,15 +72,28 @@ def random_pruning_hook(module, args, kwargs, pruning_paras): hidden_states = args[0] causal_mask = kwargs['attention_mask'] + logger.info(f'before hidden_states : {hidden_states.shape}') + device = hidden_states.device - vision_indexes = torch.arange( - image_token_start_index, - image_token_start_index + image_token_length, - device=device, - ) - num_keep = round(image_token_length * (1 - rate)) - rand_idx = torch.randperm(image_token_length, device=device)[:num_keep] - vision_indexes = vision_indexes[rand_idx] + + if self.model.first_turn_question: + logger.info(' -----first_turn_question-----') + vision_indexes = torch.arange( + image_token_start_index, + image_token_start_index + image_token_length, + device=device, + ) + num_keep = round(image_token_length * (1 - rate)) + rand_idx = torch.randperm(image_token_length, device=device)[:num_keep] + vision_indexes = vision_indexes[rand_idx] + + # save vision_indexes to module + module.register_buffer('vision_indexes', vision_indexes) + else: + logger.info(' -----not first_turn_question-----') + # load vision_indexes from module (prompt cache) + vision_indexes = module.vision_indexes + # keep index keep_indexs = torch.cat( ( @@ -115,6 +131,7 @@ def random_pruning_hook(module, args, kwargs, pruning_paras): position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0) position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1) + logger.info(f'after hidden_states : {hidden_states.shape}') return (hidden_states,), kwargs if self.model.__class__.__name__ == 'LlavaHf': diff --git a/llmc/eval/__init__.py b/llmc/eval/__init__.py index 069237b7..74ec5d15 100755 --- a/llmc/eval/__init__.py +++ b/llmc/eval/__init__.py @@ -1,6 +1,7 @@ from .eval_acc import AccuracyEval from .eval_code import HumanEval from .eval_custom_generate import CustomGenerate +from .eval_custom_generate_just_infer import CustomGenerateJustInfer from .eval_ppl import DecodePerplexityEval, PerplexityEval from .eval_token_consist import TokenConsistencyEval from .eval_video_generate import VideoGenerateEval diff --git a/llmc/eval/eval_custom_generate_just_infer.py b/llmc/eval/eval_custom_generate_just_infer.py new file mode 100644 index 00000000..658de5e2 --- /dev/null +++ b/llmc/eval/eval_custom_generate_just_infer.py @@ -0,0 +1,36 @@ +import glob +import json +import os + +import torch +from human_eval.data import stream_jsonl, write_jsonl +from human_eval.evaluation import evaluate_functional_correctness +from loguru import logger +from tqdm import tqdm + +from .eval_base import BaseEval + + +class CustomGenerateJustInfer: + def __init__(self, model, config): + self.model = model + self.config = config + self.eval_cfg = config.eval + + @torch.no_grad() + def eval(self, model, eval_pos=None): + logger.info('start inference') + + with open(os.path.join(self.eval_cfg.path, 'samples.json'), 'r') as f: + questions_list = json.load(f) + + custom_samples_ans = self.model.eval_custom_samples_just_infer( + questions_list, + self.eval_cfg + ) + + with open(os.path.join('custom_samples_ans.json'), 'w') as f: + json.dump(custom_samples_ans, f, indent=4) + + torch.cuda.empty_cache() + return 'custom gen done.' diff --git a/llmc/eval/utils.py b/llmc/eval/utils.py index ceafae5f..9e414a57 100755 --- a/llmc/eval/utils.py +++ b/llmc/eval/utils.py @@ -3,9 +3,9 @@ from loguru import logger -from llmc.eval import (AccuracyEval, CustomGenerate, DecodePerplexityEval, - HumanEval, PerplexityEval, TokenConsistencyEval, - VideoGenerateEval, VQAEval) +from llmc.eval import (AccuracyEval, CustomGenerate, CustomGenerateJustInfer, + DecodePerplexityEval, HumanEval, PerplexityEval, + TokenConsistencyEval, VideoGenerateEval, VQAEval) from llmc.utils import deploy_all_modality @@ -57,6 +57,8 @@ def get_eval_list(model, config): eval_class = HumanEval(model, config_for_eval) elif config_tmp.eval.type == 'generate_only': eval_class = CustomGenerate(model, config_for_eval) + elif config_tmp.eval.type == 'just_infer': + eval_class = CustomGenerateJustInfer(model, config_for_eval) elif config_tmp.eval.type == 'token_acc': eval_class = TokenConsistencyEval(model, config_for_eval) elif config_tmp.eval.type == 'ppl': diff --git a/llmc/models/llava.py b/llmc/models/llava.py index 1cecfc69..c1bfbce3 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -1,3 +1,4 @@ +import os import types from datetime import timedelta from typing import Optional, Union @@ -9,6 +10,7 @@ from lmms_eval.models.llava import Llava as LLaVA from loguru import logger from packaging import version +from PIL import Image from transformers import AutoConfig, AutoTokenizer from llmc.utils.registry_factory import MODEL_REGISTRY @@ -17,8 +19,11 @@ try: from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, - DEFAULT_IMAGE_PATCH_TOKEN, IMAGE_TOKEN_INDEX) - from llava.mm_utils import get_model_name_from_path + DEFAULT_IMAGE_PATCH_TOKEN, + DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) + from llava.conversation import SeparatorStyle, conv_templates + from llava.mm_utils import (get_model_name_from_path, process_images, + tokenizer_image_token) from llava.model.builder import load_pretrained_model from llava.model.language_model.llava_llama import LlavaConfig except Exception as e: @@ -45,7 +50,7 @@ def build_model(self): self.vlm_model_config.use_cache = True logger.info(f'self.vlm_model_config : {self.vlm_model_config}') - self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model( + self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model( self.model_path, None, get_model_name_from_path(self.model_path), @@ -137,6 +142,96 @@ def get_subsets_in_block(self, block): else: raise Exception(f'Llava do not support {self.get_modality()} modality.') + def eval_custom_samples_just_infer( + self, + img_qas, + eval_cfg + ): # noqa + + custom_samples_ans = img_qas.copy() + + self.vlm_model.cuda() + + def load_image(image_file): + image = Image.open(image_file).convert('RGB') + return image + + def load_images(image_files): + out = [] + for image_file in image_files: + image = load_image(image_file) + out.append(image) + return out + + self.first_turn_question = True + + for data_idx, questions in enumerate(img_qas): + self.first_turn_question = True + + custom_samples_ans[data_idx]['answer'] = [] + + image_files = questions['image'] + image_files = [os.path.join(eval_cfg.path, 'images', image_file) for image_file in image_files] # noqa + images = load_images(image_files) + image_sizes = [x.size for x in images] + images_tensor = process_images( + images, + self.image_processor, + self.vlm_model.config + ).to(self.vlm_model.device, dtype=torch.float16) + + input_ids_old = None + + for question_idx, question in enumerate(questions['question']): + + conv_mode = 'llava_v1' + conv = conv_templates[conv_mode].copy() + if question_idx > 0: + conv.system = '' + qs = question + self.first_turn_question = False + else: + qs = DEFAULT_IMAGE_TOKEN + '\n' + question + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() # noqa + # print(f"input_ids 1: {input_ids}, {input_ids.shape}") + if input_ids_old is not None: + input_ids = torch.cat((input_ids_old, input_ids), dim=1) + # print(f"input_ids 2: {input_ids}, {input_ids.shape}") + + with torch.inference_mode(): + output_ids = self.vlm_model.generate( + input_ids, + attention_mask=input_ids.new_ones(input_ids.shape, dtype=torch.bool), + images=images_tensor, + image_sizes=image_sizes, + do_sample=False, + top_p=None, + num_beams=1, + max_new_tokens=eval_cfg.max_new_tokens, + use_cache=True, + ) + + # print(f"output_ids: {output_ids}, {output_ids.shape}") + + outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + print('--------------------------------') + print(f'data_idx: {data_idx}') + print(f'question_idx: {question_idx}') + print(f'question: {question}') + print(f'outputs: {outputs}') + print('--------------------------------') + + custom_samples_ans[data_idx]['answer'].append(outputs[0]) + + input_ids_old = torch.cat((input_ids, output_ids), dim=1) + + return custom_samples_ans + if version.parse(torch.__version__) >= version.parse('2.1.2'): best_fit_attn_implementation = 'sdpa'