diff --git a/MMLongBench-Doc/README.md b/MMLongBench-Doc/README.md new file mode 100644 index 000000000..46e9b4be7 --- /dev/null +++ b/MMLongBench-Doc/README.md @@ -0,0 +1,113 @@ +

+

MMLONGBENCH-DOC: Benchmarking Long-context Document Understanding with Visualizations

+

+ Yubo Ma + · + Yuhang Zang + · + Liangyu Chen + · + Meiqi Chen + · + Yizhu Jiao + · + Xinze Li + · + Xinyuan Lu + · + Ziyu Liu + · + Yan Ma + · + Xiaoyi Dong + · + Pan Zhang + · + Liangming Pan + . + Yu-Gang Jiang + . + Jiaqi Wang + . + Yixin Cao + . + Aixin Sun +

+ + 📖Paper |🏠Homepage|🤗Huggingface +
+

+

+The automatic understanding of lengthy documents (Long-context Document Understanding; DU) stands as a long-standing task in urgent and practical needs. Although many LVLMs now claim (and show promising cases) their capabilities on long-context DU, there lacks a unified and quantitative evaluation of existing models due to the absence of related benchmark.
+To bridge this gap, we construct MMLongBench-Doc which comprises 135 documents and 1091 qustions (each accompanied by a short, deterministic reference answer and detailed meta information.). The documents have an average of 47.5 pages and 21,214 tokens, cover 7 diverse domains, and are PDF-formatted with rich layouts and multi-modal components. The questions are either curated from existing datasets or newly-annotated by expert-level annotators. Towards a comprehensive evaluation, the questions cover different sources like text, table, chart, image, etc., and different locations (page index) of the documents. Notably, 33.0% questions are cross-page questions necessitating comprehension and reasoning on evidences across multiple pages. And 22.5% questions are designed to be unanswerable for reducing the shortcuts in this benchmark and detecting LVLMs' hallucinations. +

