-
Notifications
You must be signed in to change notification settings - Fork 66
support multi turn questions #420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| from types import MethodType | ||
|
|
||
| import torch | ||
| from loguru import logger | ||
|
|
||
| from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY | ||
|
|
||
|
|
@@ -62,22 +63,37 @@ def input_hook(module, input_args, pruning_paras): | |
| @prefill_wrapper | ||
| def random_pruning_hook(module, args, kwargs, pruning_paras): | ||
|
|
||
| logger.info(' ========random_pruning_hook======== ') | ||
|
|
||
| rate = pruning_paras['rate'] | ||
| image_token_start_index = pruning_paras['image_token_start_index'] | ||
| image_token_length = pruning_paras['image_token_length'] | ||
|
|
||
| hidden_states = args[0] | ||
| causal_mask = kwargs['attention_mask'] | ||
|
|
||
| logger.info(f'before hidden_states : {hidden_states.shape}') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| device = hidden_states.device | ||
| vision_indexes = torch.arange( | ||
| image_token_start_index, | ||
| image_token_start_index + image_token_length, | ||
| device=device, | ||
| ) | ||
| num_keep = round(image_token_length * (1 - rate)) | ||
| rand_idx = torch.randperm(image_token_length, device=device)[:num_keep] | ||
| vision_indexes = vision_indexes[rand_idx] | ||
|
|
||
| if self.model.first_turn_question: | ||
| logger.info(' -----first_turn_question-----') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| vision_indexes = torch.arange( | ||
| image_token_start_index, | ||
| image_token_start_index + image_token_length, | ||
| device=device, | ||
| ) | ||
| num_keep = round(image_token_length * (1 - rate)) | ||
| rand_idx = torch.randperm(image_token_length, device=device)[:num_keep] | ||
| vision_indexes = vision_indexes[rand_idx] | ||
|
|
||
| # save vision_indexes to module | ||
| module.register_buffer('vision_indexes', vision_indexes) | ||
| else: | ||
| logger.info(' -----not first_turn_question-----') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| # load vision_indexes from module (prompt cache) | ||
| vision_indexes = module.vision_indexes | ||
|
|
||
| # keep index | ||
| keep_indexs = torch.cat( | ||
| ( | ||
|
|
@@ -115,6 +131,7 @@ def random_pruning_hook(module, args, kwargs, pruning_paras): | |
| position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0) | ||
| position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1) | ||
|
|
||
| logger.info(f'after hidden_states : {hidden_states.shape}') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return (hidden_states,), kwargs | ||
|
|
||
| if self.model.__class__.__name__ == 'LlavaHf': | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,36 @@ | ||||||
| import glob | ||||||
| import json | ||||||
| import os | ||||||
|
|
||||||
| import torch | ||||||
| from human_eval.data import stream_jsonl, write_jsonl | ||||||
| from human_eval.evaluation import evaluate_functional_correctness | ||||||
| from loguru import logger | ||||||
| from tqdm import tqdm | ||||||
|
|
||||||
| from .eval_base import BaseEval | ||||||
|
|
||||||
|
|
||||||
| class CustomGenerateJustInfer: | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| def __init__(self, model, config): | ||||||
| self.model = model | ||||||
| self.config = config | ||||||
| self.eval_cfg = config.eval | ||||||
|
|
||||||
| @torch.no_grad() | ||||||
| def eval(self, model, eval_pos=None): | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| logger.info('start inference') | ||||||
|
|
||||||
| with open(os.path.join(self.eval_cfg.path, 'samples.json'), 'r') as f: | ||||||
| questions_list = json.load(f) | ||||||
|
|
||||||
| custom_samples_ans = self.model.eval_custom_samples_just_infer( | ||||||
| questions_list, | ||||||
| self.eval_cfg | ||||||
| ) | ||||||
|
|
||||||
| with open(os.path.join('custom_samples_ans.json'), 'w') as f: | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The output filename
Suggested change
|
||||||
| json.dump(custom_samples_ans, f, indent=4) | ||||||
|
|
||||||
| torch.cuda.empty_cache() | ||||||
| return 'custom gen done.' | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,3 +1,4 @@ | ||||||
| import os | ||||||
| import types | ||||||
| from datetime import timedelta | ||||||
| from typing import Optional, Union | ||||||
|
|
@@ -9,6 +10,7 @@ | |||||
| from lmms_eval.models.llava import Llava as LLaVA | ||||||
| from loguru import logger | ||||||
| from packaging import version | ||||||
| from PIL import Image | ||||||
| from transformers import AutoConfig, AutoTokenizer | ||||||
|
|
||||||
| from llmc.utils.registry_factory import MODEL_REGISTRY | ||||||
|
|
@@ -17,8 +19,11 @@ | |||||
|
|
||||||
| try: | ||||||
| from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, | ||||||
| DEFAULT_IMAGE_PATCH_TOKEN, IMAGE_TOKEN_INDEX) | ||||||
| from llava.mm_utils import get_model_name_from_path | ||||||
| DEFAULT_IMAGE_PATCH_TOKEN, | ||||||
| DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) | ||||||
| from llava.conversation import SeparatorStyle, conv_templates | ||||||
| from llava.mm_utils import (get_model_name_from_path, process_images, | ||||||
| tokenizer_image_token) | ||||||
| from llava.model.builder import load_pretrained_model | ||||||
| from llava.model.language_model.llava_llama import LlavaConfig | ||||||
| except Exception as e: | ||||||
|
|
@@ -45,7 +50,7 @@ def build_model(self): | |||||
| self.vlm_model_config.use_cache = True | ||||||
| logger.info(f'self.vlm_model_config : {self.vlm_model_config}') | ||||||
|
|
||||||
| self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model( | ||||||
| self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model( | ||||||
| self.model_path, | ||||||
| None, | ||||||
| get_model_name_from_path(self.model_path), | ||||||
|
|
@@ -137,6 +142,96 @@ def get_subsets_in_block(self, block): | |||||
| else: | ||||||
| raise Exception(f'Llava do not support {self.get_modality()} modality.') | ||||||
|
|
||||||
| def eval_custom_samples_just_infer( | ||||||
| self, | ||||||
| img_qas, | ||||||
| eval_cfg | ||||||
| ): # noqa | ||||||
|
|
||||||
| custom_samples_ans = img_qas.copy() | ||||||
|
|
||||||
| self.vlm_model.cuda() | ||||||
|
|
||||||
| def load_image(image_file): | ||||||
| image = Image.open(image_file).convert('RGB') | ||||||
| return image | ||||||
|
|
||||||
| def load_images(image_files): | ||||||
| out = [] | ||||||
| for image_file in image_files: | ||||||
| image = load_image(image_file) | ||||||
| out.append(image) | ||||||
| return out | ||||||
|
|
||||||
| self.first_turn_question = True | ||||||
|
|
||||||
| for data_idx, questions in enumerate(img_qas): | ||||||
| self.first_turn_question = True | ||||||
|
|
||||||
| custom_samples_ans[data_idx]['answer'] = [] | ||||||
|
|
||||||
| image_files = questions['image'] | ||||||
| image_files = [os.path.join(eval_cfg.path, 'images', image_file) for image_file in image_files] # noqa | ||||||
| images = load_images(image_files) | ||||||
| image_sizes = [x.size for x in images] | ||||||
| images_tensor = process_images( | ||||||
| images, | ||||||
| self.image_processor, | ||||||
| self.vlm_model.config | ||||||
| ).to(self.vlm_model.device, dtype=torch.float16) | ||||||
|
|
||||||
| input_ids_old = None | ||||||
|
|
||||||
| for question_idx, question in enumerate(questions['question']): | ||||||
|
|
||||||
| conv_mode = 'llava_v1' | ||||||
| conv = conv_templates[conv_mode].copy() | ||||||
| if question_idx > 0: | ||||||
| conv.system = '' | ||||||
| qs = question | ||||||
| self.first_turn_question = False | ||||||
| else: | ||||||
| qs = DEFAULT_IMAGE_TOKEN + '\n' + question | ||||||
| conv.append_message(conv.roles[0], qs) | ||||||
| conv.append_message(conv.roles[1], None) | ||||||
| prompt = conv.get_prompt() | ||||||
|
|
||||||
| input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() # noqa | ||||||
| # print(f"input_ids 1: {input_ids}, {input_ids.shape}") | ||||||
| if input_ids_old is not None: | ||||||
| input_ids = torch.cat((input_ids_old, input_ids), dim=1) | ||||||
| # print(f"input_ids 2: {input_ids}, {input_ids.shape}") | ||||||
|
|
||||||
| with torch.inference_mode(): | ||||||
| output_ids = self.vlm_model.generate( | ||||||
| input_ids, | ||||||
| attention_mask=input_ids.new_ones(input_ids.shape, dtype=torch.bool), | ||||||
| images=images_tensor, | ||||||
| image_sizes=image_sizes, | ||||||
| do_sample=False, | ||||||
| top_p=None, | ||||||
| num_beams=1, | ||||||
| max_new_tokens=eval_cfg.max_new_tokens, | ||||||
| use_cache=True, | ||||||
| ) | ||||||
|
|
||||||
| # print(f"output_ids: {output_ids}, {output_ids.shape}") | ||||||
|
|
||||||
| outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) | ||||||
|
|
||||||
| print('--------------------------------') | ||||||
| print(f'data_idx: {data_idx}') | ||||||
| print(f'question_idx: {question_idx}') | ||||||
| print(f'question: {question}') | ||||||
| print(f'outputs: {outputs}') | ||||||
| print('--------------------------------') | ||||||
|
Comment on lines
+222
to
+227
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| custom_samples_ans[data_idx]['answer'].append(outputs[0]) | ||||||
|
|
||||||
| input_ids_old = torch.cat((input_ids, output_ids), dim=1) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The history
Suggested change
|
||||||
|
|
||||||
| return custom_samples_ans | ||||||
|
|
||||||
|
|
||||||
| if version.parse(torch.__version__) >= version.parse('2.1.2'): | ||||||
| best_fit_attn_implementation = 'sdpa' | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider removing this
logger.infocall as it seems to be for debugging purposes. Leaving it in could clutter the logs.