+ + Logo + +
+ +## 📢 News +- 🚀 [07/2024] We further refine and update the questions in MMLongBench-Doc! +- 🚀 [07/2024] We have integrated MMLongBench-Doc to evaluation toolkit [VLMEvalKit](https://github.com/open-compass/VLMEvalKit), providing a highly convenient testing solution! +- 🚀 [06/2024] We upload MMLongBench-Doc to huggingface. + +## 💡 Highlights +- 🔥 **Multi-modality**: All selected documents are PDF-formatted with rich layouts and multi-modal components including text, table, chart and image. We annotate questions carefully from these multi-modal evidences. +- 🔥 **Long-context**: Each document has an average of 47.5 pages and 21,214 tokens. Additionally, 33.0% of the questions are cross-page questions which necessitate the information collection and reasoning over multiple pages. +- 🔥 **Challenging**: Experiments on 14 LVLMs demonstrate that long-context document understanding greatly challenges current models. Even the best-performing LVLM, GPT-4o, achieves an overall F1 score of only 44.9%. + +## Dataset +We save our benchmark, including both questions and documents, in `./data`. +* The questions are provided in json format and contain the following attributes: +``` + { + "doc_id": "Independents-Report.pdf", + "doc_type": "Research report / Introduction", + "question": "What's the percentage of people who are democrats and voted in the last election compared to the entire population in 2018?", + "answer": "18.29%", + "evidence_pages": "[3, 5]", + "evidence_sources": "['Pure-text (Plain-text)']", + "answer_format": "Float", + } +``` +* The documents are saved in `./data/documents` as the format of PDF files. + +You can also download this dataset by the following command (make sure that you have installed Huggingface Datasets): +``` +from datasets import load_dataset +samples = load_dataset("yubo2333/MMLongBench-Doc/data")["train"] +``` + +## 🛠️ Usage +### Environment +``` +python 3.9 +2.1.2+cu121 +``` +You can install other dependencies by `pip install -r requirements.txt`. + + +### Quick Use +``` +MODEL_NAME=[gpt-4o|gpt-4-turbo|gemini-1.5-pro-latest|internvl|4khd|minicpm_llama3] bash run.sh +``` +Note that +* `OPENAI_API_KEY` should be set no matter what models you are evaluating because we adopt a three-stage evaluation protocol as detailed in Section 4.1 of [our paper](https://arxiv.org/abs/2407.01523). The conversion from a long-form response to a short-form prediction necessitates GPT-4o's involving. +* We now support various popular open-source and closed-source LVLMs, including **GPT-4o**, **GPT-4V**, **Gemini-Pro-1.5**,**InternLM-Xcomposer2-4KHD**, **Intern-VL-Chat-v1.5** and **MiniCPM-Llama3-V2.5**. More LVLMs will be supported in the near future (we are cleaning related code). + +## ✒️Citation +``` +@misc{ma2024mmlongbenchdocbenchmarkinglongcontextdocument, + title={MMLongBench-Doc: Benchmarking Long-context Document Understanding with Visualizations}, + author={Yubo Ma and Yuhang Zang and Liangyu Chen and Meiqi Chen and Yizhu Jiao and Xinze Li and Xinyuan Lu and Ziyu Liu and Yan Ma and Xiaoyi Dong and Pan Zhang and Liangming Pan and Yu-Gang Jiang and Jiaqi Wang and Yixin Cao and Aixin Sun}, + year={2024}, + eprint={2407.01523}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2407.01523}, +} +``` + +## 📄 License +![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg) ![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg) **Usage and License Notices**: The data and code are intended and licensed for research use only. +License: Attribution-NonCommercial 4.0 International It should abide by the policy of OpenAI: https://openai.com/policies/terms-of-use diff --git a/MMLongBench-Doc/asset/top_figure.png b/MMLongBench-Doc/asset/top_figure.png new file mode 100644 index 000000000..c214a9405 Binary files /dev/null and b/MMLongBench-Doc/asset/top_figure.png differ diff --git a/MMLongBench-Doc/eval/__init__.py b/MMLongBench-Doc/eval/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/MMLongBench-Doc/eval/eval_score.py b/MMLongBench-Doc/eval/eval_score.py new file mode 100644 index 000000000..747ee2134 --- /dev/null +++ b/MMLongBench-Doc/eval/eval_score.py @@ -0,0 +1,260 @@ +import re + +from collections import defaultdict +from math import isclose + + +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] + + +def anls_compute(groundtruth, prediction, threshold=0.5): + dist = levenshtein_distance(groundtruth, prediction) + length = max(len(groundtruth.upper()), len(prediction.upper())) + value = 0.0 if length == 0 else float(dist) / float(length) + anls = 1.0 - value + if anls <= threshold: + anls = 0.0 + return anls + + +def is_float_equal( + reference, prediction, include_percentage: bool = False, is_close: float = False +) -> bool: + def get_precision(gt_ans: float) -> int: + precision = 3 + if "." in str(gt_ans): + precision = len(str(gt_ans).split(".")[-1]) + return precision + + reference = float(str(reference).strip().rstrip("%").strip()) + try: + prediction = float(str(prediction).strip().rstrip("%").strip()) + except: + return False + + if include_percentage: + gt_result = [reference / 100, reference, reference * 100] + else: + gt_result = [reference] + for item in gt_result: + try: + if is_close: + if isclose(item, prediction, rel_tol=0.01): + return True + precision = max(min(get_precision(prediction), get_precision(item)), 2) + if round(prediction, precision) == round(item, precision): + return True + except Exception: + continue + return False + + +def get_clean_string(s): + s = str(s).lower().strip() + if s.endswith("mile"): + s.rstrip("mile").strip() + if s.endswith("miles"): + s.rstrip("miles").strip() + if s.endswith("million"): + s.rstrip("million").strip() + # remove parenthesis + s = re.sub(r"\s*\([^)]*\)", "", s).strip() + # remove quotes + s = re.sub(r"^['\"]|['\"]$", "", s).strip() + s = s.strip().lstrip("$").strip() + s = s.strip().rstrip("%").strip() + return s + + +def is_exact_match(s): + flag = False + # Website + if "https://" in s: + flag = True + # code file + if s.endswith(".py") or s.endswith("ipynb"): + flag = True + if s.startswith("page"): + flag = True + # telephone number + if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s): + flag = True + # time + if "a.m." in s or "p.m." in s: + flag = True + # YYYY-MM-DD + if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s): + flag = True + # YYYY-MM + if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s): + flag = True + # Email address + if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s): + flag = True + return flag + + +def isfloat(num): + try: + float(num) + return True + except ValueError: + return False + + +def eval_score(gt, pred, answer_type): + if answer_type == "Int": + try: + gt, pred = int(gt), int(float(pred)) + except: + pred = "" + score = gt == pred + elif answer_type == "Float": + try: + gt = float(get_clean_string(str(gt))) + pred = float(get_clean_string(str(pred))) + except: + pred = "" + score = is_float_equal(gt, pred, include_percentage=True, is_close=True) + elif answer_type in ["Str", "None"]: + gt = get_clean_string(gt) + pred = get_clean_string(pred) + if is_exact_match(gt): + score = gt == pred + else: + score = anls_compute(gt, pred) + else: + if isinstance(gt, str) and gt.startswith("["): + gt = eval(gt) + if not isinstance(gt, list): + gt = [gt] + if isinstance(pred, str) and pred.startswith("["): + pred = eval(pred) + if not isinstance(pred, list): + pred = [pred] + print(len(gt), len(pred)) + if len(gt) != len(pred): + score = 0.0 + else: + gt = sorted([get_clean_string(a) for a in gt]) + pred = sorted([get_clean_string(a) for a in pred]) + print(gt, pred) + if isfloat(gt[0]) or is_exact_match(gt[0]): + score = "-".join(gt) == "-".join(pred) + else: + score = min( + [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)] + ) + + return float(score) + + +def eval_acc_and_f1(samples): + evaluated_samples = [sample for sample in samples if "score" in sample] + if not evaluated_samples: + return 0.0, 0.0 + + acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples) + try: + recall = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"]) + precision = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"]) + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0 + except: + f1 = 0.0 + + return acc, f1 + + +def show_results(samples, show_path=None): + for sample in samples: + sample["evidence_pages"] = eval(sample["evidence_pages"]) + sample["evidence_sources"] = eval(sample["evidence_sources"]) + + with open(show_path, "w") as f: + acc, f1 = eval_acc_and_f1(samples) + f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n") + f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n") + f.write("-----------------------\n") + + ##################### + acc_single_page, _ = eval_acc_and_f1( + [sample for sample in samples if len(sample["evidence_pages"]) == 1] + ) + acc_multi_page, _ = eval_acc_and_f1( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable" + ] + ) + acc_neg, _ = eval_acc_and_f1( + [sample for sample in samples if sample["answer"] == "Not answerable"] + ) + + f.write( + "Single-page | Accuracy: {} | Question Number: {}\n".format( + acc_single_page, + len([sample for sample in samples if len(sample["evidence_pages"]) == 1]), + ) + ) + f.write( + "Cross-page | Accuracy: {} | Question Number: {}\n".format( + acc_multi_page, + len( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 + and sample["answer"] != "Not answerable" + ] + ), + ) + ) + f.write( + "Unanswerable | Accuracy: {} | Question Number: {}\n".format( + acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"]) + ) + ) + f.write("-----------------------\n") + + ##################### + source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list) + for sample in samples: + for answer_source in sample["evidence_sources"]: + source_sample_dict[answer_source].append(sample) + document_type_dict[sample["doc_type"]].append(sample) + for type, sub_samples in source_sample_dict.items(): + f.write( + f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) + + f.write("-----------------------\n") + for type, sub_samples in document_type_dict.items(): + f.write( + f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) diff --git a/MMLongBench-Doc/eval/extract_answer.py b/MMLongBench-Doc/eval/extract_answer.py new file mode 100644 index 000000000..65b56ccb2 --- /dev/null +++ b/MMLongBench-Doc/eval/extract_answer.py @@ -0,0 +1,30 @@ +import os + +import openai + + +client = openai.Client( + api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +) + + +def extract_answer(question, output, prompt, model_name="gpt-4o"): + response = client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": prompt, + }, + {"role": "assistant", "content": f"\n\nQuestion:{question}\nAnalysis:{output}\n"}, + ], + temperature=0.0, + max_tokens=256, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + ) + response = response.choices[0].message.content + + return response diff --git a/MMLongBench-Doc/eval/prompt_for_answer_extraction.md b/MMLongBench-Doc/eval/prompt_for_answer_extraction.md new file mode 100644 index 000000000..a309c0935 --- /dev/null +++ b/MMLongBench-Doc/eval/prompt_for_answer_extraction.md @@ -0,0 +1,35 @@ +Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis. +- Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List. If you find the analysis the question can not be answered from the given documents, type "Not answerable". Exception: If the analysis only tells you that it can not read/understand the images or documents, type "Fail to answer". +- Please make your response as concise as possible. Also note that your response should be formatted as below: +``` +Extracted answer: [answer] +Answer format: [answer format] +``` + +Please read the following example, then extract the answer from the model response and type it at the end of the prompt. + +--- +Question: List the primary questions asked about the services in this report. +Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n1. Is the service safe?\n2. Is the service effective?\n3. Is the service caring?\n4. Is the service responsive?\n5. Is the service well-led? +Extracted answer: ['Is the servife safe?', 'Is the service effective', 'Is the serve caring?', 'Is the service responsive?', 'Is the service well-led?'] +Answer format: List + +--- +Question: How many regulations of the HSCA 2008 are breached in all according to this report? +Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities) Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11: Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17: Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9. Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing, safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to notify the CQC of incidents. +Extracted answer: 10 +Answer format: Integer + +--- +Question: According to the survey that is the percentage of Chinese who are paying more or about the same attention to politics after Trump's election? +Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying more or about the same attention to politics after Trump's election. The report focuses primarily on American demographics and does not include specific details about the Chinese population in relation to this question. If you need information about a different demographic or a summary of the findings from the American demographic, I can certainly help with that! +Extracted answer: Not answerable +Answer format: String + +--- +Question: How many quotations from male respondent over 50 years old are included in this report? +Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be able to help you with your question. +Extracted answer: Fail to answer +Answer format: String + +--- diff --git a/MMLongBench-Doc/models/__init__.py b/MMLongBench-Doc/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/MMLongBench-Doc/models/internlm_xc2_4khd.py b/MMLongBench-Doc/models/internlm_xc2_4khd.py new file mode 100644 index 000000000..97042563b --- /dev/null +++ b/MMLongBench-Doc/models/internlm_xc2_4khd.py @@ -0,0 +1,125 @@ +import torch +import torch.nn.functional as F + +from transformers import AutoModel, AutoTokenizer + + +torch.set_grad_enabled(False) + + +try: + from transformers.generation.streamers import BaseStreamer +except: # noqa # pylint: disable=bare-except + BaseStreamer = None + + +def chat( + model, + tokenizer, + query: str, + image: None, + hd_num: int = 25, + history: list[tuple[str, str]] = [], + streamer: BaseStreamer | None = None, + max_new_tokens: int = 1024, + temperature: float = 1.0, + top_p: float = 0.8, + repetition_penalty: float = 1.005, + meta_instruction: str = "You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n" + "- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by the user such as English and 中文.\n" + "- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively based on the provided image.", + **kwargs, +): + if image is None: + inputs = model.build_inputs(tokenizer, query, history, meta_instruction) + im_mask = torch.zeros(inputs["input_ids"].shape[:2]).cuda().bool() + else: + if type(image) == str: + with torch.cuda.amp.autocast(): + image = model.encode_img(image, hd_num=hd_num) + inputs, im_mask = model.interleav_wrap_chat( + tokenizer, query, image, history, meta_instruction + ) + elif type(image) == list: + image_list = [] + with torch.cuda.amp.autocast(): + for image_path in image: + tmp = model.encode_img(image_path, hd_num=hd_num) + image_list.append(tmp) + if len(image_list) > 1 and image_list[-1].shape[1] != image_list[-2].shape[1]: + image_list[-1] = F.interpolate( + image_list[-1].unsqueeze(1), size=image_list[-2].shape[1:], mode="bilinear" + ).squeeze(1) + image = torch.cat(image_list, dim=0) + with torch.cuda.amp.autocast(): + inputs, im_mask = model.interleav_wrap_chat( + tokenizer, query, image, history, meta_instruction + ) + else: + raise NotImplementedError + inputs = {k: v.to(model.device) for k, v in inputs.items() if torch.is_tensor(v)} + # also add end-of-assistant token in eos token id to avoid unnecessary generation + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids(["[UNUSED_TOKEN_145]"])[0], + ] + # print(inputs['inputs_embeds'].shape[1]) + with torch.cuda.amp.autocast(): + outputs = model.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=False if temperature == 0.0 else True, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + repetition_penalty=repetition_penalty, + im_mask=im_mask, + **kwargs, + ) + if image is None: + outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :] + else: + outputs = outputs[0].cpu().tolist() + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.split("[UNUSED_TOKEN_145]")[0] + history = history + [(query, response)] + return response, history + + +def init_model(cache_path): + model_path = ( + cache_path + if (cache_path is not None and cache_path != "None") + else "internlm/internlm-xcomposer2-4khd-7b" + ) + model = AutoModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + device_map="auto", + ).eval() + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.tokenizer = tokenizer + return model + + +def get_response_concat(model, question, image_path_list, max_new_tokens=1024, temperature=1.0): + query = " " * len(image_path_list) + question + try: + response, _ = chat( + model, + model.tokenizer, + query=query, + image=image_path_list, + max_new_tokens=max_new_tokens, + hd_num=16, + temperature=temperature, + ) + except Exception as e: + print(e) + response = "Failed" + return response diff --git a/MMLongBench-Doc/models/internvl_chat.py b/MMLongBench-Doc/models/internvl_chat.py new file mode 100644 index 000000000..cee142a5e --- /dev/null +++ b/MMLongBench-Doc/models/internvl_chat.py @@ -0,0 +1,139 @@ +import torch +import torchvision.transforms as T + +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoTokenizer + + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) + return transform + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def load_image(image_file, input_size=448, max_num=6): + image = Image.open(image_file).convert("RGB") + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + +def init_model(cache_path): + import os + + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + + model_path = ( + cache_path + if (cache_path is not None and cache_path != "None") + else "OpenGVLab/InternVL-Chat-V1-5" + ) + model = AutoModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + device_map="auto", + ).eval() + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.tokenizer = tokenizer + return model + + +def get_response_concat( + model, question, image_path_list, max_new_tokens=1024, temperature=1.0, max_num=6 +): + generation_config = dict( + num_beams=1, + max_new_tokens=max_new_tokens, + do_sample=False if temperature == 0.0 else True, + temperature=temperature, + ) + pixel_values_list = [ + load_image(image_path, max_num=max_num).to(torch.bfloat16).cuda() + for image_path in image_path_list + ] + pixel_values = torch.cat(pixel_values_list, dim=0) + response, _ = model.chat( + model.tokenizer, + pixel_values, + question, + generation_config, + history=None, + return_history=True, + ) + return response diff --git a/MMLongBench-Doc/models/minicpm_llama3.py b/MMLongBench-Doc/models/minicpm_llama3.py new file mode 100644 index 000000000..4dcb56182 --- /dev/null +++ b/MMLongBench-Doc/models/minicpm_llama3.py @@ -0,0 +1,56 @@ +import torch + +from PIL import Image +from transformers import AutoModel, AutoTokenizer + + +def init_model(cache_path): + model_path = ( + cache_path + if (cache_path is not None and cache_path != "None") + else "openbmb/MiniCPM-Llama3-V-2_5" + ) + model = AutoModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + device_map="auto", + ).eval() + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model.tokenizer = tokenizer + return model + + +def get_response_concat(model, question, image_path_list, max_new_tokens=1024, temperature=1.0): + msgs = [] + system_prompt = "Answer in detail." + if system_prompt: + msgs.append(dict(type="text", value=system_prompt)) + if isinstance(image_path_list, list): + msgs.extend([dict(type="image", value=p) for p in image_path_list]) + else: + msgs = [dict(type="image", value=image_path_list)] + msgs.append(dict(type="text", value=question)) + + content = [] + for x in msgs: + if x["type"] == "text": + content.append(x["value"]) + elif x["type"] == "image": + image = Image.open(x["value"]).convert("RGB") + content.append(image) + msgs = [{"role": "user", "content": content}] + + with torch.cuda.amp.autocast(): + res = model.chat( + msgs=msgs, + context=None, + image=None, + max_new_tokens=max_new_tokens, + temperature=temperature, + do_sample=False if temperature == 0.0 else True, + tokenizer=model.tokenizer, + ) + return res diff --git a/MMLongBench-Doc/run_api.py b/MMLongBench-Doc/run_api.py new file mode 100644 index 000000000..c9e49c787 --- /dev/null +++ b/MMLongBench-Doc/run_api.py @@ -0,0 +1,242 @@ +import os +import time +import uuid + +from datetime import datetime + +from dotenv import load_dotenv + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS + + +load_dotenv() +import argparse +import json +import re + +from eval.eval_score import eval_acc_and_f1, eval_score, show_results +from eval.extract_answer import extract_answer +from tqdm import tqdm + + +cached_image_list = dict() + + +# 1. Create MOS Config and set openai config +print(f"🚀 [{datetime.now().strftime('%H:%M:%S')}] Starting to create MOS configuration...") +start_time = time.time() + +user_name = str(uuid.uuid4()) +print(user_name) + +# 1.1 Set openai config +openapi_config = { + "model_name_or_path": "gpt-4o-mini", + "temperature": 0.8, + "max_tokens": 1024, + "top_p": 0.9, + "top_k": 50, + "remove_think_prefix": True, + "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +} +# 1.2 Set neo4j config +neo4j_uri = os.getenv("NEO4J_URI", "bolt://47.117.41.207:7687") + +# 1.3 Create MOS Config +config = { + "user_id": user_name, + "chat_model": { + "backend": "openai", + "config": openapi_config, + }, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": { + "backend": "openai", + "config": openapi_config, + }, + "embedder": { + "backend": "ollama", + "config": { + "model_name_or_path": "nomic-embed-text:latest", + }, + }, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, +} + +mos_config = MOSConfig(**config) +# you can set PRO_MODE to True to enable CoT enhancement mos_config.PRO_MODE = True +mos = MOS(mos_config) + +print( + f"✅ [{datetime.now().strftime('%H:%M:%S')}] MOS configuration created successfully, time elapsed: {time.time() - start_time:.2f}s\n" +) + +# 2. Initialize memory cube +print(f"🚀 [{datetime.now().strftime('%H:%M:%S')}] Starting to initialize MemCube configuration...") +start_time = time.time() + +config = GeneralMemCubeConfig.model_validate( + { + "user_id": user_name, + "cube_id": f"{user_name}", + "text_mem": { + "backend": "tree_text", + "config": { + "extractor_llm": { + "backend": "openai", + "config": openapi_config, + }, + "dispatcher_llm": { + "backend": "openai", + "config": openapi_config, + }, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": neo4j_uri, + "user": "neo4j", + "password": "12345678", + "db_name": f"db{user_name.replace('-', '')}", + "auto_create": True, + }, + }, + "embedder": { + "backend": "ollama", + "config": { + "model_name_or_path": "nomic-embed-text:latest", + }, + }, + }, + }, + "act_mem": {}, + "para_mem": {}, + }, +) + +print( + f"✅ [{datetime.now().strftime('%H:%M:%S')}] MemCube configuration initialization completed, time elapsed: {time.time() - start_time:.2f}s\n" +) + +# 3. Initialize the MemCube with the configuration +print(f"🚀 [{datetime.now().strftime('%H:%M:%S')}] Starting to create MemCube instance...") +start_time = time.time() + +mem_cube = GeneralMemCube(config) +try: + mem_cube.dump(f"/tmp/{user_name}/") + print( + f"✅ [{datetime.now().strftime('%H:%M:%S')}] MemCube created and saved successfully, time elapsed: {time.time() - start_time:.2f}s\n" + ) +except Exception as e: + print( + f"❌ [{datetime.now().strftime('%H:%M:%S')}] MemCube save failed: {e}, time elapsed: {time.time() - start_time:.2f}s\n" + ) + +# 4. Register the MemCube +print(f"🚀 [{datetime.now().strftime('%H:%M:%S')}] Starting to register MemCube...") +start_time = time.time() + +mos.register_mem_cube(f"/tmp/{user_name}", mem_cube_id=user_name) + +print( + f"✅ [{datetime.now().strftime('%H:%M:%S')}] MemCube registration completed, time elapsed: {time.time() - start_time:.2f}s\n" +) + +mos.add(doc_path="examples/data") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_path", type=str, default="MMLongBench-Doc/data/samples.json") + parser.add_argument("--document_path", type=str, default="MMLongBench-Doc/data/documents") + parser.add_argument("--model_name", type=str, default="gpt-4o") + parser.add_argument("--max_pages", type=int, default=120) + parser.add_argument("--resolution", type=int, default=144) + parser.add_argument("--max_try", type=int, default=10) + parser.add_argument("--max_tokens", type=int, default=1024) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument( + "--extractor_prompt_path", + type=str, + default="MMLongBench-Doc/eval/prompt_for_answer_extraction.md", + ) + args = parser.parse_args() + + args.output_path = "MMLongBench-Doc/res.json" + + with open(args.extractor_prompt_path) as f: + prompt = f.read() + if os.path.exists(args.output_path): + with open(args.output_path) as f: + samples = json.load(f) + else: + with open(args.input_path) as f: + samples = json.load(f) + + for sample in tqdm(samples): + if sample["evidence_sources"] != "['Pure-text (Plain-text)']": + continue + + messages = sample["question"] + + try_cnt = 0 + is_success = False + while True: + try: + mos.clear_messages() + response = mos.chat(messages) + is_success = True + except: + try_cnt += 1 + response = "Failed" + if is_success or try_cnt > args.max_try: + break + + sample["response"] = response + extracted_res = extract_answer(sample["question"], response, prompt) + sample["extracted_res"] = extracted_res + # try: + print("llm res:", extracted_res) + pred_ans = extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip() + score = eval_score(sample["answer"], pred_ans, sample["answer_format"]) + # except: + # pred_ans = "Failed to extract" + # score = 0.0 + sample["pred"] = pred_ans + sample["score"] = score + + acc, f1 = eval_acc_and_f1(samples) + print("--------------------------------------") + print("Question: {}".format(sample["question"])) + print("Response: {}".format(sample["response"])) + print( + "Gt: {}\tPred: {}\tScore: {}".format(sample["answer"], sample["pred"], sample["score"]) + ) + print(f"Avg acc: {acc}") + print(f"Avg f1: {f1}") + + with open(args.output_path, "w") as f: + json.dump(samples, f) + + show_results(samples, show_path=re.sub("\.json$", ".txt", args.output_path))