From dd34693fea55ba0279f39256be3a7e4238e8d574 Mon Sep 17 00:00:00 2001
From: stx <31013941@qq.com>
Date: Tue, 25 Nov 2025 16:10:54 +0800
Subject: [PATCH 1/4] feat: add MMLongbench-Doc, HotpotQA, xinyu evaluation
---
evaluation/scripts/hotpot/hotpot_eval.py | 224 ++++++++++++++
.../scripts/hotpot/hotpot_evaluate_v1.py | 151 ++++++++++
.../scripts/mmlongbench/eval/__init__.py | 0
.../scripts/mmlongbench/eval/eval_score.py | 246 +++++++++++++++
.../mmlongbench/eval/extract_answer.py | 33 +++
.../eval/prompt_for_answer_extraction.md | 35 +++
evaluation/scripts/mmlongbench/eval_docs.py | 265 +++++++++++++++++
evaluation/scripts/mmlongbench/import_docs.py | 88 ++++++
.../scripts/mmlongbench/models/__init__.py | 0
.../mmlongbench/models/minicpm_llama3.py | 56 ++++
.../scripts/mmlongbench/multimodal_test.py | 185 ++++++++++++
evaluation/scripts/xinyu/eval/__init__.py | 0
.../scripts/xinyu/eval/eval_score_llm.py | 279 ++++++++++++++++++
evaluation/scripts/xinyu/eval_docs.py | 228 ++++++++++++++
evaluation/scripts/xinyu/import_docs.py | 85 ++++++
15 files changed, 1875 insertions(+)
create mode 100644 evaluation/scripts/hotpot/hotpot_eval.py
create mode 100644 evaluation/scripts/hotpot/hotpot_evaluate_v1.py
create mode 100644 evaluation/scripts/mmlongbench/eval/__init__.py
create mode 100644 evaluation/scripts/mmlongbench/eval/eval_score.py
create mode 100644 evaluation/scripts/mmlongbench/eval/extract_answer.py
create mode 100644 evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md
create mode 100644 evaluation/scripts/mmlongbench/eval_docs.py
create mode 100644 evaluation/scripts/mmlongbench/import_docs.py
create mode 100644 evaluation/scripts/mmlongbench/models/__init__.py
create mode 100644 evaluation/scripts/mmlongbench/models/minicpm_llama3.py
create mode 100644 evaluation/scripts/mmlongbench/multimodal_test.py
create mode 100644 evaluation/scripts/xinyu/eval/__init__.py
create mode 100644 evaluation/scripts/xinyu/eval/eval_score_llm.py
create mode 100644 evaluation/scripts/xinyu/eval_docs.py
create mode 100644 evaluation/scripts/xinyu/import_docs.py
diff --git a/evaluation/scripts/hotpot/hotpot_eval.py b/evaluation/scripts/hotpot/hotpot_eval.py
new file mode 100644
index 000000000..05ff52349
--- /dev/null
+++ b/evaluation/scripts/hotpot/hotpot_eval.py
@@ -0,0 +1,224 @@
+import json
+import os
+import uuid
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from memos.configs.mem_cube import GeneralMemCubeConfig
+from memos.configs.mem_os import MOSConfig
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_os.main import MOS
+
+
+load_dotenv()
+
+db_name = "stx-hotpot-001"
+
+
+user_name = str(uuid.uuid4())
+
+# 1.1 Set openai config
+openapi_config = {
+ "model_name_or_path": "gpt-4o-mini",
+ "temperature": 0.8,
+ "max_tokens": 1024,
+ "top_p": 0.9,
+ "top_k": 50,
+ "remove_think_prefix": True,
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+}
+# 1.2 Set neo4j config
+neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
+
+# 1.3 Create MOS Config
+config = {
+ "user_id": user_name,
+ "chat_model": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "mem_reader": {
+ "backend": "simple_struct",
+ "config": {
+ "llm": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "embedder": {
+ "backend": "universal_api",
+ "config": {
+ "provider": "openai",
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "model_name_or_path": "text-embedding-3-large",
+ "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ },
+ },
+ "chunker": {
+ "backend": "sentence",
+ "config": {
+ "tokenizer_or_token_counter": "gpt2",
+ "chunk_size": 512,
+ "chunk_overlap": 128,
+ "min_sentences_per_chunk": 1,
+ },
+ },
+ },
+ },
+ "max_turns_window": 20,
+ "top_k": 5,
+ "enable_textual_memory": True,
+ "enable_activation_memory": False,
+ "enable_parametric_memory": False,
+}
+
+mos_config = MOSConfig(**config)
+# you can set PRO_MODE to True to enable CoT enhancement mos_config.PRO_MODE = True
+mos = MOS(mos_config)
+
+
+# Filter out embedding fields, keeping only necessary fields
+def filter_memory_data(memories_data):
+ filtered_data = {}
+ for key, value in memories_data.items():
+ if key == "text_mem":
+ filtered_data[key] = []
+ for mem_group in value:
+ # Check if it's the new data structure (list of TextualMemoryItem objects)
+ if "memories" in mem_group and isinstance(mem_group["memories"], list):
+ # New data structure: directly a list of TextualMemoryItem objects
+ filtered_memories = []
+ for memory_item in mem_group["memories"]:
+ # Create filtered dictionary
+ filtered_item = {
+ "id": memory_item.id,
+ "memory": memory_item.memory,
+ "metadata": {},
+ }
+ # Filter metadata, excluding embedding
+ if hasattr(memory_item, "metadata") and memory_item.metadata:
+ for attr_name in dir(memory_item.metadata):
+ if not attr_name.startswith("_") and attr_name != "embedding":
+ attr_value = getattr(memory_item.metadata, attr_name)
+ if not callable(attr_value):
+ filtered_item["metadata"][attr_name] = attr_value
+ filtered_memories.append(filtered_item)
+
+ filtered_group = {
+ "cube_id": mem_group.get("cube_id", ""),
+ "memories": filtered_memories,
+ }
+ filtered_data[key].append(filtered_group)
+ else:
+ # Old data structure: dictionary with nodes and edges
+ filtered_group = {
+ "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])}
+ }
+ for node in mem_group["memories"].get("nodes", []):
+ filtered_node = {
+ "id": node.get("id"),
+ "memory": node.get("memory"),
+ "metadata": {
+ k: v
+ for k, v in node.get("metadata", {}).items()
+ if k != "embedding"
+ },
+ }
+ filtered_group["memories"]["nodes"].append(filtered_node)
+ filtered_data[key].append(filtered_group)
+ else:
+ filtered_data[key] = value
+ return filtered_data
+
+
+config = GeneralMemCubeConfig.model_validate(
+ {
+ "user_id": user_name,
+ "cube_id": f"{user_name}",
+ "text_mem": {
+ "backend": "tree_text",
+ "config": {
+ "extractor_llm": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "dispatcher_llm": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "graph_db": {
+ "backend": "neo4j",
+ "config": {
+ "uri": neo4j_uri,
+ "user": "neo4j",
+ "password": "iaarlichunyu",
+ "db_name": db_name,
+ "auto_create": True,
+ },
+ },
+ "embedder": {
+ "backend": "universal_api",
+ "config": {
+ "provider": "openai",
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "model_name_or_path": "text-embedding-3-large",
+ "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ },
+ },
+ "reorganize": True,
+ },
+ },
+ "act_mem": {},
+ "para_mem": {},
+ },
+)
+
+mem_cube = GeneralMemCube(config)
+
+
+mos.register_mem_cube(f"/tmp/{user_name}", mem_cube_id=user_name)
+
+
+with open("evaluation/data/hotpot/hotpot_dev_distractor_v1.json") as f:
+ data = json.load(f)
+
+
+def build_context_text(context_list):
+ parts = []
+ for title, sentences in context_list:
+ text = " ".join(s.strip() for s in sentences if s.strip())
+ parts.append(f"{title}: {text}")
+ return "\n".join(parts)
+
+
+def build_and_ask(item):
+ qid = item["_id"]
+ question = item["question"]
+
+ for title, sentences in item["context"]:
+ text = " ".join(s.strip() for s in sentences if s.strip())
+ memory_content = f"{title}: {text}"
+ mos.add(memory_content=memory_content)
+
+ answer = mos.chat(question).strip()
+ return qid, answer
+
+
+pred_answers = {}
+
+with ThreadPoolExecutor(max_workers=5) as executor:
+ futures = {executor.submit(build_and_ask, item): item for item in data}
+ for future in tqdm(as_completed(futures), total=len(futures)):
+ try:
+ qid, answer = future.result()
+ pred_answers[qid] = answer
+ except Exception as e:
+ print(f"Error: {e}")
+
+predictions = {"answer": pred_answers, "sp": []}
+
+with open("evaluation/data/hotpot/output/dev_distractor_pred.json", "w") as f:
+ json.dump(predictions, f, ensure_ascii=False, indent=2)
diff --git a/evaluation/scripts/hotpot/hotpot_evaluate_v1.py b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py
new file mode 100644
index 000000000..d4d6e71e1
--- /dev/null
+++ b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py
@@ -0,0 +1,151 @@
+import re
+import string
+import sys
+
+from collections import Counter
+
+import ujson as json
+
+
+def normalize_answer(s):
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def f1_score(prediction, ground_truth):
+ normalized_prediction = normalize_answer(prediction)
+ normalized_ground_truth = normalize_answer(ground_truth)
+
+ zero_metric = (0, 0, 0)
+
+ if (
+ normalized_prediction in ["yes", "no", "noanswer"]
+ and normalized_prediction != normalized_ground_truth
+ ):
+ return zero_metric
+ if (
+ normalized_ground_truth in ["yes", "no", "noanswer"]
+ and normalized_prediction != normalized_ground_truth
+ ):
+ return zero_metric
+
+ prediction_tokens = normalized_prediction.split()
+ ground_truth_tokens = normalized_ground_truth.split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return zero_metric
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1, precision, recall
+
+
+def exact_match_score(prediction, ground_truth):
+ return normalize_answer(prediction) == normalize_answer(ground_truth)
+
+
+def update_answer(metrics, prediction, gold):
+ em = exact_match_score(prediction, gold)
+ f1, prec, recall = f1_score(prediction, gold)
+ metrics["em"] += float(em)
+ metrics["f1"] += f1
+ metrics["prec"] += prec
+ metrics["recall"] += recall
+ return em, prec, recall
+
+
+def update_sp(metrics, prediction, gold):
+ cur_sp_pred = set(map(tuple, prediction))
+ gold_sp_pred = set(map(tuple, gold))
+ tp, fp, fn = 0, 0, 0
+ for e in cur_sp_pred:
+ if e in gold_sp_pred:
+ tp += 1
+ else:
+ fp += 1
+ for e in gold_sp_pred:
+ if e not in cur_sp_pred:
+ fn += 1
+ prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
+ recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
+ f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
+ em = 1.0 if fp + fn == 0 else 0.0
+ metrics["sp_em"] += em
+ metrics["sp_f1"] += f1
+ metrics["sp_prec"] += prec
+ metrics["sp_recall"] += recall
+ return em, prec, recall
+
+
+def eval(prediction_file, gold_file):
+ with open(prediction_file) as f:
+ prediction = json.load(f)
+ with open(gold_file) as f:
+ gold = json.load(f)
+
+ metrics = {
+ "em": 0,
+ "f1": 0,
+ "prec": 0,
+ "recall": 0,
+ "sp_em": 0,
+ "sp_f1": 0,
+ "sp_prec": 0,
+ "sp_recall": 0,
+ "joint_em": 0,
+ "joint_f1": 0,
+ "joint_prec": 0,
+ "joint_recall": 0,
+ }
+ for dp in gold:
+ cur_id = dp["_id"]
+ can_eval_joint = True
+ if cur_id not in prediction["answer"]:
+ print(f"missing answer {cur_id}")
+ can_eval_joint = False
+ else:
+ em, prec, recall = update_answer(metrics, prediction["answer"][cur_id], dp["answer"])
+ if cur_id not in prediction["sp"]:
+ print(f"missing sp fact {cur_id}")
+ can_eval_joint = False
+ else:
+ sp_em, sp_prec, sp_recall = update_sp(
+ metrics, prediction["sp"][cur_id], dp["supporting_facts"]
+ )
+
+ if can_eval_joint:
+ joint_prec = prec * sp_prec
+ joint_recall = recall * sp_recall
+ if joint_prec + joint_recall > 0:
+ joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
+ else:
+ joint_f1 = 0.0
+ joint_em = em * sp_em
+
+ metrics["joint_em"] += joint_em
+ metrics["joint_f1"] += joint_f1
+ metrics["joint_prec"] += joint_prec
+ metrics["joint_recall"] += joint_recall
+
+ n = len(gold)
+ for k in metrics:
+ metrics[k] /= n
+
+ print(metrics)
+
+
+if __name__ == "__main__":
+ eval(sys.argv[1], sys.argv[2])
diff --git a/evaluation/scripts/mmlongbench/eval/__init__.py b/evaluation/scripts/mmlongbench/eval/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/evaluation/scripts/mmlongbench/eval/eval_score.py b/evaluation/scripts/mmlongbench/eval/eval_score.py
new file mode 100644
index 000000000..02ef6eb53
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/eval/eval_score.py
@@ -0,0 +1,246 @@
+import re
+
+from collections import defaultdict
+from math import isclose
+
+
+def levenshtein_distance(s1, s2):
+ if len(s1) > len(s2):
+ s1, s2 = s2, s1
+
+ distances = range(len(s1) + 1)
+ for i2, c2 in enumerate(s2):
+ distances_ = [i2 + 1]
+ for i1, c1 in enumerate(s1):
+ if c1 == c2:
+ distances_.append(distances[i1])
+ else:
+ distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
+ distances = distances_
+ return distances[-1]
+
+
+def anls_compute(groundtruth, prediction, threshold=0.5):
+ dist = levenshtein_distance(groundtruth, prediction)
+ length = max(len(groundtruth.upper()), len(prediction.upper()))
+ value = 0.0 if length == 0 else float(dist) / float(length)
+ anls = 1.0 - value
+ if anls <= threshold:
+ anls = 0.0
+ return anls
+
+
+def is_float_equal(
+ reference, prediction, include_percentage: bool = False, is_close: float = False
+) -> bool:
+ def get_precision(gt_ans: float) -> int:
+ precision = 3
+ if "." in str(gt_ans):
+ precision = len(str(gt_ans).split(".")[-1])
+ return precision
+
+ reference = float(str(reference).strip().rstrip("%").strip())
+ try:
+ prediction = float(str(prediction).strip().rstrip("%").strip())
+ except Exception:
+ return False
+
+ gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]
+ for item in gt_result:
+ try:
+ if is_close and isclose(item, prediction, rel_tol=0.01):
+ return True
+ precision = max(min(get_precision(prediction), get_precision(item)), 2)
+ if round(prediction, precision) == round(item, precision):
+ return True
+ except Exception:
+ continue
+ return False
+
+
+def get_clean_string(s):
+ s = str(s).lower().strip()
+
+ for suffix in ["mile", "miles", "million"]:
+ if s.endswith(suffix):
+ s = s[: -len(suffix)].strip()
+
+ s = re.sub(r"\s*\([^)]*\)", "", s).strip()
+ s = re.sub(r"^['\"]|['\"]$", "", s).strip()
+ s = s.lstrip("$").rstrip("%").strip()
+
+ return s
+
+
+def is_exact_match(s):
+ flag = False
+ # Website
+ if "https://" in s:
+ flag = True
+ # code file
+ if s.endswith((".py", ".ipynb")) or s.startswith("page"):
+ flag = True
+ # telephone number
+ if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s):
+ flag = True
+ # time
+ if "a.m." in s or "p.m." in s:
+ flag = True
+ # YYYY-MM-DD
+ if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s):
+ flag = True
+ # YYYY-MM
+ if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s):
+ flag = True
+ # Email address
+ if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s):
+ flag = True
+ return flag
+
+
+def isfloat(num):
+ try:
+ float(num)
+ return True
+ except ValueError:
+ return False
+
+
+def eval_score(gt, pred, answer_type):
+ if answer_type == "Int":
+ try:
+ gt, pred = int(gt), int(float(pred))
+ except Exception:
+ pred = ""
+ score = gt == pred
+ elif answer_type == "Float":
+ try:
+ gt = float(get_clean_string(str(gt)))
+ pred = float(get_clean_string(str(pred)))
+ except Exception:
+ pred = ""
+ score = is_float_equal(gt, pred, include_percentage=True, is_close=True)
+ elif answer_type in ["Str", "None"]:
+ gt = get_clean_string(gt)
+ pred = get_clean_string(pred)
+ score = gt == pred if is_exact_match(gt) else anls_compute(gt, pred)
+ else:
+ if isinstance(gt, str) and gt.startswith("["):
+ gt = eval(gt)
+ if not isinstance(gt, list):
+ gt = [gt]
+ if isinstance(pred, str) and pred.startswith("["):
+ pred = eval(pred)
+ if not isinstance(pred, list):
+ pred = [pred]
+ print(len(gt), len(pred))
+ if len(gt) != len(pred):
+ score = 0.0
+ else:
+ gt = sorted([get_clean_string(a) for a in gt])
+ pred = sorted([get_clean_string(a) for a in pred])
+ print(gt, pred)
+ if isfloat(gt[0]) or is_exact_match(gt[0]):
+ score = "-".join(gt) == "-".join(pred)
+ else:
+ score = min(
+ [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)]
+ )
+
+ return float(score)
+
+
+def eval_acc_and_f1(samples):
+ evaluated_samples = [sample for sample in samples if "score" in sample]
+ if not evaluated_samples:
+ return 0.0, 0.0
+
+ acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples)
+ try:
+ recall = sum(
+ [
+ sample["score"]
+ for sample in evaluated_samples
+ if sample["answer"] != "Not answerable"
+ ]
+ ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"])
+ precision = sum(
+ [
+ sample["score"]
+ for sample in evaluated_samples
+ if sample["answer"] != "Not answerable"
+ ]
+ ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"])
+ f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0
+ except Exception:
+ f1 = 0.0
+
+ return acc, f1
+
+
+def show_results(samples, show_path=None):
+ for sample in samples:
+ sample["evidence_pages"] = eval(sample["evidence_pages"])
+ sample["evidence_sources"] = eval(sample["evidence_sources"])
+
+ with open(show_path, "w") as f:
+ acc, f1 = eval_acc_and_f1(samples)
+ f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n")
+ f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n")
+ f.write("-----------------------\n")
+
+ acc_single_page, _ = eval_acc_and_f1(
+ [sample for sample in samples if len(sample["evidence_pages"]) == 1]
+ )
+ acc_multi_page, _ = eval_acc_and_f1(
+ [
+ sample
+ for sample in samples
+ if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable"
+ ]
+ )
+ acc_neg, _ = eval_acc_and_f1(
+ [sample for sample in samples if sample["answer"] == "Not answerable"]
+ )
+
+ f.write(
+ "Single-page | Accuracy: {} | Question Number: {}\n".format(
+ acc_single_page,
+ len([sample for sample in samples if len(sample["evidence_pages"]) == 1]),
+ )
+ )
+ f.write(
+ "Cross-page | Accuracy: {} | Question Number: {}\n".format(
+ acc_multi_page,
+ len(
+ [
+ sample
+ for sample in samples
+ if len(sample["evidence_pages"]) != 1
+ and sample["answer"] != "Not answerable"
+ ]
+ ),
+ )
+ )
+ f.write(
+ "Unanswerable | Accuracy: {} | Question Number: {}\n".format(
+ acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"])
+ )
+ )
+ f.write("-----------------------\n")
+
+ source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list)
+ for sample in samples:
+ for answer_source in sample["evidence_sources"]:
+ source_sample_dict[answer_source].append(sample)
+ document_type_dict[sample["doc_type"]].append(sample)
+ for type, sub_samples in source_sample_dict.items():
+ f.write(
+ f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n"
+ )
+
+ f.write("-----------------------\n")
+ for type, sub_samples in document_type_dict.items():
+ f.write(
+ f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n"
+ )
diff --git a/evaluation/scripts/mmlongbench/eval/extract_answer.py b/evaluation/scripts/mmlongbench/eval/extract_answer.py
new file mode 100644
index 000000000..b7f7e6863
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/eval/extract_answer.py
@@ -0,0 +1,33 @@
+import os
+
+import openai
+
+from dotenv import load_dotenv
+
+
+load_dotenv()
+client = openai.Client(
+ api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+)
+
+
+def extract_answer(question, output, prompt, model_name="gpt-4o"):
+ response = client.chat.completions.create(
+ model=model_name,
+ messages=[
+ {
+ "role": "user",
+ "content": prompt,
+ },
+ {"role": "assistant", "content": f"\n\nQuestion:{question}\nAnalysis:{output}\n"},
+ ],
+ temperature=0.0,
+ max_tokens=256,
+ top_p=1,
+ frequency_penalty=0,
+ presence_penalty=0,
+ )
+ response = response.choices[0].message.content
+
+ return response
diff --git a/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md b/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md
new file mode 100644
index 000000000..a309c0935
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md
@@ -0,0 +1,35 @@
+Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis.
+- Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List. If you find the analysis the question can not be answered from the given documents, type "Not answerable". Exception: If the analysis only tells you that it can not read/understand the images or documents, type "Fail to answer".
+- Please make your response as concise as possible. Also note that your response should be formatted as below:
+```
+Extracted answer: [answer]
+Answer format: [answer format]
+```
+
+Please read the following example, then extract the answer from the model response and type it at the end of the prompt.
+
+---
+Question: List the primary questions asked about the services in this report.
+Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n1. Is the service safe?\n2. Is the service effective?\n3. Is the service caring?\n4. Is the service responsive?\n5. Is the service well-led?
+Extracted answer: ['Is the servife safe?', 'Is the service effective', 'Is the serve caring?', 'Is the service responsive?', 'Is the service well-led?']
+Answer format: List
+
+---
+Question: How many regulations of the HSCA 2008 are breached in all according to this report?
+Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities) Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11: Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17: Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9. Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing, safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to notify the CQC of incidents.
+Extracted answer: 10
+Answer format: Integer
+
+---
+Question: According to the survey that is the percentage of Chinese who are paying more or about the same attention to politics after Trump's election?
+Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying more or about the same attention to politics after Trump's election. The report focuses primarily on American demographics and does not include specific details about the Chinese population in relation to this question. If you need information about a different demographic or a summary of the findings from the American demographic, I can certainly help with that!
+Extracted answer: Not answerable
+Answer format: String
+
+---
+Question: How many quotations from male respondent over 50 years old are included in this report?
+Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be able to help you with your question.
+Extracted answer: Fail to answer
+Answer format: String
+
+---
diff --git a/evaluation/scripts/mmlongbench/eval_docs.py b/evaluation/scripts/mmlongbench/eval_docs.py
new file mode 100644
index 000000000..510a0b1ed
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/eval_docs.py
@@ -0,0 +1,265 @@
+import csv
+import json
+import os
+import re
+import traceback
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from dotenv import load_dotenv
+from eval.eval_score import eval_acc_and_f1, eval_score, show_results
+from eval.extract_answer import extract_answer
+from tqdm import tqdm
+
+from memos.configs.mem_cube import GeneralMemCubeConfig
+from memos.configs.mem_os import MOSConfig
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_os.main import MOS
+
+
+load_dotenv()
+openapi_config = {
+ "model_name_or_path": "gpt-4o",
+ "top_k": 50,
+ "remove_think_prefix": True,
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+}
+neo4j_uri = os.getenv("NEO4J_URI", "bolt://47.117.41.207:7687")
+db_name = "stx-mmlongbench-004"
+
+doc_paths = [
+ f
+ for f in os.listdir("evaluation/data/mmlongbench/documents")
+ if os.path.isfile(os.path.join("evaluation/data/mmlongbench/documents", f))
+]
+
+with open("evaluation/data/mmlongbench/samples.json") as f:
+ samples = json.load(f)
+
+RESULTS_PATH = "evaluation/data/mmlongbench/test_results.json"
+completed_pairs: set[tuple[str, str]] = set()
+
+
+def _load_existing_results():
+ global completed_pairs
+ if os.path.exists(RESULTS_PATH):
+ try:
+ with open(RESULTS_PATH, encoding="utf-8") as f:
+ existing = json.load(f)
+ for r in existing:
+ did = r.get("doc_id")
+ q = r.get("question")
+ if did and q:
+ completed_pairs.add((did, q))
+ return existing
+ except Exception:
+ return []
+ return []
+
+
+def _doc_has_pending(doc_file: str) -> bool:
+ for s in samples:
+ if s.get("doc_id") == doc_file and (doc_file, s.get("question")) not in completed_pairs:
+ return True
+ return False
+
+
+def get_user_name(doc_file):
+ csv_path = "evaluation/data/mmlongbench/user_doc_map.csv"
+ if os.path.exists(csv_path):
+ with open(csv_path, newline="", encoding="utf-8") as f:
+ reader = csv.reader(f)
+ for row in reader:
+ uid, path = row[0], row[1]
+ base = os.path.basename(path)
+ if base == doc_file or os.path.splitext(base)[0] == os.path.splitext(doc_file)[0]:
+ return uid
+ return ""
+
+
+def process_doc(doc_file):
+ user_name = get_user_name(doc_file)
+ print(user_name, doc_file)
+ config = {
+ "user_id": user_name,
+ "chat_model": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "mem_reader": {
+ "backend": "simple_struct",
+ "config": {
+ "llm": {"backend": "openai", "config": openapi_config},
+ "embedder": {
+ "backend": "universal_api",
+ "config": {
+ "provider": "openai",
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "model_name_or_path": "text-embedding-3-large",
+ "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ },
+ },
+ "chunker": {
+ "backend": "sentence",
+ "config": {
+ "tokenizer_or_token_counter": "gpt2",
+ "chunk_size": 512,
+ "chunk_overlap": 128,
+ "min_sentences_per_chunk": 1,
+ },
+ },
+ },
+ },
+ "max_turns_window": 20,
+ "top_k": 5,
+ "enable_textual_memory": True,
+ "enable_activation_memory": False,
+ "enable_parametric_memory": False,
+ }
+ mos_config = MOSConfig(**config)
+ mos = MOS(mos_config)
+
+ mem_cube_config = GeneralMemCubeConfig.model_validate(
+ {
+ "user_id": user_name,
+ "cube_id": user_name,
+ "text_mem": {
+ "backend": "tree_text",
+ "config": {
+ "extractor_llm": {"backend": "openai", "config": openapi_config},
+ "dispatcher_llm": {"backend": "openai", "config": openapi_config},
+ "graph_db": {
+ "backend": "neo4j",
+ "config": {
+ "uri": neo4j_uri,
+ "user": "neo4j",
+ "password": "iaarlichunyu",
+ "db_name": db_name,
+ "user_name": user_name,
+ "use_multi_db": False,
+ "auto_create": True,
+ "embedding_dimension": 3072,
+ },
+ },
+ "embedder": {
+ "backend": "universal_api",
+ "config": {
+ "provider": "openai",
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "model_name_or_path": "text-embedding-3-large",
+ "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ },
+ },
+ "reorganize": False,
+ },
+ },
+ "act_mem": {},
+ "para_mem": {},
+ }
+ )
+ mem_cube = GeneralMemCube(mem_cube_config)
+
+ temp_dir = "tmp/" + doc_file
+ if not os.path.exists(temp_dir) or not os.listdir(temp_dir):
+ mem_cube.dump(temp_dir)
+
+ mos.register_mem_cube(temp_dir, mem_cube_id=user_name)
+
+ with open("evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md") as f:
+ prompt = f.read()
+
+ samples_res = []
+ doc_samples = [s for s in samples if s.get("doc_id") == doc_file]
+ if len(doc_samples) == 0:
+ return []
+
+ for sample in tqdm(doc_samples, desc=f"Processing {doc_file}"):
+ if (doc_file, sample.get("question")) in completed_pairs:
+ continue
+ messages = sample["question"]
+ try_cnt, is_success = 0, False
+
+ while True:
+ try:
+ mos.clear_messages()
+ response = mos.chat(messages, user_name)
+ is_success = True
+ except Exception as e:
+ print(f"[{doc_file}] Error:", e)
+ traceback.print_exc()
+ try_cnt += 1
+ response = "Failed"
+ if is_success or try_cnt > 5:
+ break
+
+ sample["response"] = response
+ extracted_res = extract_answer(sample["question"], response, prompt)
+ sample["extracted_res"] = extracted_res
+
+ pred_ans = extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip()
+ score = eval_score(sample["answer"], pred_ans, sample["answer_format"])
+
+ sample["pred"] = pred_ans
+ sample["score"] = score
+ samples_res.append(sample)
+
+ print("--------------------------------------")
+ print(f"Question: {sample['question']}")
+ print(f"Response: {sample['response']}")
+ print(f"Ground true: {sample['answer']}\tPred: {sample['pred']}\tScore: {sample['score']}")
+
+ return samples_res
+
+
+if __name__ == "__main__":
+ results = _load_existing_results()
+ total_samples = len(samples)
+ processed_samples = len(completed_pairs)
+ pending_samples = total_samples - processed_samples
+ sample_doc_ids = [s.get("doc_id") for s in samples if s.get("doc_id")]
+ all_docs_in_samples = set(sample_doc_ids)
+ processed_docs = {d for d, _ in completed_pairs}
+ with ThreadPoolExecutor(max_workers=4) as executor:
+ pending_docs = [d for d in doc_paths if _doc_has_pending(d)]
+ print("\n" + "=" * 80)
+ print("📊 评测进度统计")
+ print("=" * 80)
+ print(f"✅ 已加载历史结果: {len(results)} 条")
+ print(f"📂 数据集总样本: {total_samples}")
+ print(f"🧪 已完成样本: {processed_samples}")
+ print(f"⏳ 待处理样本: {pending_samples}")
+ print(f"📄 数据集中总文档: {len(all_docs_in_samples)}")
+ print(f"✔️ 已完成文档: {len(processed_docs)}")
+ print(f"➡️ 待处理文档(本次将运行): {len(pending_docs)}")
+ print("=" * 80 + "\n")
+ future_to_doc = {
+ executor.submit(process_doc, doc_file): doc_file for doc_file in pending_docs
+ }
+
+ for future in as_completed(future_to_doc):
+ doc_file = future_to_doc[future]
+ try:
+ res = future.result()
+ results.extend(res)
+
+ if len(res) > 0:
+ acc, f1 = eval_acc_and_f1(results)
+ print()
+ print(f"Avg acc: {acc}")
+ print(f"Avg f1: {f1}")
+
+ with open(RESULTS_PATH, "w", encoding="utf-8") as f:
+ json.dump(results, f, ensure_ascii=False, indent=2)
+ except Exception as e:
+ print(f"[{doc_file}] failed with {e}")
+
+ acc, f1 = eval_acc_and_f1(results)
+ print("--------------------------------------")
+ print(f"Final avg acc: {acc}")
+ print(f"Final avg f1: {f1}")
+
+ show_results(
+ results,
+ show_path=re.sub(r"\.json$", ".txt", "evaluation/data/mmlongbench/test_results_report.txt"),
+ )
diff --git a/evaluation/scripts/mmlongbench/import_docs.py b/evaluation/scripts/mmlongbench/import_docs.py
new file mode 100644
index 000000000..540c8f960
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/import_docs.py
@@ -0,0 +1,88 @@
+import asyncio
+import os
+import traceback
+import uuid
+
+from memos import log
+from memos.configs.mem_reader import SimpleStructMemReaderConfig
+from memos.configs.memory import TreeTextMemoryConfig
+from memos.mem_reader.simple_struct import SimpleStructMemReader
+from memos.memories.textual.tree import TreeTextMemory
+
+
+logger = log.get_logger(__name__)
+db_name = "stx-mmlongbench-004"
+# Create a memory reader instance
+reader_config = SimpleStructMemReaderConfig.from_json_file(
+ "examples/data/config/simple_struct_reader_config.json"
+)
+reader = SimpleStructMemReader(reader_config)
+
+tree_config = TreeTextMemoryConfig.from_json_file(
+ "examples/data/config/tree_config_shared_database.json"
+)
+tree_config.graph_db.config.db_name = db_name
+# Processing Documents
+existing_names = {
+ d for d in os.listdir("ppt_test_result") if os.path.isdir(os.path.join("ppt_test_result", d))
+}
+doc_paths = []
+for f in os.listdir("evaluation/data/mmlongbench/documents"):
+ fp = os.path.join("evaluation/data/mmlongbench/documents", f)
+ if os.path.isfile(fp):
+ name = os.path.splitext(f)[0]
+ if name in existing_names:
+ continue
+ doc_paths.append(fp)
+
+print("existing_names length:", len(existing_names))
+print("doc_paths length:", len(doc_paths))
+
+
+async def process_doc(doc_path):
+ print(f"🔄 Processing document: {doc_path}")
+ # Generate random user id: 'user_' + random short hex
+ user_id = "user_" + uuid.uuid4().hex[:8]
+ # Persist mapping between user_id and doc_path
+ try:
+ os.makedirs("evaluation/data/mmlongbench", exist_ok=True)
+ with open("evaluation/data/mmlongbench/user_doc_map.csv", "a", encoding="utf-8") as f:
+ f.write(f"{user_id},{doc_path}\n")
+ except Exception as e:
+ logger.error(f"Failed to write user-doc mapping: {e}")
+
+ tree_config.graph_db.config.user_name = user_id
+ my_tree_textual_memory = TreeTextMemory(tree_config)
+ doc_memory = await reader.get_memory(
+ [doc_path], "doc", info={"user_id": user_id, "session_id": "session_" + str(uuid.uuid4())}
+ )
+
+ count = 0
+ for m_list in doc_memory:
+ count += len(m_list)
+ my_tree_textual_memory.add(m_list)
+ print("total memories: ", count)
+
+ return doc_path
+
+
+async def main():
+ batch_size = 4
+ for i in range(0, len(doc_paths), batch_size):
+ batch = doc_paths[i : i + batch_size]
+ print(f"🚀 Starting batch {i // batch_size + 1} with {len(batch)} docs")
+
+ tasks = [process_doc(p) for p in batch]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ for p, result in zip(batch, results, strict=False):
+ if isinstance(result, Exception):
+ print(f"❌ Error processing {p}: {result}")
+ tb_text = "".join(traceback.TracebackException.from_exception(result).format())
+ print(tb_text)
+ else:
+ print(f"✅ Finished {result}")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/evaluation/scripts/mmlongbench/models/__init__.py b/evaluation/scripts/mmlongbench/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/evaluation/scripts/mmlongbench/models/minicpm_llama3.py b/evaluation/scripts/mmlongbench/models/minicpm_llama3.py
new file mode 100644
index 000000000..7f6d4b743
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/models/minicpm_llama3.py
@@ -0,0 +1,56 @@
+import torch
+
+from PIL import Image
+from transformers import AutoModel, AutoTokenizer
+
+
+def init_model(cache_path):
+ model_path = (
+ cache_path
+ if (cache_path is not None and cache_path != "None")
+ else "openbmb/MiniCPM-Llama3-V-2_5"
+ )
+ model = AutoModel.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ device_map="auto",
+ ).eval()
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ model.tokenizer = tokenizer
+ return model
+
+
+def get_response_concat(model, question, image_path_list, max_new_tokens=1024, temperature=1.0):
+ msgs = []
+ system_prompt = "Answer in detail."
+ if system_prompt:
+ msgs.append({"type": "text", "value": system_prompt})
+ if isinstance(image_path_list, list):
+ msgs.extend([{"type": "image", "value": p} for p in image_path_list])
+ else:
+ msgs = [{"type": "image", "value": image_path_list}]
+ msgs.append({"type": "text", "value": question})
+
+ content = []
+ for x in msgs:
+ if x["type"] == "text":
+ content.append(x["value"])
+ elif x["type"] == "image":
+ image = Image.open(x["value"]).convert("RGB")
+ content.append(image)
+ msgs = [{"role": "user", "content": content}]
+
+ with torch.cuda.amp.autocast():
+ res = model.chat(
+ msgs=msgs,
+ context=None,
+ image=None,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ do_sample=temperature != 0.0,
+ tokenizer=model.tokenizer,
+ )
+ return res
diff --git a/evaluation/scripts/mmlongbench/multimodal_test.py b/evaluation/scripts/mmlongbench/multimodal_test.py
new file mode 100644
index 000000000..929215229
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/multimodal_test.py
@@ -0,0 +1,185 @@
+import os
+import shutil
+
+from dotenv import load_dotenv
+
+from memos.configs.mem_cube import GeneralMemCubeConfig
+from memos.configs.mem_os import MOSConfig
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_os.main import MOS
+
+
+load_dotenv()
+
+db_name = "stx-mmlongbench-002"
+user_id = "user_dc812220"
+
+# 1.1 Set openai config
+openapi_config = {
+ "model_name_or_path": "gpt-4o",
+ "top_k": 50,
+ "remove_think_prefix": True,
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+}
+# 1.2 Set neo4j config
+neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
+
+# 1.3 Create MOS Config
+config = {
+ "user_id": user_id,
+ "chat_model": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "mem_reader": {
+ "backend": "simple_struct",
+ "config": {
+ "llm": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "embedder": {
+ "backend": "universal_api",
+ "config": {
+ "provider": "openai",
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "model_name_or_path": "text-embedding-3-large",
+ "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ },
+ },
+ "chunker": {
+ "backend": "sentence",
+ "config": {
+ "tokenizer_or_token_counter": "gpt2",
+ "chunk_size": 512,
+ "chunk_overlap": 128,
+ "min_sentences_per_chunk": 1,
+ },
+ },
+ },
+ },
+ "max_turns_window": 20,
+ "top_k": 5,
+ "enable_textual_memory": True,
+ "enable_activation_memory": False,
+ "enable_parametric_memory": False,
+}
+
+mos_config = MOSConfig(**config)
+mos = MOS(mos_config)
+
+config = GeneralMemCubeConfig.model_validate(
+ {
+ "user_id": user_id,
+ "cube_id": f"{user_id}",
+ "text_mem": {
+ "backend": "tree_text",
+ "config": {
+ "extractor_llm": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "dispatcher_llm": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "graph_db": {
+ "backend": "neo4j",
+ "config": {
+ "uri": neo4j_uri,
+ "user": "neo4j",
+ "password": "iaarlichunyu",
+ "db_name": db_name,
+ "user_name": user_id,
+ "use_multi_db": False,
+ "auto_create": True,
+ "embedding_dimension": 3072,
+ },
+ },
+ "embedder": {
+ "backend": "universal_api",
+ "config": {
+ "provider": "openai",
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "model_name_or_path": "text-embedding-3-large",
+ "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ },
+ },
+ "reorganize": False,
+ },
+ },
+ "act_mem": {},
+ "para_mem": {},
+ },
+)
+
+
+# Filter out embedding fields, keeping only necessary fields
+def filter_memory_data(memories_data):
+ filtered_data = {}
+ for key, value in memories_data.items():
+ if key == "text_mem":
+ filtered_data[key] = []
+ for mem_group in value:
+ # Check if it's the new data structure (list of TextualMemoryItem objects)
+ if "memories" in mem_group and isinstance(mem_group["memories"], list):
+ # New data structure: directly a list of TextualMemoryItem objects
+ filtered_memories = []
+ for memory_item in mem_group["memories"]:
+ # Create filtered dictionary
+ filtered_item = {
+ "id": memory_item.id,
+ "memory": memory_item.memory,
+ "metadata": {},
+ }
+ # Filter metadata, excluding embedding
+ if hasattr(memory_item, "metadata") and memory_item.metadata:
+ for attr_name in dir(memory_item.metadata):
+ if not attr_name.startswith("_") and attr_name != "embedding":
+ attr_value = getattr(memory_item.metadata, attr_name)
+ if not callable(attr_value):
+ filtered_item["metadata"][attr_name] = attr_value
+ filtered_memories.append(filtered_item)
+
+ filtered_group = {
+ "cube_id": mem_group.get("cube_id", ""),
+ "memories": filtered_memories,
+ }
+ filtered_data[key].append(filtered_group)
+ else:
+ # Old data structure: dictionary with nodes and edges
+ filtered_group = {
+ "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])}
+ }
+ for node in mem_group["memories"].get("nodes", []):
+ filtered_node = {
+ "id": node.get("id"),
+ "memory": node.get("memory"),
+ "metadata": {
+ k: v
+ for k, v in node.get("metadata", {}).items()
+ if k != "embedding"
+ },
+ }
+ filtered_group["memories"]["nodes"].append(filtered_node)
+ filtered_data[key].append(filtered_group)
+ else:
+ filtered_data[key] = value
+ return filtered_data
+
+
+mem_cube = GeneralMemCube(config)
+
+temp_dir = f"/tmp/{user_id}"
+if os.path.exists(temp_dir):
+ shutil.rmtree(temp_dir)
+mem_cube.dump(temp_dir)
+mos.register_mem_cube(temp_dir, mem_cube_id=user_id)
+
+
+print("start answering...")
+user_query = "图8美股变化的影响是什么"
+print(f"👤 User query: {user_query}")
+response = mos.chat(user_query)
+print(f"🤖 Response: {response}")
diff --git a/evaluation/scripts/xinyu/eval/__init__.py b/evaluation/scripts/xinyu/eval/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/evaluation/scripts/xinyu/eval/eval_score_llm.py b/evaluation/scripts/xinyu/eval/eval_score_llm.py
new file mode 100644
index 000000000..f5764ce39
--- /dev/null
+++ b/evaluation/scripts/xinyu/eval/eval_score_llm.py
@@ -0,0 +1,279 @@
+import os
+import re
+import traceback
+
+from collections import defaultdict
+from math import isclose
+
+from memos.configs.mem_os import MOSConfig
+from memos.llms.factory import LLMFactory
+
+
+openapi_config = {
+ "model_name_or_path": "gpt-5-nano",
+ "top_k": 50,
+ "remove_think_prefix": True,
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+}
+config = {
+ "user_id": "user_name",
+ "chat_model": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "mem_reader": {
+ "backend": "simple_struct",
+ "config": {
+ "llm": {"backend": "openai", "config": openapi_config},
+ "embedder": {
+ "backend": "universal_api",
+ "config": {
+ "provider": "openai",
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "model_name_or_path": "text-embedding-3-large",
+ "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ },
+ },
+ "chunker": {
+ "backend": "sentence",
+ "config": {
+ "tokenizer_or_token_counter": "gpt2",
+ "chunk_size": 512,
+ "chunk_overlap": 128,
+ "min_sentences_per_chunk": 1,
+ },
+ },
+ },
+ },
+ "max_turns_window": 20,
+ "top_k": 5,
+ "enable_textual_memory": True,
+ "enable_activation_memory": False,
+ "enable_parametric_memory": False,
+}
+mos_config = MOSConfig(**config)
+chat_llm = LLMFactory.from_config(mos_config.chat_model)
+
+
+def is_float_equal(
+ reference, prediction, include_percentage: bool = False, is_close: float = False
+) -> bool:
+ def get_precision(gt_ans: float) -> int:
+ precision = 3
+ if "." in str(gt_ans):
+ precision = len(str(gt_ans).split(".")[-1])
+ return precision
+
+ reference = float(str(reference).strip().rstrip("%").strip())
+ try:
+ prediction = float(str(prediction).strip().rstrip("%").strip())
+ except Exception:
+ return False
+
+ gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]
+ for item in gt_result:
+ try:
+ if is_close and isclose(item, prediction, rel_tol=0.01):
+ return True
+ precision = max(min(get_precision(prediction), get_precision(item)), 2)
+ if round(prediction, precision) == round(item, precision):
+ return True
+ except Exception:
+ continue
+ return False
+
+
+def get_clean_string(s):
+ s = str(s).lower().strip()
+
+ for suffix in ["mile", "miles", "million"]:
+ if s.endswith(suffix):
+ s = s[: -len(suffix)].strip()
+
+ s = re.sub(r"\s*\([^)]*\)", "", s).strip()
+ s = re.sub(r"^['\"]|['\"]$", "", s).strip()
+ s = s.lstrip("$").rstrip("%").strip()
+
+ return s
+
+
+def is_exact_match(s):
+ flag = False
+ # Website
+ if "https://" in s:
+ flag = True
+ # code file
+ if s.endswith((".py", ".ipynb")) or s.startswith("page"):
+ flag = True
+ # telephone number
+ if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s):
+ flag = True
+ # time
+ if "a.m." in s or "p.m." in s:
+ flag = True
+ # YYYY-MM-DD
+ if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s):
+ flag = True
+ # YYYY-MM
+ if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s):
+ flag = True
+ # Email address
+ if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s):
+ flag = True
+ return flag
+
+
+def isfloat(num):
+ try:
+ float(num)
+ return True
+ except ValueError:
+ return False
+
+
+def eval_score(question, gt, pred):
+ prompt = """
+ 你是一个评委,根据问题和标准答案对学生的答案进行打分。打分规则如下:
+
+ 完全不对(0分):
+ 学生答案与问题无关,未展示出任何相关概念或知识。
+ 对了一部分(0.5分):
+ 学生答案提供了一些相关信息,但未能直接回答问题。
+ 答案中包含部分正确内容,但缺乏关键信息,导致整体理解不清。
+ 基本正确(0.7分):
+ 学生答案提供了大部分关键信息,不过依然距离标准答案有一定缺失。
+ 答案中包含部分关键内容,但缺乏部分信息,导致不够完整。
+ 完全正确(1分):
+ 学生答案准确地回答了问题,涵盖所有关键信息。
+ 表达清晰,逻辑合理,直接且有效地回应了问题。
+
+ 问题:{}
+
+ 标准答案:{}
+
+ 学生答案:{}
+ """
+
+ max_try = 20
+ try_i = 0
+ while try_i < max_try:
+ try:
+ llm_input_prompt_score = (
+ prompt.format(question, gt, pred)
+ + """请返回给我一个json:
+ {
+ "分数": 1,
+ "理由": "xxxx"
+ }"""
+ )
+ score = chat_llm.generate(
+ [
+ {"role": "user", "content": llm_input_prompt_score},
+ ]
+ )
+
+ print(f"score: {score}")
+ score_real = eval(score.replace("json", "").replace("\n", "").replace("```", ""))
+ return float(score_real["分数"])
+ except Exception:
+ traceback.print_exc()
+ print(f"trying num {try_i}")
+ try_i += 1
+ return -1
+
+
+def eval_acc_and_f1(samples):
+ evaluated_samples = [sample for sample in samples if "score" in sample]
+ if not evaluated_samples:
+ return 0.0, 0.0
+
+ acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples)
+ try:
+ recall = sum(
+ [
+ sample["score"]
+ for sample in evaluated_samples
+ if sample["answer"] != "Not answerable"
+ ]
+ ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"])
+ precision = sum(
+ [
+ sample["score"]
+ for sample in evaluated_samples
+ if sample["answer"] != "Not answerable"
+ ]
+ ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"])
+ f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0
+ except Exception:
+ f1 = 0.0
+
+ return acc, f1
+
+
+def show_results(samples, show_path=None):
+ for sample in samples:
+ sample["evidence_pages"] = eval(sample["evidence_pages"])
+ sample["evidence_sources"] = eval(sample["evidence_sources"])
+
+ with open(show_path, "w") as f:
+ acc, f1 = eval_acc_and_f1(samples)
+ f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n")
+ f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n")
+ f.write("-----------------------\n")
+
+ acc_single_page, _ = eval_acc_and_f1(
+ [sample for sample in samples if len(sample["evidence_pages"]) == 1]
+ )
+ acc_multi_page, _ = eval_acc_and_f1(
+ [
+ sample
+ for sample in samples
+ if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable"
+ ]
+ )
+ acc_neg, _ = eval_acc_and_f1(
+ [sample for sample in samples if sample["answer"] == "Not answerable"]
+ )
+
+ f.write(
+ "Single-page | Accuracy: {} | Question Number: {}\n".format(
+ acc_single_page,
+ len([sample for sample in samples if len(sample["evidence_pages"]) == 1]),
+ )
+ )
+ f.write(
+ "Cross-page | Accuracy: {} | Question Number: {}\n".format(
+ acc_multi_page,
+ len(
+ [
+ sample
+ for sample in samples
+ if len(sample["evidence_pages"]) != 1
+ and sample["answer"] != "Not answerable"
+ ]
+ ),
+ )
+ )
+ f.write(
+ "Unanswerable | Accuracy: {} | Question Number: {}\n".format(
+ acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"])
+ )
+ )
+ f.write("-----------------------\n")
+
+ source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list)
+ for sample in samples:
+ for answer_source in sample["evidence_sources"]:
+ source_sample_dict[answer_source].append(sample)
+ document_type_dict[sample["doc_type"]].append(sample)
+ for type, sub_samples in source_sample_dict.items():
+ f.write(
+ f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n"
+ )
+
+ f.write("-----------------------\n")
+ for type, sub_samples in document_type_dict.items():
+ f.write(
+ f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n"
+ )
diff --git a/evaluation/scripts/xinyu/eval_docs.py b/evaluation/scripts/xinyu/eval_docs.py
new file mode 100644
index 000000000..03a333201
--- /dev/null
+++ b/evaluation/scripts/xinyu/eval_docs.py
@@ -0,0 +1,228 @@
+import csv
+import json
+import os
+import re
+import traceback
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from dotenv import load_dotenv
+
+from evaluation.scripts.mmlongbench.eval.extract_answer import extract_answer
+from evaluation.scripts.xinyu.eval.eval_score_llm import eval_acc_and_f1, eval_score, show_results
+from memos.configs.mem_cube import GeneralMemCubeConfig
+from memos.configs.mem_os import MOSConfig
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_os.main import MOS
+
+
+load_dotenv()
+openapi_config = {
+ "model_name_or_path": "gpt-4o",
+ "top_k": 50,
+ "remove_think_prefix": True,
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+}
+neo4j_uri = os.getenv("NEO4J_URI", "bolt://47.117.41.207:7687")
+db_name = "stx-mmlongbench-003"
+doc_paths = [
+ f
+ for f in os.listdir("evaluation/data/xinyu/documents")
+ if os.path.isfile(os.path.join("evaluation/data/xinyu/documents", f))
+]
+
+with open("evaluation/data/xinyu/all_samples_with_gt.json") as f:
+ samples = json.load(f)
+
+
+def get_user_name(doc_file):
+ csv_path = "evaluation/data/xinyu/user_doc_map.csv"
+ if os.path.exists(csv_path):
+ with open(csv_path, newline="", encoding="utf-8") as f:
+ reader = csv.reader(f)
+ for row in reader:
+ uid, path = row[0], row[1]
+ base = os.path.basename(path)
+ if base == doc_file or os.path.splitext(base)[0] == os.path.splitext(doc_file)[0]:
+ return uid
+ return ""
+
+
+def process_doc(doc_file):
+ user_name = get_user_name(doc_file)
+ print(user_name, doc_file)
+ config = {
+ "user_id": user_name,
+ "chat_model": {
+ "backend": "openai",
+ "config": openapi_config,
+ },
+ "mem_reader": {
+ "backend": "simple_struct",
+ "config": {
+ "llm": {"backend": "openai", "config": openapi_config},
+ "embedder": {
+ "backend": "universal_api",
+ "config": {
+ "provider": "openai",
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "model_name_or_path": "text-embedding-3-large",
+ "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ },
+ },
+ "chunker": {
+ "backend": "sentence",
+ "config": {
+ "tokenizer_or_token_counter": "gpt2",
+ "chunk_size": 512,
+ "chunk_overlap": 128,
+ "min_sentences_per_chunk": 1,
+ },
+ },
+ },
+ },
+ "max_turns_window": 20,
+ "top_k": 5,
+ "enable_textual_memory": True,
+ "enable_activation_memory": False,
+ "enable_parametric_memory": False,
+ }
+ mos_config = MOSConfig(**config)
+ mos = MOS(mos_config)
+
+ mem_cube_config = GeneralMemCubeConfig.model_validate(
+ {
+ "user_id": user_name,
+ "cube_id": user_name,
+ "text_mem": {
+ "backend": "tree_text",
+ "config": {
+ "extractor_llm": {"backend": "openai", "config": openapi_config},
+ "dispatcher_llm": {"backend": "openai", "config": openapi_config},
+ "graph_db": {
+ "backend": "neo4j",
+ "config": {
+ "uri": neo4j_uri,
+ "user": "neo4j",
+ "password": "iaarlichunyu",
+ "db_name": db_name,
+ "user_name": user_name,
+ "use_multi_db": False,
+ "auto_create": True,
+ "embedding_dimension": 3072,
+ },
+ },
+ "embedder": {
+ "backend": "universal_api",
+ "config": {
+ "provider": "openai",
+ "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ "model_name_or_path": "text-embedding-3-large",
+ "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ },
+ },
+ "reorganize": False,
+ },
+ },
+ "act_mem": {},
+ "para_mem": {},
+ }
+ )
+ mem_cube = GeneralMemCube(mem_cube_config)
+
+ temp_dir = os.path.join("tmp", doc_file)
+
+ if (not os.path.exists(temp_dir)) or (not os.listdir(temp_dir)):
+ mem_cube.dump(temp_dir)
+
+ mos.register_mem_cube(temp_dir, mem_cube_id=user_name)
+
+ with open("evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md") as f:
+ prompt = f.read()
+
+ samples_res = []
+ doc_samples = [s for s in samples if s.get("doc_id") == doc_file]
+
+ if len(doc_samples) == 0:
+ return []
+
+ sample = doc_samples[0]
+ question_list = sample["question"]
+ answer_list = sample["answer"]
+
+ for idx, question in enumerate(question_list):
+ gt = answer_list.get(str(idx))
+
+ try_cnt, is_success = 0, False
+ while True:
+ try:
+ mos.clear_messages()
+ response = mos.chat(question, user_name)
+ is_success = True
+ except Exception as e:
+ print(f"[{doc_file}] Error:", e)
+ traceback.print_exc()
+ try_cnt += 1
+ response = "Failed"
+ if is_success or try_cnt > 5:
+ break
+
+ sample_item = dict(sample)
+ sample_item["question"] = question
+ sample_item["answer"] = gt
+ sample_item["response"] = response
+
+ extracted_res = extract_answer(sample_item["question"], response, prompt)
+ sample_item["extracted_res"] = extracted_res
+
+ print("--------------------------------------")
+ pred_ans = extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip()
+ score = eval_score(question, gt, response)
+
+ sample_item["pred"] = pred_ans
+ sample_item["score"] = score
+ samples_res.append(sample_item)
+
+ print(f"Question: {question}")
+ print(f"Response: {sample_item['response']}")
+ print(f"Ground true: {gt}\tPred: {sample_item['pred']}\tScore: {sample_item['score']}")
+
+ print("samples_res length: ", len(samples_res))
+ return samples_res
+
+
+if __name__ == "__main__":
+ results = []
+
+ with ThreadPoolExecutor(max_workers=4) as executor:
+ future_to_doc = {executor.submit(process_doc, doc_file): doc_file for doc_file in doc_paths}
+
+ for future in as_completed(future_to_doc):
+ doc_file = future_to_doc[future]
+ try:
+ res = future.result()
+ results.extend(res)
+
+ if len(res) > 0:
+ acc, f1 = eval_acc_and_f1(results)
+ print()
+ print(f"Avg acc: {acc}")
+ print(f"Avg f1: {f1}")
+
+ with open("evaluation/data/xinyu/test_results.json", "w", encoding="utf-8") as f:
+ json.dump(results, f, ensure_ascii=False, indent=2)
+
+ except Exception as e:
+ print(f"[{doc_file}] failed with {e}")
+ traceback.print_exc()
+
+ acc, f1 = eval_acc_and_f1(results)
+ print("--------------------------------------")
+ print(f"Final avg acc: {acc}")
+ print(f"Final avg f1: {f1}")
+
+ show_results(
+ results,
+ show_path=re.sub(r"\.json$", ".txt", "evaluation/data/xinyu/test_results_report.txt"),
+ )
diff --git a/evaluation/scripts/xinyu/import_docs.py b/evaluation/scripts/xinyu/import_docs.py
new file mode 100644
index 000000000..f9d2619a4
--- /dev/null
+++ b/evaluation/scripts/xinyu/import_docs.py
@@ -0,0 +1,85 @@
+import asyncio
+import os
+import traceback
+import uuid
+
+from memos import log
+from memos.configs.mem_reader import SimpleStructMemReaderConfig
+from memos.configs.memory import TreeTextMemoryConfig
+from memos.mem_reader.simple_struct import SimpleStructMemReader
+from memos.memories.textual.tree import TreeTextMemory
+
+
+logger = log.get_logger(__name__)
+db_name = "stx-mmlongbench-003"
+# Create a memory reader instance
+reader_config = SimpleStructMemReaderConfig.from_json_file(
+ "examples/data/config/simple_struct_reader_config.json"
+)
+reader = SimpleStructMemReader(reader_config)
+
+tree_config = TreeTextMemoryConfig.from_json_file(
+ "examples/data/config/tree_config_shared_database.json"
+)
+tree_config.graph_db.config.db_name = db_name
+# Processing Documents
+existing_names = {
+ d for d in os.listdir("ppt_test_result") if os.path.isdir(os.path.join("ppt_test_result", d))
+}
+doc_paths = []
+for f in os.listdir("evaluation/data/xinyu/documents"):
+ fp = os.path.join("evaluation/data/xinyu/documents", f)
+ if os.path.isfile(fp):
+ name = os.path.splitext(f)[0]
+ if name in existing_names:
+ continue
+ doc_paths.append(fp)
+
+
+async def process_doc(doc_path):
+ print(f"🔄 Processing document: {doc_path}")
+ doc_file = doc_path.split("/")[-1].rsplit(".", 1)[0]
+
+ # Generate random user id: 'user_' + random short hex
+ user_id = "user_" + uuid.uuid4().hex[:8]
+ # Persist mapping between user_id and doc_path
+ with open("evaluation/data/xinyu/user_doc_map.csv", "a", encoding="utf-8") as f:
+ f.write(f"{user_id},{doc_path}\n")
+
+ tree_config.graph_db.config.user_name = user_id
+ temp_dir = "tmp/" + doc_file
+ my_tree_textual_memory = TreeTextMemory(tree_config)
+ doc_memory = await reader.get_memory(
+ [doc_path], "doc", info={"user_id": user_id, "session_id": "session_" + str(uuid.uuid4())}
+ )
+
+ count = 0
+ for m_list in doc_memory:
+ count += len(m_list)
+ my_tree_textual_memory.add(m_list)
+ print("total memories: ", count)
+
+ my_tree_textual_memory.dump(temp_dir)
+ return doc_path
+
+
+async def main():
+ batch_size = 2
+ for i in range(0, len(doc_paths), batch_size):
+ batch = doc_paths[i : i + batch_size]
+ print(f"🚀 Starting batch {i // batch_size + 1} with {len(batch)} docs")
+
+ tasks = [process_doc(p) for p in batch]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ for p, result in zip(batch, results, strict=False):
+ if isinstance(result, Exception):
+ print(f"❌ Error processing {p}: {result}")
+ tb_text = "".join(traceback.TracebackException.from_exception(result).format())
+ print(tb_text)
+ else:
+ print(f"✅ Finished {result}")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
From f377cb41556798b3bf6de940aa73808c71f620c9 Mon Sep 17 00:00:00 2001
From: stx <31013941@qq.com>
Date: Sun, 4 Jan 2026 10:51:42 +0800
Subject: [PATCH 2/4] feat: add evaluation pipline
---
.gitignore | 2 +-
.../__init__.py => data/personamem/.gitkeep} | 0
evaluation/scripts/hotpot/data_loader.py | 78 +++
evaluation/scripts/hotpot/hotpot_eval.py | 419 ++++++++--------
.../scripts/hotpot/hotpot_evaluate_v1.py | 17 +-
evaluation/scripts/hotpot/hotpot_ingestion.py | 241 +++++++++
evaluation/scripts/hotpot/hotpot_old.py | 309 ++++++++++++
evaluation/scripts/hotpot/hotpot_search.py | 297 +++++++++++
evaluation/scripts/long_bench-v2/__init__.py | 1 +
.../long_bench-v2/longbench_v2_ingestion.py | 199 ++++++++
.../long_bench-v2/longbench_v2_metric.py | 176 +++++++
.../long_bench-v2/longbench_v2_responses.py | 319 ++++++++++++
.../long_bench-v2/longbench_v2_search.py | 273 ++++++++++
.../long_bench-v2/run_longbench_v2_eval.sh | 110 ++++
.../scripts/long_bench-v2/wait_scheduler.py | 67 +++
.../scripts/longbench_v2/longbench_v2_eval.py | 241 +++++++++
.../longbench_v2/longbench_v2_ingestion.py | 284 +++++++++++
.../scripts/longbench_v2/longbench_v2_old.py | 415 ++++++++++++++++
.../longbench_v2/longbench_v2_search.py | 284 +++++++++++
.../mmlongbench/eval_utils/__init__.py | 0
.../eval_utils/eval_score.py} | 163 +++---
.../scripts/mmlongbench/mmlongbench_eval.py | 470 ++++++++++++++++++
.../mmlongbench/mmlongbench_ingestion.py | 341 +++++++++++++
.../scripts/mmlongbench/mmlongbench_old.py | 403 +++++++++++++++
.../scripts/mmlongbench/mmlongbench_search.py | 383 ++++++++++++++
evaluation/scripts/run_hotpot_eval.sh | 48 ++
evaluation/scripts/run_longbench_v2_eval.sh | 45 ++
evaluation/scripts/run_mmlongbench_eval.sh | 44 ++
evaluation/scripts/utils/client.py | 263 ++++++----
evaluation/scripts/utils/eval_score.py | 246 +++++++++
evaluation/scripts/utils/extract_answer.py | 58 +++
evaluation/scripts/utils/metrics.py | 56 +++
.../utils/prompt_for_answer_extraction.md | 35 ++
evaluation/scripts/utils/prompts.py | 49 ++
evaluation/scripts/xinyu/eval_docs.py | 228 ---------
evaluation/scripts/xinyu/import_docs.py | 85 ----
36 files changed, 5920 insertions(+), 729 deletions(-)
rename evaluation/{scripts/xinyu/eval/__init__.py => data/personamem/.gitkeep} (100%)
create mode 100644 evaluation/scripts/hotpot/data_loader.py
create mode 100644 evaluation/scripts/hotpot/hotpot_ingestion.py
create mode 100644 evaluation/scripts/hotpot/hotpot_old.py
create mode 100644 evaluation/scripts/hotpot/hotpot_search.py
create mode 100644 evaluation/scripts/long_bench-v2/__init__.py
create mode 100644 evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py
create mode 100644 evaluation/scripts/long_bench-v2/longbench_v2_metric.py
create mode 100644 evaluation/scripts/long_bench-v2/longbench_v2_responses.py
create mode 100644 evaluation/scripts/long_bench-v2/longbench_v2_search.py
create mode 100755 evaluation/scripts/long_bench-v2/run_longbench_v2_eval.sh
create mode 100644 evaluation/scripts/long_bench-v2/wait_scheduler.py
create mode 100644 evaluation/scripts/longbench_v2/longbench_v2_eval.py
create mode 100644 evaluation/scripts/longbench_v2/longbench_v2_ingestion.py
create mode 100644 evaluation/scripts/longbench_v2/longbench_v2_old.py
create mode 100644 evaluation/scripts/longbench_v2/longbench_v2_search.py
create mode 100644 evaluation/scripts/mmlongbench/eval_utils/__init__.py
rename evaluation/scripts/{xinyu/eval/eval_score_llm.py => mmlongbench/eval_utils/eval_score.py} (63%)
create mode 100644 evaluation/scripts/mmlongbench/mmlongbench_eval.py
create mode 100644 evaluation/scripts/mmlongbench/mmlongbench_ingestion.py
create mode 100644 evaluation/scripts/mmlongbench/mmlongbench_old.py
create mode 100644 evaluation/scripts/mmlongbench/mmlongbench_search.py
create mode 100755 evaluation/scripts/run_hotpot_eval.sh
create mode 100755 evaluation/scripts/run_longbench_v2_eval.sh
create mode 100755 evaluation/scripts/run_mmlongbench_eval.sh
create mode 100644 evaluation/scripts/utils/eval_score.py
create mode 100644 evaluation/scripts/utils/extract_answer.py
create mode 100644 evaluation/scripts/utils/metrics.py
create mode 100644 evaluation/scripts/utils/prompt_for_answer_extraction.md
delete mode 100644 evaluation/scripts/xinyu/eval_docs.py
delete mode 100644 evaluation/scripts/xinyu/import_docs.py
diff --git a/.gitignore b/.gitignore
index 8319a4d2f..c6b9130e6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,7 +11,7 @@ tmp/
**settings.json**
evaluation/*tmp/
evaluation/results
-evaluation/.env
+.env
!evaluation/configs-example/*.json
evaluation/configs/*
**tree_textual_memory_locomo**
diff --git a/evaluation/scripts/xinyu/eval/__init__.py b/evaluation/data/personamem/.gitkeep
similarity index 100%
rename from evaluation/scripts/xinyu/eval/__init__.py
rename to evaluation/data/personamem/.gitkeep
diff --git a/evaluation/scripts/hotpot/data_loader.py b/evaluation/scripts/hotpot/data_loader.py
new file mode 100644
index 000000000..871981036
--- /dev/null
+++ b/evaluation/scripts/hotpot/data_loader.py
@@ -0,0 +1,78 @@
+import json
+
+from pathlib import Path
+
+from datasets import load_dataset
+
+
+def load_hotpot_data(data_dir: Path | str) -> list[dict]:
+ """
+ Load HotpotQA dataset.
+ If dev_distractor_gold.json exists in data_dir, load it.
+ Otherwise, download from Hugging Face, convert to standard format, save, and load.
+ """
+ data_dir = Path(data_dir)
+ data_dir.mkdir(parents=True, exist_ok=True)
+ file_path = data_dir / "dev_distractor_gold.json"
+
+ if file_path.exists():
+ print(f"Loading local dataset from {file_path}")
+ try:
+ with open(file_path, encoding="utf-8") as f:
+ return json.load(f)
+ except Exception as e:
+ print(f"Failed to load local file: {e}. Re-downloading...")
+
+ print("Downloading HotpotQA dataset from Hugging Face...")
+ try:
+ dataset = load_dataset(
+ "hotpotqa/hotpot_qa", "distractor", split="validation", trust_remote_code=True
+ )
+ except Exception as e:
+ print(f"Failed to download dataset: {e}")
+ raise
+
+ print(f"Processing and saving dataset to {file_path}...")
+ items = []
+ for item in dataset:
+ # Convert HF format to Standard format
+ # ID
+ qid = item.get("id") or item.get("_id")
+
+ # Supporting Facts
+ sp = item.get("supporting_facts")
+ if isinstance(sp, dict):
+ sp_titles = sp.get("title", [])
+ sp_sent_ids = sp.get("sent_id", [])
+ sp_list = list(zip(sp_titles, sp_sent_ids, strict=False))
+ else:
+ sp_list = sp or []
+
+ # Context
+ ctx = item.get("context")
+ if isinstance(ctx, dict):
+ ctx_titles = ctx.get("title", [])
+ ctx_sentences = ctx.get("sentences", [])
+ ctx_list = list(zip(ctx_titles, ctx_sentences, strict=False))
+ else:
+ ctx_list = ctx or []
+
+ new_item = {
+ "_id": qid,
+ "question": item.get("question"),
+ "answer": item.get("answer"),
+ "supporting_facts": sp_list,
+ "context": ctx_list,
+ "type": item.get("type"),
+ "level": item.get("level"),
+ }
+ items.append(new_item)
+
+ try:
+ with open(file_path, "w", encoding="utf-8") as f:
+ json.dump(items, f, ensure_ascii=False, indent=2)
+ print(f"Saved {len(items)} items to {file_path}")
+ except Exception as e:
+ print(f"Failed to save dataset: {e}")
+
+ return items
diff --git a/evaluation/scripts/hotpot/hotpot_eval.py b/evaluation/scripts/hotpot/hotpot_eval.py
index 05ff52349..80315a65f 100644
--- a/evaluation/scripts/hotpot/hotpot_eval.py
+++ b/evaluation/scripts/hotpot/hotpot_eval.py
@@ -1,224 +1,231 @@
+import argparse
+import importlib.util
import json
import os
-import uuid
+import time
from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
from dotenv import load_dotenv
+from openai import OpenAI
from tqdm import tqdm
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_os.main import MOS
+from evaluation.scripts.hotpot.data_loader import load_hotpot_data
+from evaluation.scripts.utils.extract_answer import extract_answer, parse_extracted_answer
+from evaluation.scripts.utils.metrics import Metrics
+from evaluation.scripts.utils.prompts import HOTPOT_ANSWER_PROMPT
load_dotenv()
-db_name = "stx-hotpot-001"
-
-
-user_name = str(uuid.uuid4())
-
-# 1.1 Set openai config
-openapi_config = {
- "model_name_or_path": "gpt-4o-mini",
- "temperature": 0.8,
- "max_tokens": 1024,
- "top_p": 0.9,
- "top_k": 50,
- "remove_think_prefix": True,
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
-}
-# 1.2 Set neo4j config
-neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
-
-# 1.3 Create MOS Config
-config = {
- "user_id": user_name,
- "chat_model": {
- "backend": "openai",
- "config": openapi_config,
- },
- "mem_reader": {
- "backend": "simple_struct",
- "config": {
- "llm": {
- "backend": "openai",
- "config": openapi_config,
- },
- "embedder": {
- "backend": "universal_api",
- "config": {
- "provider": "openai",
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "model_name_or_path": "text-embedding-3-large",
- "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
- "chunker": {
- "backend": "sentence",
- "config": {
- "tokenizer_or_token_counter": "gpt2",
- "chunk_size": 512,
- "chunk_overlap": 128,
- "min_sentences_per_chunk": 1,
- },
- },
- },
- },
- "max_turns_window": 20,
- "top_k": 5,
- "enable_textual_memory": True,
- "enable_activation_memory": False,
- "enable_parametric_memory": False,
-}
-
-mos_config = MOSConfig(**config)
-# you can set PRO_MODE to True to enable CoT enhancement mos_config.PRO_MODE = True
-mos = MOS(mos_config)
-
-
-# Filter out embedding fields, keeping only necessary fields
-def filter_memory_data(memories_data):
- filtered_data = {}
- for key, value in memories_data.items():
- if key == "text_mem":
- filtered_data[key] = []
- for mem_group in value:
- # Check if it's the new data structure (list of TextualMemoryItem objects)
- if "memories" in mem_group and isinstance(mem_group["memories"], list):
- # New data structure: directly a list of TextualMemoryItem objects
- filtered_memories = []
- for memory_item in mem_group["memories"]:
- # Create filtered dictionary
- filtered_item = {
- "id": memory_item.id,
- "memory": memory_item.memory,
- "metadata": {},
- }
- # Filter metadata, excluding embedding
- if hasattr(memory_item, "metadata") and memory_item.metadata:
- for attr_name in dir(memory_item.metadata):
- if not attr_name.startswith("_") and attr_name != "embedding":
- attr_value = getattr(memory_item.metadata, attr_name)
- if not callable(attr_value):
- filtered_item["metadata"][attr_name] = attr_value
- filtered_memories.append(filtered_item)
-
- filtered_group = {
- "cube_id": mem_group.get("cube_id", ""),
- "memories": filtered_memories,
- }
- filtered_data[key].append(filtered_group)
- else:
- # Old data structure: dictionary with nodes and edges
- filtered_group = {
- "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])}
- }
- for node in mem_group["memories"].get("nodes", []):
- filtered_node = {
- "id": node.get("id"),
- "memory": node.get("memory"),
- "metadata": {
- k: v
- for k, v in node.get("metadata", {}).items()
- if k != "embedding"
- },
- }
- filtered_group["memories"]["nodes"].append(filtered_node)
- filtered_data[key].append(filtered_group)
+
+def llm_response(
+ oai_client, chat_model: str, context: str, question: str, question_date: str | None = None
+) -> str:
+ prompt = HOTPOT_ANSWER_PROMPT.format(question=question, context=context)
+ resp = oai_client.chat.completions.create(
+ model=chat_model,
+ messages=[{"role": "system", "content": prompt}],
+ temperature=0,
+ )
+ return resp.choices[0].message.content or ""
+
+
+def _load_json_list(path: Path) -> list[dict]:
+ data = json.loads(path.read_text(encoding="utf-8"))
+ if isinstance(data, list):
+ return data
+ if isinstance(data, dict) and isinstance(data.get("results"), list):
+ return data.get("results") or []
+ raise ValueError(f"Invalid json format: {path}")
+
+
+def _save_pred(
+ pred_path: Path, pred_answers: dict, pred_sp: dict, perf: dict | None = None
+) -> None:
+ pred_path.parent.mkdir(parents=True, exist_ok=True)
+ tmp = pred_path.with_suffix(pred_path.suffix + ".tmp")
+ safe_pred_answers = {
+ k: (v if isinstance(v, str) else ("" if v is None else str(v)))
+ for k, v in pred_answers.items()
+ }
+ obj = {"answer": safe_pred_answers, "sp": pred_sp}
+ if perf is not None:
+ obj["perf"] = perf
+ tmp.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")
+ os.replace(tmp, pred_path)
+
+
+def run_eval(pred_path: Path, gold_path: Path):
+ spec = importlib.util.spec_from_file_location(
+ "hotpot_eval_v1", "evaluation/scripts/hotpot/hotpot_evaluate_v1.py"
+ )
+ m = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(m)
+ metrics = m.eval(str(pred_path), str(gold_path))
+
+ # Save metrics to pred_path (beginning of file)
+ try:
+ results_path = pred_path
+ current_data = {}
+ if results_path.exists():
+ with open(results_path, encoding="utf-8") as f:
+ current_data = json.load(f)
+
+ if isinstance(current_data, list):
+ new_data = [metrics, *current_data]
+
+ elif isinstance(current_data, dict):
+ # Put metrics at the beginning
+ new_data = metrics.copy()
+ for k, v in current_data.items():
+ if k not in new_data:
+ new_data[k] = v
else:
- filtered_data[key] = value
- return filtered_data
-
-
-config = GeneralMemCubeConfig.model_validate(
- {
- "user_id": user_name,
- "cube_id": f"{user_name}",
- "text_mem": {
- "backend": "tree_text",
- "config": {
- "extractor_llm": {
- "backend": "openai",
- "config": openapi_config,
- },
- "dispatcher_llm": {
- "backend": "openai",
- "config": openapi_config,
- },
- "graph_db": {
- "backend": "neo4j",
- "config": {
- "uri": neo4j_uri,
- "user": "neo4j",
- "password": "iaarlichunyu",
- "db_name": db_name,
- "auto_create": True,
- },
- },
- "embedder": {
- "backend": "universal_api",
- "config": {
- "provider": "openai",
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "model_name_or_path": "text-embedding-3-large",
- "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
- "reorganize": True,
- },
+ new_data = metrics
+
+ with open(results_path, "w", encoding="utf-8") as f:
+ json.dump(new_data, f, indent=2, ensure_ascii=False)
+ except Exception as e:
+ print(f"Failed to save metrics to {results_path}: {e}")
+
+
+def evaluate_one(oai_client, row: dict, chat_model: str) -> tuple[str, str, list]:
+ qid = str(row.get("_id"))
+ question = row.get("question") or ""
+ context = row.get("context") or ""
+ sp_list = row.get("sp") or []
+
+ raw_answer = llm_response(oai_client, chat_model, context=context, question=question)
+ extracted_res = extract_answer(question, raw_answer)
+ answer = parse_extracted_answer(extracted_res, raw_answer)
+ return qid, answer, sp_list
+
+
+def main(argv: list[str] | None = None) -> None:
+ parser = argparse.ArgumentParser(
+ description="HotpotQA eval (OpenAI only, read search results)."
+ )
+ parser.add_argument(
+ "--lib",
+ type=str,
+ default="memos",
+ choices=["memos", "mem0", "supermemory"],
+ )
+ parser.add_argument("--workers", type=int, default=8)
+ parser.add_argument("--max_samples", type=int, default=None)
+ parser.add_argument("--version-dir", "-v", default=None, help="Version directory name")
+ parser.add_argument("--chat-model", default=None, help="Chat model name")
+ parser.add_argument("--search-mode", default="fine", help="Search mode")
+
+ args = parser.parse_args(argv)
+
+ output_dir = Path(f"evaluation/data/hotpot/{args.version_dir}")
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ if args.lib == "memos":
+ search_path = output_dir / f"{args.lib}_{args.search_mode}_search_results.json"
+ pred_path = output_dir / f"{args.lib}_{args.search_mode}_search_eval_results.json"
+ else:
+ search_path = output_dir / f"{args.lib}_search_results.json"
+ pred_path = output_dir / f"{args.lib}_eval_results.json"
+ gold_path = Path("evaluation/data/hotpot/dev_distractor_gold.json")
+
+ if not search_path.exists():
+ raise FileNotFoundError(f"Search results not found: {search_path}")
+
+ if not gold_path.exists():
+ load_hotpot_data("evaluation/data/hotpot")
+
+ pred_answers: dict[str, str] = {}
+ pred_sp: dict[str, list] = {}
+ if pred_path.exists():
+ try:
+ prev = json.loads(pred_path.read_text(encoding="utf-8"))
+ if isinstance(prev, dict) and isinstance(prev.get("answer"), dict):
+ pred_answers.update(prev["answer"])
+ if isinstance(prev, dict) and isinstance(prev.get("sp"), dict):
+ pred_sp.update(prev["sp"])
+ except Exception as e:
+ print(f"[Eval] failed to load existing pred: {e}")
+
+ rows = _load_json_list(search_path)
+ if args.max_samples is not None:
+ rows = rows[: args.max_samples]
+
+ pending = [r for r in rows if str(r.get("_id")) not in pred_answers]
+ print(f"[Eval] lib={args.lib} total={len(rows)} pending={len(pending)} workers={args.workers}")
+ if not pending:
+ run_eval(pred_path, gold_path)
+ return
+
+ oai_client = OpenAI(
+ api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL")
+ )
+
+ processed = len(pred_answers)
+
+ metrics = Metrics()
+ start_time = time.time()
+
+ print("[Response model]: ", args.chat_model)
+ with ThreadPoolExecutor(max_workers=args.workers) as executor:
+
+ def do_eval(row):
+ st = time.perf_counter()
+ try:
+ res = evaluate_one(oai_client, row, args.chat_model)
+ dur = time.perf_counter() - st
+ metrics.record(dur, True)
+ return res
+ except Exception as e:
+ dur = time.perf_counter() - st
+ metrics.record(dur, False, str(e))
+ raise e
+
+ futures = [executor.submit(do_eval, row) for row in pending]
+ for idx, f in enumerate(
+ tqdm(as_completed(futures), total=len(futures), desc="Evaluating"), 1
+ ):
+ try:
+ qid, answer, sp_list = f.result()
+ pred_answers[qid] = answer
+ pred_sp[qid] = sp_list
+ processed += 1
+ if idx % 20 == 0:
+ _save_pred(pred_path, pred_answers, pred_sp)
+ except Exception as e:
+ print(f"[Eval] Error: {e}")
+
+ _save_pred(pred_path, pred_answers, pred_sp)
+
+ # Save performance metrics (merge into pred json)
+ total_duration = time.time() - start_time
+ summary = metrics.summary()
+ perf_obj = {
+ "summary": summary,
+ "total_duration": total_duration,
+ "config": {
+ "workers": args.workers,
+ "chat_model": args.chat_model or os.getenv("CHAT_MODEL"),
+ "lib": args.lib,
},
- "act_mem": {},
- "para_mem": {},
- },
-)
-
-mem_cube = GeneralMemCube(config)
-
-
-mos.register_mem_cube(f"/tmp/{user_name}", mem_cube_id=user_name)
-
+ }
+ _save_pred(pred_path, pred_answers, pred_sp, perf=perf_obj)
+ run_eval(pred_path, gold_path)
-with open("evaluation/data/hotpot/hotpot_dev_distractor_v1.json") as f:
- data = json.load(f)
+ print("\n" + "=" * 60)
+ print("Evaluation finished! Statistics:")
+ print("=" * 60)
+ print(f"Total duration: {total_duration:.2f}s")
+ print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}")
+ if summary["errors"]:
+ print("\nError stats:")
+ for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]:
+ print(f" [{count} times] {error[:100]}...")
-def build_context_text(context_list):
- parts = []
- for title, sentences in context_list:
- text = " ".join(s.strip() for s in sentences if s.strip())
- parts.append(f"{title}: {text}")
- return "\n".join(parts)
-
-
-def build_and_ask(item):
- qid = item["_id"]
- question = item["question"]
-
- for title, sentences in item["context"]:
- text = " ".join(s.strip() for s in sentences if s.strip())
- memory_content = f"{title}: {text}"
- mos.add(memory_content=memory_content)
-
- answer = mos.chat(question).strip()
- return qid, answer
-
-
-pred_answers = {}
-
-with ThreadPoolExecutor(max_workers=5) as executor:
- futures = {executor.submit(build_and_ask, item): item for item in data}
- for future in tqdm(as_completed(futures), total=len(futures)):
- try:
- qid, answer = future.result()
- pred_answers[qid] = answer
- except Exception as e:
- print(f"Error: {e}")
-
-predictions = {"answer": pred_answers, "sp": []}
-with open("evaluation/data/hotpot/output/dev_distractor_pred.json", "w") as f:
- json.dump(predictions, f, ensure_ascii=False, indent=2)
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/hotpot/hotpot_evaluate_v1.py b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py
index d4d6e71e1..19f09996d 100644
--- a/evaluation/scripts/hotpot/hotpot_evaluate_v1.py
+++ b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py
@@ -96,6 +96,9 @@ def eval(prediction_file, gold_file):
with open(gold_file) as f:
gold = json.load(f)
+ evaluated_ids = set((prediction.get("answer") or {}).keys())
+ gold = [dp for dp in gold if (dp.get("_id") or dp.get("id")) in evaluated_ids]
+
metrics = {
"em": 0,
"f1": 0,
@@ -114,12 +117,10 @@ def eval(prediction_file, gold_file):
cur_id = dp["_id"]
can_eval_joint = True
if cur_id not in prediction["answer"]:
- print(f"missing answer {cur_id}")
can_eval_joint = False
else:
em, prec, recall = update_answer(metrics, prediction["answer"][cur_id], dp["answer"])
if cur_id not in prediction["sp"]:
- print(f"missing sp fact {cur_id}")
can_eval_joint = False
else:
sp_em, sp_prec, sp_recall = update_sp(
@@ -140,11 +141,15 @@ def eval(prediction_file, gold_file):
metrics["joint_prec"] += joint_prec
metrics["joint_recall"] += joint_recall
+ print("=========Eval Results===========")
n = len(gold)
- for k in metrics:
- metrics[k] /= n
-
- print(metrics)
+ if n > 0:
+ for k in metrics:
+ metrics[k] /= n
+ print(metrics)
+ else:
+ print(metrics)
+ return metrics
if __name__ == "__main__":
diff --git a/evaluation/scripts/hotpot/hotpot_ingestion.py b/evaluation/scripts/hotpot/hotpot_ingestion.py
new file mode 100644
index 000000000..adbd3471c
--- /dev/null
+++ b/evaluation/scripts/hotpot/hotpot_ingestion.py
@@ -0,0 +1,241 @@
+import argparse
+import json
+import os
+import time
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from evaluation.scripts.hotpot.data_loader import load_hotpot_data
+from evaluation.scripts.utils.metrics import Metrics
+
+
+load_dotenv()
+
+
+def retry_operation(func, *args, retries=5, delay=2, **kwargs):
+ for attempt in range(retries):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ if attempt < retries - 1:
+ func_name = getattr(func, "__name__", "Operation")
+ print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...")
+ time.sleep(delay)
+ delay *= 2
+ else:
+ raise e
+
+
+def _get_lib_client(lib: str):
+ if lib == "mem0":
+ from evaluation.scripts.utils.client import Mem0Client
+
+ return Mem0Client(enable_graph=False)
+ if lib == "supermemory":
+ from evaluation.scripts.utils.client import SupermemoryClient
+
+ return SupermemoryClient()
+ from evaluation.scripts.utils.client import MemosApiClient
+
+ return MemosApiClient()
+
+
+def _load_added_ids(records_path: Path) -> set[str]:
+ if not records_path.exists():
+ return set()
+ try:
+ obj = json.loads(records_path.read_text(encoding="utf-8"))
+ ids = obj.get("added_ids") if isinstance(obj, dict) else None
+ if isinstance(ids, list):
+ return {str(x) for x in ids if x}
+ except Exception:
+ return set()
+ return set()
+
+
+def _save_added_ids(records_path: Path, added_ids: set[str], perf: dict | None = None) -> None:
+ records_path.parent.mkdir(parents=True, exist_ok=True)
+ tmp = records_path.with_suffix(records_path.suffix + ".tmp")
+ obj = {"added_ids": sorted(added_ids)}
+ if perf is not None:
+ obj["perf"] = perf
+ tmp.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")
+ os.replace(tmp, records_path)
+
+
+def _build_memory_texts(ctx: dict | list | None) -> list[str]:
+ tasks: list[str] = []
+ for item in ctx:
+ if not isinstance(item, list | tuple) or len(item) != 2:
+ continue
+ title, sentences = item
+ if not isinstance(sentences, list):
+ continue
+ for idx, sentence in enumerate(sentences):
+ tasks.append(
+ json.dumps({"idx": idx, "title": title, "sentence": sentence}, ensure_ascii=False)
+ )
+ return tasks
+
+
+def add_context_memories(
+ client,
+ lib: str,
+ user_id: str,
+ ctx: dict | list | None,
+ mode: str = "fine",
+ async_mode: str = "sync",
+) -> None:
+ tasks = _build_memory_texts(ctx)
+ if not tasks:
+ return
+
+ if lib == "memos":
+ messages = [{"type": "text", "text": content} for content in tasks]
+ writable_cube_ids = [user_id]
+ retry_operation(
+ client.add,
+ messages=messages,
+ user_id=user_id,
+ writable_cube_ids=writable_cube_ids,
+ source_type="batch_import",
+ mode=mode,
+ async_mode=async_mode,
+ )
+ return
+
+ if lib == "mem0":
+ ts = int(time.time())
+ messages = [{"role": "user", "content": content} for content in tasks]
+ retry_operation(client.add, messages=messages, user_id=user_id, timestamp=ts, batch_size=10)
+ return
+
+ if lib == "supermemory":
+ for content in tasks:
+ retry_operation(client.add, content=content, user_id=user_id)
+
+
+def ingest_one(client, lib: str, item: dict, version_dir: str) -> str:
+ qid = item.get("_id") or item.get("id")
+ ctx = item.get("context")
+
+ user_id = version_dir + "_" + str(qid)
+ add_context_memories(client, lib, user_id, ctx)
+ return str(qid)
+
+
+def main(argv: list[str] | None = None) -> None:
+ parser = argparse.ArgumentParser(description="HotpotQA ingestion (add only).")
+ parser.add_argument(
+ "--lib",
+ type=str,
+ default="memos",
+ choices=["memos", "mem0", "supermemory"],
+ )
+ parser.add_argument("--workers", type=int, default=8)
+ parser.add_argument("--limit", type=int, default=None)
+ parser.add_argument("--version-dir", "-v", default=None, help="Version directory name")
+ parser.add_argument(
+ "--mode", default="fine", choices=["fine", "fast"], help="Processing mode (default: fine)"
+ )
+ parser.add_argument(
+ "--async-mode", default="sync", choices=["sync", "async"], help="Async mode (default: sync)"
+ )
+
+ args = parser.parse_args(argv)
+
+ print("=" * 60)
+ print("hotpotQA Product Add Concurrent Tool")
+ print("=" * 60)
+
+ output_dir = Path("evaluation/data/hotpot")
+ if args.version_dir:
+ output_dir = output_dir / args.version_dir
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ items_list = load_hotpot_data("evaluation/data/hotpot")
+ if args.limit is not None:
+ items_list = items_list[: args.limit]
+
+ records_path = output_dir / f"{args.lib}_added_records.json"
+ added_ids = _load_added_ids(records_path)
+ pending_items = []
+ for it in items_list:
+ qid = it.get("_id") or it.get("id")
+ if str(qid) not in added_ids:
+ pending_items.append(it)
+
+ print(f"[Add] lib={args.lib} total={len(items_list)} pending={len(pending_items)}")
+ if not pending_items:
+ return
+
+ client = _get_lib_client(args.lib)
+ metrics = Metrics()
+
+ def do_ingest(item):
+ start_time = time.perf_counter()
+ try:
+ sid = ingest_one(client, args.lib, item, args.version_dir)
+ duration = time.perf_counter() - start_time
+ metrics.record(duration, True)
+ return sid
+ except Exception as e:
+ duration = time.perf_counter() - start_time
+ metrics.record(duration, False, str(e))
+ raise e
+
+ start_time = time.time()
+ with ThreadPoolExecutor(max_workers=args.workers) as executor:
+ futures = [executor.submit(do_ingest, it) for it in pending_items]
+ for idx, f in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Adding"), 1):
+ try:
+ sid = f.result()
+ if sid:
+ added_ids.add(str(sid))
+ if idx % 20 == 0:
+ _save_added_ids(records_path, added_ids)
+ except Exception as e:
+ print(f"[Add] Error: {e}")
+
+ _save_added_ids(records_path, added_ids)
+ print(f"[Add] saved records to {records_path}")
+
+ total_duration = time.time() - start_time
+ summary = metrics.summary()
+ perf_obj = {
+ "summary": summary,
+ "total_duration": total_duration,
+ "config": {
+ "workers": args.workers,
+ "mode": args.mode,
+ "async_mode": args.async_mode,
+ "lib": args.lib,
+ },
+ }
+ _save_added_ids(records_path, added_ids, perf=perf_obj)
+
+ print("\n" + "=" * 60)
+ print("Ingestion finished! Statistics:")
+ print("=" * 60)
+ print(f"Total duration: {total_duration:.2f}s")
+ print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}")
+
+ if summary["stats"]:
+ stats = summary["stats"]
+ qps = stats["count"] / total_duration if total_duration > 0 else 0
+ print(f"QPS: {qps:.2f}")
+ print("Latency stats (ms):")
+ print(f" Mean: {stats['mean']:.2f}")
+ print(f" Median: {stats['median']:.2f}")
+ print(f" Min: {stats['min']:.2f}")
+ print(f" Max: {stats['max']:.2f}")
+ print(f" P95: {stats['p95']:.2f}")
+ print(f" P99: {stats['p99']:.2f}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/hotpot/hotpot_old.py b/evaluation/scripts/hotpot/hotpot_old.py
new file mode 100644
index 000000000..1e986431d
--- /dev/null
+++ b/evaluation/scripts/hotpot/hotpot_old.py
@@ -0,0 +1,309 @@
+import importlib.util
+import json
+import os
+import sys
+import time
+import traceback
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from dotenv import load_dotenv
+from openai import OpenAI
+from tqdm import tqdm
+
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from utils.client import MemosApiClient
+from utils.prompts import LME_ANSWER_PROMPT, MEMOS_CONTEXT_TEMPLATE
+
+from evaluation.scripts.hotpot.data_loader import load_hotpot_data
+from memos.reranker.strategies.dialogue_common import extract_texts_and_sp_from_sources
+
+
+load_dotenv()
+os.environ["SEARCH_MODE"] = os.environ.get("SEARCH_MODE", "fine")
+client = MemosApiClient()
+oai_client = OpenAI(
+ api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL")
+)
+
+pred_answers = {}
+pred_sp = {}
+output_dir = "evaluation/data/hotpot/output"
+os.makedirs(output_dir, exist_ok=True)
+pred_path = os.path.join(output_dir, "dev_distractor_pred.json")
+gold_path = "evaluation/data/hotpot/dev_distractor_gold.json"
+
+
+def add_context_memories(user_id: str, ctx: dict | list | None):
+ tasks = []
+ if isinstance(ctx, dict):
+ titles = ctx.get("title") or []
+ sentences_list = ctx.get("sentences") or []
+ for title, sentences in zip(titles, sentences_list, strict=False):
+ for idx, sentence in enumerate(sentences):
+ memory_content = f"{title}: {sentence} [#{idx}]"
+ tasks.append(memory_content)
+ elif isinstance(ctx, list):
+ for item in ctx:
+ if isinstance(item, list) and len(item) >= 2:
+ title = item[0]
+ sentences = item[1]
+ for idx, sentence in enumerate(sentences):
+ memory_content = f"{title}: {sentence} [#{idx}]"
+ tasks.append(memory_content)
+
+ if not tasks:
+ return
+
+ iso = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+ messages = [{"role": "user", "content": content, "created_at": iso} for content in tasks]
+ client.add(messages=messages, user_id=user_id, conv_id=user_id)
+
+
+def memos_search(user_id: str, query: str, top_k):
+ results = client.search(query=query, user_id=user_id, top_k=top_k)
+ memories = results["text_mem"][0]["memories"]
+ print("Search memories:", len(memories))
+
+ context = "\n".join([i["memory"] for i in memories]) + f"\n{results.get('pref_string', '')}"
+ context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context)
+
+ # Extract supporting facts (sp) from raw sources
+ sp_list: list[list[str | int]] = []
+ for m in memories:
+ sources = (m.get("metadata", {}) or {}).get("sources") or []
+ texts, sps = extract_texts_and_sp_from_sources(sources)
+ for t, s in sps:
+ sp_list.append([t, s])
+
+ # De-duplicate while preserving order
+ seen = set()
+ dedup_sp = []
+ for t, s in sp_list:
+ key = (t, s)
+ if key not in seen:
+ seen.add(key)
+ dedup_sp.append([t, s])
+
+ return context, dedup_sp
+
+
+def llm_response(context: str, question: str, question_date: str | None = None) -> str:
+ prompt = LME_ANSWER_PROMPT.format(
+ question=question, question_date=question_date or "", context=context
+ )
+ resp = oai_client.chat.completions.create(
+ model=os.getenv("CHAT_MODEL"),
+ messages=[{"role": "system", "content": prompt}],
+ temperature=0,
+ )
+ return resp.choices[0].message.content or ""
+
+
+def extract_answer(question: str, output: str, model_name: str | None = None) -> str:
+ try:
+ response = oai_client.chat.completions.create(
+ model=model_name or os.getenv("CHAT_MODEL"),
+ messages=[
+ {
+ "role": "user",
+ "content": (
+ "You are an answer extractor. Given a question and a verbose response, "
+ "return ONLY the concise final answer suitable for HotpotQA exact match.\n\n"
+ "Rules:\n"
+ "- If the question asks yes/no, answer strictly 'yes' or 'no'.\n"
+ "- Otherwise, output the shortest noun phrase/entity/date/number that answers the question.\n"
+ "- No explanations, no punctuation beyond what's necessary for the answer.\n\n"
+ f"Question: {question}\nVerbose response: {output}\nFinal answer:"
+ ),
+ }
+ ],
+ temperature=0.0,
+ max_tokens=64,
+ top_p=1,
+ )
+ ans = (response.choices[0].message.content or "").strip()
+ return ans
+ except Exception:
+ text = (output or "").lower()
+ if " yes" in text or text.startswith("yes"):
+ return "yes"
+ if " no" in text or text.startswith("no"):
+ return "no"
+ for sep in ["\n", ". ", ".", "?", "!"]:
+ if sep in output:
+ cand = output.split(sep)[0].strip()
+ if cand:
+ return cand
+ return (output or "").strip()
+
+
+def build_context_text(context_list):
+ parts = []
+ for title, sentences in context_list:
+ text = " ".join(s.strip() for s in sentences if s.strip())
+ parts.append(f"{title}: {text}")
+ return "\n".join(parts)
+
+
+def ingest_context(item):
+ qid = item.get("_id") or item.get("id")
+ ctx = item.get("context")
+ add_context_memories(qid, ctx)
+ return qid
+
+
+def search_and_ask(item):
+ qid = item.get("_id") or item.get("id")
+ question = item["question"]
+ try:
+ context, sp_list = memos_search(qid, question, top_k=7)
+ raw_answer = llm_response(context=context, question=question, question_date="")
+ answer = extract_answer(question, raw_answer) or ""
+ print("Question:", question)
+ print("Answer (raw):", raw_answer)
+ print("Answer (final):", answer)
+ pred_sp[qid] = sp_list
+ return qid, answer
+ except Exception as e:
+ print(f"[Question {qid}] Error:", e)
+ traceback.print_exc()
+ return qid, ""
+
+
+def write_gold(data, out_path: str | None = None):
+ split = data.get("validation")
+ items_list = [split[i] for i in range(len(split))]
+ out = []
+ for it in items_list:
+ qid = it.get("_id") or it.get("id")
+ sp = it.get("supporting_facts")
+ if isinstance(sp, dict):
+ titles = sp.get("title") or []
+ sent_ids = sp.get("sent_id") or []
+ sp_list = [[t, s] for t, s in zip(titles, sent_ids, strict=False)]
+ else:
+ sp_list = sp or []
+ ctx = it.get("context")
+ if isinstance(ctx, dict):
+ titles = ctx.get("title") or []
+ sentences = ctx.get("sentences") or []
+ ctx_list = [[t, s] for t, s in zip(titles, sentences, strict=False)]
+ else:
+ ctx_list = ctx or []
+ out.append(
+ {
+ "_id": qid,
+ "question": it.get("question"),
+ "answer": it.get("answer"),
+ "supporting_facts": sp_list,
+ "context": ctx_list,
+ }
+ )
+ target_path = out_path or gold_path
+ tmp_path = target_path + ".tmp"
+ try:
+ with open(tmp_path, "w", encoding="utf-8") as f:
+ json.dump(out, f, ensure_ascii=False)
+ f.flush()
+ os.fsync(f.fileno())
+ os.replace(tmp_path, target_path)
+ except Exception as e:
+ print("Failed to save gold:", e)
+
+
+def run_eval(pred_file: str | None = None, gold_file: str | None = None):
+ spec = importlib.util.spec_from_file_location(
+ "hotpot_eval_v1", "evaluation/scripts/hotpot/hotpot_evaluate_v1.py"
+ )
+ m = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(m)
+ m.eval(pred_file or pred_path, gold_file or gold_path)
+
+
+def save_pred():
+ tmp_path = pred_path + ".tmp"
+ try:
+ safe_pred_answers = {
+ k: (v if isinstance(v, str) else ("" if v is None else str(v)))
+ for k, v in pred_answers.items()
+ }
+ with open(tmp_path, "w", encoding="utf-8") as f:
+ json.dump({"answer": safe_pred_answers, "sp": pred_sp}, f, ensure_ascii=False, indent=2)
+ f.flush()
+ os.fsync(f.fileno())
+ os.replace(tmp_path, pred_path)
+ except Exception as e:
+ print("Failed to save:", e)
+
+
+def main():
+ interval = 10
+ items_list = load_hotpot_data("evaluation/data/hotpot")
+
+ if os.path.exists(pred_path):
+ try:
+ with open(pred_path, encoding="utf-8") as f:
+ prev = json.load(f)
+ if isinstance(prev, dict) and isinstance(prev.get("answer"), dict):
+ prev_ans = {
+ k: (v if isinstance(v, str) else ("" if v is None else str(v)))
+ for k, v in prev["answer"].items()
+ }
+ pred_answers.update(prev_ans)
+ if isinstance(prev, dict) and isinstance(prev.get("sp"), dict):
+ pred_sp.update(prev["sp"])
+ except Exception as e:
+ print("Failed to read historical predictions:", e)
+
+ processed = len(pred_answers)
+ print("Starting evaluation, total samples:", len(items_list))
+ print("Existing predictions:", processed)
+
+ pending_items = []
+ for it in items_list:
+ qid = it.get("_id") or it.get("id")
+ if qid not in pred_answers:
+ pending_items.append(it)
+
+ if pending_items:
+ print(f"[Step1: Ingest] start, items={len(pending_items)}")
+ with ThreadPoolExecutor(max_workers=8) as executor:
+ futures = {
+ executor.submit(ingest_context, item): idx for idx, item in enumerate(pending_items)
+ }
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Ingest"):
+ try:
+ future.result()
+ except Exception as e:
+ print("Ingest thread failed:", e)
+ traceback.print_exc()
+
+ print(f"[Step2: QA] start, items={len(pending_items)}")
+ with ThreadPoolExecutor(max_workers=8) as executor:
+ futures = {
+ executor.submit(search_and_ask, item): idx for idx, item in enumerate(pending_items)
+ }
+ for future in tqdm(as_completed(futures), total=len(futures), desc="QA"):
+ try:
+ qid, answer = future.result()
+ except Exception as e:
+ print("QA thread failed:", e)
+ traceback.print_exc()
+ continue
+ pred_answers[qid] = answer
+ processed += 1
+ if processed % 10 == 0:
+ print("Completed:", processed, "Remaining:", len(items_list) - processed)
+ save_pred()
+ if processed % interval == 0:
+ print("Stage evaluation, current progress:", processed)
+ run_eval()
+
+ run_eval()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/hotpot/hotpot_search.py b/evaluation/scripts/hotpot/hotpot_search.py
new file mode 100644
index 000000000..6542d4539
--- /dev/null
+++ b/evaluation/scripts/hotpot/hotpot_search.py
@@ -0,0 +1,297 @@
+import argparse
+import json
+import os
+import time
+import traceback
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+
+from tqdm import tqdm
+
+from evaluation.scripts.hotpot.data_loader import load_hotpot_data
+from evaluation.scripts.utils.metrics import Metrics
+
+
+def retry_operation(func, *args, retries=5, delay=2, **kwargs):
+ for attempt in range(retries):
+ try:
+ result = func(*args, **kwargs)
+ if isinstance(result, dict) and "data" in result:
+ return result["data"]
+ return result
+ except Exception as e:
+ if attempt < retries - 1:
+ func_name = getattr(func, "__name__", "Operation")
+ print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...")
+ time.sleep(delay)
+ delay *= 2
+ else:
+ raise e
+
+
+def _get_lib_client(lib: str):
+ if lib == "mem0":
+ from evaluation.scripts.utils.client import Mem0Client
+
+ return Mem0Client(enable_graph=False)
+ if lib == "supermemory":
+ from evaluation.scripts.utils.client import SupermemoryClient
+
+ return SupermemoryClient()
+ from evaluation.scripts.utils.client import MemosApiClient
+
+ return MemosApiClient()
+
+
+def _load_existing_results(output_path: Path) -> tuple[list[dict], set[str]]:
+ if not output_path.exists():
+ return [], set()
+ try:
+ data = json.loads(output_path.read_text(encoding="utf-8"))
+ if isinstance(data, list):
+ ids = {str(r.get("_id")) for r in data if r.get("_id")}
+ return data, ids
+ if isinstance(data, dict) and isinstance(data.get("results"), list):
+ rows = data.get("results") or []
+ ids = {str(r.get("_id")) for r in rows if r.get("_id")}
+ return rows, ids
+ except Exception:
+ return [], set()
+ return [], set()
+
+
+def _save_json_list(path: Path, rows: list[dict]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ tmp = path.with_suffix(path.suffix + ".tmp")
+ tmp.write_text(json.dumps({"results": rows}, ensure_ascii=False, indent=2), encoding="utf-8")
+ os.replace(tmp, path)
+
+
+def get_sources_info(sources):
+ seen = set()
+ dedup_sp = []
+ mem_texts = []
+
+ for source in sources:
+ if isinstance(source, str):
+ try:
+ obj = json.loads(source)
+ except json.JSONDecodeError:
+ continue
+
+ title = obj.get("title")
+ idx = obj.get("idx")
+ sentence = obj.get("sentence")
+
+ if title is None or idx is None:
+ continue
+
+ key = (title, idx)
+ if key not in seen:
+ seen.add(key)
+ dedup_sp.append([title, idx])
+ mem_texts.append(sentence)
+
+ return mem_texts, dedup_sp
+
+
+def memos_search(
+ client, user_id: str, query: str, top_k: int, search_mode: str
+) -> tuple[str, list[list[str | int]]]:
+ readable_cube_ids = [user_id]
+ results = retry_operation(
+ client.search,
+ query=query,
+ user_id=user_id,
+ readable_cube_ids=readable_cube_ids,
+ top_k=top_k,
+ mode=search_mode,
+ )
+ memories = results["text_mem"][0]["memories"]
+ mem_texts = [i["memory"] for i in memories]
+
+ sources = []
+ for m in memories:
+ source = (m.get("metadata", {}) or {}).get("sources") or []
+ for s in source:
+ source_txt = json.loads(s["content"])
+ sources.append(json.loads(source_txt)["content"])
+ sources.extend(source)
+
+ _, dedup_sp = get_sources_info(sources)
+ return mem_texts, dedup_sp
+
+
+def mem0_search(client, user_id: str, query: str, top_k: int) -> tuple[str, list[list[str | int]]]:
+ res = retry_operation(client.search, query, user_id, top_k)
+ sources = [m.get("memory", "") for m in res.get("results", []) if m.get("memory")]
+ mem_texts, dedup_sp = get_sources_info(sources)
+ return mem_texts, dedup_sp
+
+
+def supermemory_search(
+ client, user_id: str, query: str, top_k: int
+) -> tuple[str, list[list[str | int]]]:
+ sources = retry_operation(client.search, query, user_id, top_k)
+ mem_texts, dedup_sp = get_sources_info(sources)
+ return mem_texts, dedup_sp
+
+
+def search_one(
+ client, lib: str, item: dict, top_k: int, version_dir: str, search_mode: str
+) -> dict:
+ qid = item.get("_id") or item.get("id")
+ question = item.get("question") or ""
+ user_id = version_dir + "_" + str(qid)
+
+ if lib == "memos":
+ memories, sp_list = memos_search(client, user_id, str(question), top_k, search_mode)
+ elif lib == "mem0":
+ memories, sp_list = mem0_search(client, user_id, str(question), top_k)
+ elif lib == "supermemory":
+ memories, sp_list = supermemory_search(client, user_id, str(question), top_k)
+ else:
+ memories, sp_list = [], []
+
+ return {
+ "_id": str(qid),
+ "question": question,
+ "answer": item.get("answer"),
+ "memories": memories,
+ "sp": sp_list,
+ }
+
+
+def main(argv: list[str] | None = None) -> None:
+ parser = argparse.ArgumentParser(description="HotpotQA search (search only).")
+ parser.add_argument(
+ "--lib",
+ type=str,
+ default="memos",
+ choices=["memos", "mem0", "supermemory"],
+ )
+ parser.add_argument("--workers", type=int, default=8)
+ parser.add_argument("--top-k", type=int, default=7)
+ parser.add_argument(
+ "--limit", type=int, default=None, help="Limit number of samples (was max_samples)"
+ )
+ parser.add_argument("--version-dir", "-v", default=None, help="Version directory name")
+ parser.add_argument("--search-mode", default="fine", help="Search mode")
+
+ args = parser.parse_args(argv)
+
+ # Handle limit/max_samples compatibility
+ limit = args.limit if args.limit is not None else args.max_samples
+
+ items_list = load_hotpot_data("evaluation/data/hotpot")
+ if limit is not None:
+ items_list = items_list[:limit]
+
+ output_dir = Path(f"evaluation/data/hotpot/{args.version_dir}")
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ if args.lib == "memos":
+ output_path = output_dir / f"{args.lib}_{args.search_mode}_search_results.json"
+ else:
+ output_path = output_dir / f"{args.lib}_search_results.json"
+
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ results, processed_ids = _load_existing_results(output_path)
+ pending_items = []
+ for it in items_list:
+ qid = it.get("_id") or it.get("id")
+ if str(qid) not in processed_ids:
+ pending_items.append(it)
+
+ print(
+ f"[Search] lib={args.lib} total={len(items_list)} pending={len(pending_items)} top_k={args.top_k}"
+ )
+ if not pending_items:
+ return
+
+ client = _get_lib_client(args.lib)
+ metrics = Metrics()
+ start_time = time.time()
+
+ with ThreadPoolExecutor(max_workers=args.workers) as executor:
+
+ def do_search(item):
+ st = time.perf_counter()
+ try:
+ r = search_one(
+ client, args.lib, item, args.top_k, args.version_dir, args.search_mode
+ )
+ dur = time.perf_counter() - st
+ metrics.record(dur, True)
+ return r
+ except Exception as e:
+ dur = time.perf_counter() - st
+ metrics.record(dur, False, str(e))
+ raise e
+
+ futures = [executor.submit(do_search, it) for it in pending_items]
+ for idx, f in enumerate(
+ tqdm(as_completed(futures), total=len(futures), desc="Searching"), 1
+ ):
+ try:
+ r = f.result()
+ results.append(r)
+ if idx % 20 == 0:
+ _save_json_list(output_path, results)
+ except Exception as e:
+ print(f"[Search] Error: {e}")
+ traceback.print_exc()
+
+ _save_json_list(output_path, results)
+ print(f"[Search] saved {len(results)} rows to {output_path}")
+
+ # Save performance metrics
+ total_duration = time.time() - start_time
+ summary = metrics.summary()
+ # Merge perf into results json file
+ combined_obj = {
+ "results": results,
+ "perf": {
+ "summary": summary,
+ "total_duration": total_duration,
+ "config": {
+ "workers": args.workers,
+ "top_k": args.top_k,
+ "limit": limit,
+ "search_mode": args.search_mode,
+ "lib": args.lib,
+ },
+ },
+ }
+ tmp = output_path.with_suffix(output_path.suffix + ".tmp")
+ tmp.write_text(json.dumps(combined_obj, ensure_ascii=False, indent=2), encoding="utf-8")
+ os.replace(tmp, output_path)
+
+ print("\n" + "=" * 60)
+ print("Search finished! Statistics:")
+ print("=" * 60)
+ print(f"Total duration: {total_duration:.2f}s")
+ print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}")
+
+ if summary["stats"]:
+ stats = summary["stats"]
+ qps = stats["count"] / total_duration if total_duration > 0 else 0
+ print(f"QPS: {qps:.2f}")
+ print("Latency stats (ms):")
+ print(f" Mean: {stats['mean']:.2f}")
+ print(f" Median: {stats['median']:.2f}")
+ print(f" Min: {stats['min']:.2f}")
+ print(f" Max: {stats['max']:.2f}")
+ print(f" P95: {stats['p95']:.2f}")
+ print(f" P99: {stats['p99']:.2f}")
+
+ if summary["errors"]:
+ print("\nError stats:")
+ for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]:
+ print(f" [{count} times] {error[:100]}...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/long_bench-v2/__init__.py b/evaluation/scripts/long_bench-v2/__init__.py
new file mode 100644
index 000000000..786c0ce03
--- /dev/null
+++ b/evaluation/scripts/long_bench-v2/__init__.py
@@ -0,0 +1 @@
+# LongBench v2 evaluation scripts
diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py
new file mode 100644
index 000000000..5a5c11968
--- /dev/null
+++ b/evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py
@@ -0,0 +1,199 @@
+import argparse
+import json
+import os
+import sys
+import threading
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+
+ROOT_DIR = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts")
+
+sys.path.insert(0, ROOT_DIR)
+sys.path.insert(0, EVAL_SCRIPTS_DIR)
+
+
+def ingest_sample(
+ client, sample, sample_idx, frame, version, success_records, record_file, file_lock
+):
+ """Ingest a single LongBench v2 sample as memories."""
+ # Skip if already processed
+ if str(sample_idx) in success_records:
+ return True
+
+ user_id = f"longbench_v2_{sample_idx}_{version}"
+ conv_id = f"longbench_v2_{sample_idx}_{version}"
+
+ # Get context and convert to messages
+ context = sample.get("context", "")
+
+ # For memos, we ingest the context as a raw document content
+ messages = [
+ {
+ "type": "file",
+ "file": {
+ "file_data": context,
+ "file_id": str(sample_idx),
+ },
+ }
+ ]
+
+ if "memos-api" in frame:
+ try:
+ client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1)
+ print(f"✅ [{frame}] Ingested sample {sample_idx}")
+ # Record successful ingestion (thread-safe)
+ with file_lock, open(record_file, "a") as f:
+ f.write(f"{sample_idx}\n")
+ f.flush()
+ return True
+ except Exception as e:
+ print(f"❌ [{frame}] Error ingesting sample {sample_idx}: {e}")
+ return False
+
+ return False
+
+
+def load_dataset_from_local():
+ """Load LongBench v2 dataset from local JSON file."""
+ data_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
+ "data",
+ "long_bench_v2",
+ )
+
+ filepath = os.path.join(data_dir, "data.json")
+
+ if not os.path.exists(filepath):
+ raise FileNotFoundError(f"Dataset file not found: {filepath}")
+
+ # Load JSON file
+ with open(filepath, encoding="utf-8") as f:
+ samples = json.load(f)
+
+ return samples
+
+
+def main(frame, version="default", num_workers=10, max_samples=None):
+ """Main ingestion function."""
+ load_dotenv()
+
+ print("\n" + "=" * 80)
+ print(f"🚀 LONGBENCH V2 INGESTION - {frame.upper()} v{version}".center(80))
+ print("=" * 80 + "\n")
+
+ # Load dataset from local file
+ try:
+ dataset = load_dataset_from_local()
+ print(f"Loaded {len(dataset)} samples from LongBench v2")
+ except FileNotFoundError as e:
+ print(f"❌ Error loading dataset: {e}")
+ return
+ except Exception as e:
+ print(f"❌ Error loading dataset: {e}")
+ return
+
+ # Limit samples if specified
+ if max_samples:
+ dataset = dataset[:max_samples]
+ print(f"Limited to {len(dataset)} samples")
+
+ # Initialize checkpoint file for resume functionality
+ checkpoint_dir = os.path.join(
+ ROOT_DIR, "evaluation", "results", "long_bench_v2", f"{frame}-{version}"
+ )
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ record_file = os.path.join(checkpoint_dir, "success_records.txt")
+
+ # Load existing success records for resume
+ success_records = set()
+ if os.path.exists(record_file):
+ with open(record_file) as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ success_records.add(line)
+ print(f"📋 Found {len(success_records)} already processed samples (resume mode)")
+ else:
+ print("📋 Starting fresh ingestion (no checkpoint found)")
+
+ # Initialize client
+ client = None
+ if frame == "memos-api":
+ from utils.client import MemosApiClient
+
+ client = MemosApiClient()
+ else:
+ print(f"❌ Unsupported frame: {frame}")
+ return
+
+ # Ingest samples
+ success_count = len(success_records) # Start with already processed count
+ file_lock = threading.Lock() # Lock for thread-safe file writing
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = []
+ for idx, sample in enumerate(dataset):
+ future = executor.submit(
+ ingest_sample,
+ client,
+ sample,
+ idx,
+ frame,
+ version,
+ success_records,
+ record_file,
+ file_lock,
+ )
+ futures.append(future)
+
+ for future in tqdm(
+ as_completed(futures),
+ total=len(futures),
+ desc="Ingesting LongBench v2",
+ ):
+ try:
+ if future.result():
+ success_count += 1
+ except Exception as e:
+ print(f"Error processing sample: {e}")
+
+ print(f"\n{'=' * 80}")
+ print(f"✅ INGESTION COMPLETE: {success_count}/{len(dataset)} samples ingested".center(80))
+ print(f"{'=' * 80}\n")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lib",
+ type=str,
+ choices=["memos-api", "memos-api-online"],
+ default="memos-api",
+ )
+ parser.add_argument(
+ "--version",
+ type=str,
+ default="default",
+ help="Version identifier for saving results",
+ )
+ parser.add_argument(
+ "--workers",
+ type=int,
+ default=2,
+ help="Number of parallel workers",
+ )
+ parser.add_argument(
+ "--max_samples",
+ type=int,
+ default=None,
+ help="Maximum number of samples to process (default: all)",
+ )
+ args = parser.parse_args()
+
+ main(args.lib, args.version, args.workers, args.max_samples)
diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_metric.py b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py
new file mode 100644
index 000000000..af324c9c7
--- /dev/null
+++ b/evaluation/scripts/long_bench-v2/longbench_v2_metric.py
@@ -0,0 +1,176 @@
+import argparse
+import json
+import os
+
+
+def calculate_accuracy(responses):
+ """Calculate accuracy metrics for LongBench v2.
+
+ Logic is aligned with longbench_stx.print_metrics, but returns a dict
+ and additionally computes by_domain statistics.
+ """
+ total = len(responses)
+ if total == 0:
+ return {}
+
+ # Counters (aligned with longbench_stx.print_metrics)
+ easy = hard = short = medium = long = 0
+ easy_acc = hard_acc = short_acc = medium_acc = long_acc = 0
+ total_prompt_tokens = 0
+
+ for pred in responses:
+ acc = int(pred.get("judge", False))
+ diff = pred.get("difficulty", "easy")
+ length = pred.get("length", "short")
+
+ pt = pred.get("prompt_tokens")
+ if isinstance(pt, int | float):
+ total_prompt_tokens += int(pt)
+
+ if diff == "easy":
+ easy += 1
+ easy_acc += acc
+ else:
+ hard += 1
+ hard_acc += acc
+
+ if length == "short":
+ short += 1
+ short_acc += acc
+ elif length == "medium":
+ medium += 1
+ medium_acc += acc
+ else:
+ long += 1
+ long_acc += acc
+
+ o_acc = round(100 * (easy_acc + hard_acc) / total, 2)
+ e_acc = round(100 * easy_acc / easy, 2) if easy > 0 else 0.0
+ h_acc = round(100 * hard_acc / hard, 2) if hard > 0 else 0.0
+ s_acc = round(100 * short_acc / short, 2) if short > 0 else 0.0
+ m_acc = round(100 * medium_acc / medium, 2) if medium > 0 else 0.0
+ l_acc = round(100 * long_acc / long, 2) if long > 0 else 0.0
+
+ # Additional by-domain stats (extra vs. stx)
+ domain_stats = {}
+ for r in responses:
+ domain = r.get("domain", "Unknown")
+ if domain not in domain_stats:
+ domain_stats[domain] = {"total": 0, "correct": 0}
+ domain_stats[domain]["total"] += 1
+ if r.get("judge", False):
+ domain_stats[domain]["correct"] += 1
+
+ domain_acc = {
+ domain: round(100 * stats["correct"] / stats["total"], 2)
+ for domain, stats in domain_stats.items()
+ }
+
+ return {
+ "overall": o_acc,
+ "easy": e_acc,
+ "hard": h_acc,
+ "short": s_acc,
+ "medium": m_acc,
+ "long": l_acc,
+ "by_domain": domain_acc,
+ "total_samples": total,
+ "correct_samples": easy_acc + hard_acc,
+ "total_prompt_tokens": total_prompt_tokens,
+ "avg_prompt_tokens": round(total_prompt_tokens / total, 2) if total > 0 else 0.0,
+ }
+
+
+def main(frame, version="default"):
+ """Main metric calculation function."""
+ print("\n" + "=" * 80)
+ print(f"📊 LONGBENCH V2 METRICS CALCULATION - {frame.upper()} v{version}".center(80))
+ print("=" * 80 + "\n")
+
+ # Load responses
+ responses_path = f"results/long_bench_v2/{frame}-{version}/{frame}_longbench_v2_responses.json"
+ if not os.path.exists(responses_path):
+ print(f"❌ Responses not found: {responses_path}")
+ print("Please run longbench_v2_responses.py first")
+ return
+
+ with open(responses_path, encoding="utf-8") as f:
+ responses = json.load(f)
+
+ # Only keep entries that actually have search results:
+ # - For new pipeline: non-empty memories_used list
+ # - For older runs: non-empty search_context string
+ def _has_search_results(r: dict) -> bool:
+ mems = r.get("memories_used")
+ if isinstance(mems, list) and any(str(m).strip() for m in mems):
+ return True
+ ctx = str(r.get("search_context", "")).strip()
+ return ctx != ""
+
+ filtered = [r for r in responses if _has_search_results(r)]
+
+ # Calculate metrics (handle case where no samples have search results)
+ if not filtered:
+ print("⚠️ No responses with valid search results were found. Metrics will be zeroed.")
+ metrics = {
+ "overall": 0.0,
+ "easy": 0.0,
+ "hard": 0.0,
+ "short": 0.0,
+ "medium": 0.0,
+ "long": 0.0,
+ "by_domain": {},
+ "total_samples": 0,
+ "correct_samples": 0,
+ "total_prompt_tokens": 0,
+ "avg_prompt_tokens": 0.0,
+ }
+ else:
+ metrics = calculate_accuracy(filtered)
+
+ # Save metrics
+ output_path = f"results/long_bench_v2/{frame}-{version}/{frame}_longbench_v2_metrics.json"
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(metrics, f, ensure_ascii=False, indent=4)
+
+ print(f"\n{'=' * 80}")
+ print(f"✅ METRICS CALCULATION COMPLETE: Results saved to {output_path}".center(80))
+ print(f"{'=' * 80}\n")
+
+ # Print summary table
+ print("\n📊 Summary of Results:")
+ print("-" * 80)
+ print(f"{'Overall Accuracy':<30s}: {metrics['overall']:.2f}%")
+ print(f"{'Easy':<30s}: {metrics['easy']:.2f}%")
+ print(f"{'Hard':<30s}: {metrics['hard']:.2f}%")
+ print(f"{'Short':<30s}: {metrics['short']:.2f}%")
+ print(f"{'Medium':<30s}: {metrics['medium']:.2f}%")
+ print(f"{'Long':<30s}: {metrics['long']:.2f}%")
+ print(f"{'Avg Prompt Tokens':<30s}: {metrics.get('avg_prompt_tokens', 0.0):.2f}")
+ print("\nBy Domain:")
+ for domain, acc in metrics["by_domain"].items():
+ print(f" {domain:<28s}: {acc:.1f}%")
+ print(f"\nTotal Samples: {metrics['total_samples']}")
+ print(f"Correct: {metrics['correct_samples']}")
+ print("-" * 80)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lib",
+ type=str,
+ choices=["memos-api", "memos-api-online"],
+ default="memos-api",
+ )
+ parser.add_argument(
+ "--version",
+ type=str,
+ default="default",
+ help="Version identifier for loading results",
+ )
+ args = parser.parse_args()
+
+ main(args.lib, args.version)
diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_responses.py b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py
new file mode 100644
index 000000000..686062c5f
--- /dev/null
+++ b/evaluation/scripts/long_bench-v2/longbench_v2_responses.py
@@ -0,0 +1,319 @@
+import argparse
+import json
+import os
+import re
+import sys
+import threading
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from time import time
+
+from dotenv import load_dotenv
+from openai import OpenAI
+from tqdm import tqdm
+
+
+ROOT_DIR = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts")
+
+sys.path.insert(0, ROOT_DIR)
+sys.path.insert(0, EVAL_SCRIPTS_DIR)
+
+
+# RAG-style prompt template aligned with longbench_stx.TEMPLATE_RAG
+TEMPLATE_RAG = """Please read the following retrieved text chunks and answer the question below.
+
+
+$DOC$
+
+
+What is the correct answer to this question: $Q$
+Choices:
+(A) $C_A$
+(B) $C_B$
+(C) $C_C$
+(D) $C_D$
+
+Format your response as follows: "The correct answer is (insert answer here)"."""
+
+
+def extract_answer(response):
+ """Extract answer from response (A, B, C, or D).
+
+ Logic is kept consistent with longbench_stx.extract_answer.
+ """
+ response = response.replace("*", "")
+ # Try to find "The correct answer is (X)" pattern
+ match = re.search(r"The correct answer is \(([A-D])\)", response)
+ if match:
+ return match.group(1)
+ else:
+ match = re.search(r"The correct answer is ([A-D])", response)
+ if match:
+ return match.group(1)
+ return None
+
+
+def llm_answer(llm_client, memories, question, choices):
+ """Generate response using RAG-style prompt, aligned with longbench_stx.llm_answer.
+
+ Returns:
+ tuple[str, int | None]: (response_text, prompt_tokens)
+ """
+ # Join memories to form the retrieved context document
+ doc_content = "\n\n".join([f"Retrieved chunk {idx + 1}: {m}" for idx, m in enumerate(memories)])
+
+ prompt = (
+ TEMPLATE_RAG.replace("$DOC$", doc_content)
+ .replace("$Q$", question)
+ .replace("$C_A$", choices.get("A", ""))
+ .replace("$C_B$", choices.get("B", ""))
+ .replace("$C_C$", choices.get("C", ""))
+ .replace("$C_D$", choices.get("D", ""))
+ )
+
+ try:
+ response = llm_client.chat.completions.create(
+ model=os.getenv("CHAT_MODEL"),
+ messages=[{"role": "user", "content": prompt}],
+ temperature=0.1,
+ max_tokens=12800,
+ )
+ text = response.choices[0].message.content or ""
+ prompt_tokens = None
+ usage = getattr(response, "usage", None)
+ if usage is not None:
+ # openai>=1.x style: usage.prompt_tokens
+ pt = getattr(usage, "prompt_tokens", None)
+ if isinstance(pt, int):
+ prompt_tokens = pt
+ else:
+ # fallback for dict-like usage
+ try:
+ prompt_tokens = int(usage.get("prompt_tokens")) # type: ignore[call-arg]
+ except Exception:
+ prompt_tokens = None
+ return text, prompt_tokens
+ except Exception as e:
+ print(f"Error generating response: {e}")
+ return "", None
+
+
+def process_sample(search_result, llm_client, success_records, record_file, file_lock):
+ """Process a single sample: generate answer.
+
+ This mirrors longbench_stx.evaluate_sample but consumes precomputed search results
+ produced by longbench_v2_search.py.
+ """
+ # Use sample_idx when available, otherwise fall back to _id so that
+ # we can work with stx-style search results that only have _id.
+ sample_idx = search_result.get("sample_idx")
+ sample_key = str(sample_idx) if sample_idx is not None else str(search_result.get("_id", ""))
+
+ # Skip if already processed
+ if sample_key and sample_key in success_records:
+ return None
+
+ start = time()
+
+ question = search_result.get("question", "")
+ choices = {
+ "A": search_result.get("choice_A", "") or "",
+ "B": search_result.get("choice_B", "") or "",
+ "C": search_result.get("choice_C", "") or "",
+ "D": search_result.get("choice_D", "") or "",
+ }
+
+ # Prefer memories saved by longbench_v2_search; fall back to reconstructing
+ # from raw search_results if needed (for old search jsons).
+ memories = search_result.get("memories_used")
+ if memories is None:
+ raw = search_result.get("search_results") or {}
+ memories = []
+ if isinstance(raw, dict) and raw.get("text_mem"):
+ text_mem = raw["text_mem"]
+ if text_mem and text_mem[0].get("memories"):
+ memories = [
+ m.get("memory", "") for m in text_mem[0]["memories"] if isinstance(m, dict)
+ ]
+
+ # Ensure we have a list, even if empty
+ memories = memories or []
+
+ # Skip if no retrieved memories and no question
+ if not question:
+ return None
+ if not memories:
+ return None
+
+ # Generate answer
+ response, prompt_tokens = llm_answer(llm_client, memories, str(question), choices)
+
+ # Extract answer (A, B, C, or D)
+ pred = extract_answer(response)
+
+ response_duration_ms = (time() - start) * 1000
+
+ result = {
+ # Preserve sample_idx if present for backward compatibility
+ "sample_idx": search_result.get("sample_idx"),
+ "_id": search_result.get("_id"),
+ "domain": search_result.get("domain"),
+ "sub_domain": search_result.get("sub_domain"),
+ "difficulty": search_result.get("difficulty"),
+ "length": search_result.get("length"),
+ "question": question,
+ "choice_A": choices["A"],
+ "choice_B": choices["B"],
+ "choice_C": choices["C"],
+ "choice_D": choices["D"],
+ "answer": search_result.get("answer"),
+ "pred": pred,
+ "response": response,
+ "judge": pred == search_result.get("answer") if pred else False,
+ "prompt_tokens": prompt_tokens,
+ # Keep full retrieved memories list for inspection / debugging
+ "memories_used": memories,
+ # Preserve full search results payload (e.g., list of memories)
+ "search_results": search_result.get("search_results"),
+ "response_duration_ms": response_duration_ms,
+ "search_duration_ms": search_result.get("search_duration_ms", 0),
+ }
+
+ # Record successful processing (thread-safe)
+ if sample_key:
+ with file_lock, open(record_file, "a") as f:
+ f.write(f"{sample_key}\n")
+ f.flush()
+
+ return result
+
+
+def main(frame, version="default", num_workers=10):
+ """Main response generation function."""
+ load_dotenv()
+
+ print("\n" + "=" * 80)
+ print(f"🚀 LONGBENCH V2 RESPONSE GENERATION - {frame.upper()} v{version}".center(80))
+ print("=" * 80 + "\n")
+
+ # Initialize checkpoint file for resume functionality
+ checkpoint_dir = os.path.join(
+ ROOT_DIR, "evaluation", "results", "long_bench_v2", f"{frame}-{version}"
+ )
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ record_file = os.path.join(checkpoint_dir, "response_success_records.txt")
+ search_path = os.path.join(checkpoint_dir, f"{frame}_longbench_v2_search_results.json")
+ output_path = os.path.join(checkpoint_dir, f"{frame}_longbench_v2_responses.json")
+
+ # Load search results
+ if not os.path.exists(search_path):
+ print(f"❌ Search results not found: {search_path}")
+ print("Please run longbench_v2_search.py first")
+ return
+
+ with open(search_path, encoding="utf-8") as f:
+ search_results = json.load(f)
+
+ # Load existing results and success records for resume
+ existing_results: dict[str, dict] = {}
+ success_records: set[str] = set()
+ if os.path.exists(output_path):
+ with open(output_path, encoding="utf-8") as f:
+ existing_results_list = json.load(f)
+ for result in existing_results_list:
+ # Use sample_idx if present, otherwise _id as the unique key
+ sample_idx = result.get("sample_idx")
+ key = str(sample_idx) if sample_idx is not None else str(result.get("_id", ""))
+ if key:
+ existing_results[key] = result
+ success_records.add(key)
+ print(f"📋 Found {len(existing_results)} existing responses (resume mode)")
+ else:
+ print("📋 Starting fresh response generation (no checkpoint found)")
+
+ # Load additional success records from checkpoint file
+ if os.path.exists(record_file):
+ with open(record_file) as f:
+ for line in f:
+ line = line.strip()
+ if line and line not in success_records:
+ success_records.add(line)
+ print(f"📋 Total {len(success_records)} samples already processed")
+
+ # Initialize LLM client
+ llm_client = OpenAI(
+ api_key=os.getenv("CHAT_MODEL_API_KEY"),
+ base_url=os.getenv("CHAT_MODEL_BASE_URL"),
+ )
+ print(f"🔌 Using OpenAI client with model: {os.getenv('CHAT_MODEL')}")
+
+ # Process all samples concurrently using ThreadPoolExecutor
+ new_results = []
+ file_lock = threading.Lock() # Lock for thread-safe file writing
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = [
+ executor.submit(
+ process_sample, sample, llm_client, success_records, record_file, file_lock
+ )
+ for sample in search_results
+ ]
+
+ for future in tqdm(
+ as_completed(futures),
+ total=len(futures),
+ desc="Generating responses",
+ ):
+ result = future.result()
+ if result:
+ new_results.append(result)
+ # Update existing results with new result (keyed by sample_idx or _id)
+ sample_idx = result.get("sample_idx")
+ key = str(sample_idx) if sample_idx is not None else str(result.get("_id", ""))
+ if key:
+ existing_results[key] = result
+
+ # Merge and save all results
+ all_responses = list(existing_results.values())
+
+ # Sort by sample_idx when available, otherwise by _id for stability
+ def _sort_key(x: dict):
+ if x.get("sample_idx") is not None:
+ return ("0", int(x.get("sample_idx")))
+ return ("1", str(x.get("_id", "")))
+
+ all_responses.sort(key=_sort_key)
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(all_responses, f, ensure_ascii=False, indent=2)
+
+ print(f"\n{'=' * 80}")
+ print(f"✅ RESPONSE GENERATION COMPLETE: Results saved to {output_path}".center(80))
+ print(f"{'=' * 80}\n")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lib",
+ type=str,
+ choices=["memos-api", "memos-api-online"],
+ default="memos-api",
+ )
+ parser.add_argument(
+ "--version",
+ type=str,
+ default="default",
+ help="Version identifier for loading results",
+ )
+ parser.add_argument(
+ "--workers",
+ type=int,
+ default=10,
+ help="Number of parallel workers",
+ )
+ args = parser.parse_args()
+
+ main(args.lib, args.version, args.workers)
diff --git a/evaluation/scripts/long_bench-v2/longbench_v2_search.py b/evaluation/scripts/long_bench-v2/longbench_v2_search.py
new file mode 100644
index 000000000..2347e5d66
--- /dev/null
+++ b/evaluation/scripts/long_bench-v2/longbench_v2_search.py
@@ -0,0 +1,273 @@
+import argparse
+import json
+import os
+import sys
+import threading
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from time import time
+
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+
+ROOT_DIR = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
+EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts")
+
+sys.path.insert(0, ROOT_DIR)
+sys.path.insert(0, EVAL_SCRIPTS_DIR)
+
+
+def memos_api_search(client, query, user_id, top_k, frame):
+ """Search using memos API."""
+ start = time()
+ search_results = client.search(query=query, user_id=user_id, top_k=top_k)
+
+ # Extract raw memory texts in the same way as longbench_stx.memos_search
+ memories_texts: list[str] = []
+ if (
+ (frame == "memos-api" or frame == "memos-api-online")
+ and isinstance(search_results, dict)
+ and "text_mem" in search_results
+ ):
+ text_mem = search_results.get("text_mem") or []
+ if text_mem and text_mem[0].get("memories"):
+ memories = text_mem[0]["memories"]
+ for m in memories:
+ if not isinstance(m, dict):
+ continue
+ # tags may be at top-level or inside metadata
+ tags = m.get("tags") or m.get("metadata", {}).get("tags") or []
+ # Skip fast-mode memories
+ if any(isinstance(t, str) and "mode:fast" in t for t in tags):
+ continue
+ mem_text = m.get("memory", "")
+ if str(mem_text).strip():
+ memories_texts.append(mem_text)
+
+ duration_ms = (time() - start) * 1000
+ return memories_texts, duration_ms, search_results
+
+
+def process_sample(
+ client, sample, sample_idx, frame, version, top_k, success_records, record_file, file_lock
+):
+ """Process a single sample: search for relevant memories."""
+ # Skip if already processed
+ if str(sample_idx) in success_records:
+ return None
+
+ user_id = f"longbench_v2_{sample_idx}_{version}"
+ query = sample.get("question", "")
+
+ if not query:
+ return None
+
+ memories_used, duration_ms, search_results = memos_api_search(
+ client, query, user_id, top_k, frame
+ )
+
+ if not (isinstance(memories_used, list) and any(str(m).strip() for m in memories_used)):
+ return None
+
+ result = {
+ "sample_idx": sample_idx,
+ "_id": sample.get("_id"),
+ "domain": sample.get("domain"),
+ "sub_domain": sample.get("sub_domain"),
+ "difficulty": sample.get("difficulty"),
+ "length": sample.get("length"),
+ "question": query,
+ "choice_A": sample.get("choice_A"),
+ "choice_B": sample.get("choice_B"),
+ "choice_C": sample.get("choice_C"),
+ "choice_D": sample.get("choice_D"),
+ "answer": sample.get("answer"),
+ # Raw memories used for RAG answering (aligned with longbench_stx)
+ "memories_used": memories_used,
+ # Preserve full search results payload for debugging / analysis
+ "search_results": search_results,
+ "search_duration_ms": duration_ms,
+ }
+
+ # Record successful processing (thread-safe)
+ with file_lock, open(record_file, "a") as f:
+ f.write(f"{sample_idx}\n")
+ f.flush()
+
+ return result
+
+
+def load_dataset_from_local():
+ """Load LongBench v2 dataset from local JSON file."""
+ data_dir = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
+ "data",
+ "long_bench_v2",
+ )
+
+ filepath = os.path.join(data_dir, "data.json")
+
+ if not os.path.exists(filepath):
+ raise FileNotFoundError(f"Dataset file not found: {filepath}")
+
+ # Load JSON file
+ with open(filepath, encoding="utf-8") as f:
+ samples = json.load(f)
+
+ return samples
+
+
+def main(frame, version="default", num_workers=10, top_k=20, max_samples=None):
+ """Main search function."""
+ load_dotenv()
+
+ print("\n" + "=" * 80)
+ print(f"🚀 LONGBENCH V2 SEARCH - {frame.upper()} v{version}".center(80))
+ print("=" * 80 + "\n")
+
+ # Load dataset from local file
+ try:
+ dataset = load_dataset_from_local()
+ print(f"Loaded {len(dataset)} samples from LongBench v2")
+ except FileNotFoundError as e:
+ print(f"❌ Error loading dataset: {e}")
+ return
+ except Exception as e:
+ print(f"❌ Error loading dataset: {e}")
+ return
+
+ # Limit samples if specified
+ if max_samples:
+ dataset = dataset[:max_samples]
+ print(f"Limited to {len(dataset)} samples")
+
+ # Initialize checkpoint file for resume functionality
+ checkpoint_dir = os.path.join(
+ ROOT_DIR, "evaluation", "results", "long_bench_v2", f"{frame}-{version}"
+ )
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ record_file = os.path.join(checkpoint_dir, "search_success_records.txt")
+ output_path = os.path.join(checkpoint_dir, f"{frame}_longbench_v2_search_results.json")
+
+ # Load existing results and success records for resume
+ existing_results = {}
+ success_records = set()
+ if os.path.exists(output_path):
+ with open(output_path, encoding="utf-8") as f:
+ existing_results_list = json.load(f)
+ for result in existing_results_list:
+ sample_idx = result.get("sample_idx")
+ if sample_idx is not None:
+ existing_results[sample_idx] = result
+ success_records.add(str(sample_idx))
+ print(f"📋 Found {len(existing_results)} existing search results (resume mode)")
+ else:
+ print("📋 Starting fresh search (no checkpoint found)")
+
+ # Load additional success records from checkpoint file
+ if os.path.exists(record_file):
+ with open(record_file) as f:
+ for line in f:
+ line = line.strip()
+ if line and line not in success_records:
+ success_records.add(line)
+ print(f"📋 Total {len(success_records)} samples already processed")
+
+ # Initialize client
+ client = None
+ if frame == "memos-api":
+ from utils.client import MemosApiClient
+
+ client = MemosApiClient()
+ elif frame == "memos-api-online":
+ from utils.client import MemosApiOnlineClient
+
+ client = MemosApiOnlineClient()
+ else:
+ print(f"❌ Unsupported frame: {frame}")
+ return
+
+ # Process samples
+ new_results = []
+ file_lock = threading.Lock() # Lock for thread-safe file writing
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = []
+ for idx, sample in enumerate(dataset):
+ future = executor.submit(
+ process_sample,
+ client,
+ sample,
+ idx,
+ frame,
+ version,
+ top_k,
+ success_records,
+ record_file,
+ file_lock,
+ )
+ futures.append(future)
+
+ for future in tqdm(
+ as_completed(futures),
+ total=len(futures),
+ desc="Searching LongBench v2",
+ ):
+ result = future.result()
+ if result:
+ new_results.append(result)
+ # Update existing results with new result
+ sample_idx = result.get("sample_idx")
+ if sample_idx is not None:
+ existing_results[sample_idx] = result
+
+ # Merge and save all results
+ search_results = list(existing_results.values())
+ # Sort by sample_idx to maintain order
+ search_results.sort(key=lambda x: x.get("sample_idx", 0))
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(search_results, f, ensure_ascii=False, indent=2)
+
+ print(f"\n{'=' * 80}")
+ print(f"✅ SEARCH COMPLETE: Results saved to {output_path}".center(80))
+ print(f"{'=' * 80}\n")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lib",
+ type=str,
+ choices=["memos-api", "memos-api-online"],
+ default="memos-api",
+ )
+ parser.add_argument(
+ "--version",
+ type=str,
+ default="default",
+ help="Version identifier for saving results",
+ )
+ parser.add_argument(
+ "--workers",
+ type=int,
+ default=1,
+ help="Number of parallel workers",
+ )
+ parser.add_argument(
+ "--top_k",
+ type=int,
+ default=20,
+ help="Number of results to retrieve in search queries",
+ )
+ parser.add_argument(
+ "--max_samples",
+ type=int,
+ default=None,
+ help="Maximum number of samples to process (default: all)",
+ )
+ args = parser.parse_args()
+
+ main(args.lib, args.version, args.workers, args.top_k, args.max_samples)
diff --git a/evaluation/scripts/long_bench-v2/run_longbench_v2_eval.sh b/evaluation/scripts/long_bench-v2/run_longbench_v2_eval.sh
new file mode 100755
index 000000000..917c57bfb
--- /dev/null
+++ b/evaluation/scripts/long_bench-v2/run_longbench_v2_eval.sh
@@ -0,0 +1,110 @@
+#!/bin/bash
+
+# Common parameters for all scripts
+LIB="memos-api"
+VERSION="long-bench-v2-1208-1556-async"
+WORKERS=10
+TOPK=20
+MAX_SAMPLES="" # Empty means all samples
+WAIT_INTERVAL=2 # seconds between polls
+WAIT_TIMEOUT=900 # seconds per user
+
+# Parse command line arguments
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --lib)
+ LIB="$2"
+ shift 2
+ ;;
+ --version)
+ VERSION="$2"
+ shift 2
+ ;;
+ --workers)
+ WORKERS="$2"
+ shift 2
+ ;;
+ --top_k)
+ TOPK="$2"
+ shift 2
+ ;;
+ --max_samples)
+ MAX_SAMPLES="$2"
+ shift 2
+ ;;
+ *)
+ echo "Unknown option: $1"
+ exit 1
+ ;;
+ esac
+done
+
+# Build max_samples argument
+MAX_SAMPLES_ARG=""
+if [ -n "$MAX_SAMPLES" ]; then
+ MAX_SAMPLES_ARG="--max_samples $MAX_SAMPLES"
+fi
+
+echo "Running LongBench v2 evaluation with:"
+echo " LIB: $LIB"
+echo " VERSION: $VERSION"
+echo " WORKERS: $WORKERS"
+echo " TOPK: $TOPK"
+echo " MAX_SAMPLES: ${MAX_SAMPLES:-all}"
+echo ""
+
+# Step 2: Search
+echo ""
+echo "=========================================="
+echo "Step 2: Running longbench_v2_search.py..."
+echo "=========================================="
+python scripts/long_bench-v2/longbench_v2_search.py \
+ --lib $LIB \
+ --version $VERSION \
+ --top_k $TOPK \
+ --workers $WORKERS \
+ $MAX_SAMPLES_ARG
+
+if [ $? -ne 0 ]; then
+ echo "Error running longbench_v2_search.py"
+ exit 1
+fi
+
+# Step 3: Response Generation
+echo ""
+echo "=========================================="
+echo "Step 3: Running longbench_v2_responses.py..."
+echo "=========================================="
+python scripts/long_bench-v2/longbench_v2_responses.py \
+ --lib $LIB \
+ --version $VERSION \
+ --workers $WORKERS
+
+if [ $? -ne 0 ]; then
+ echo "Error running longbench_v2_responses.py"
+ exit 1
+fi
+
+# Step 4: Metrics Calculation
+echo ""
+echo "=========================================="
+echo "Step 4: Running longbench_v2_metric.py..."
+echo "=========================================="
+python scripts/long_bench-v2/longbench_v2_metric.py \
+ --lib $LIB \
+ --version $VERSION
+
+if [ $? -ne 0 ]; then
+ echo "Error running longbench_v2_metric.py"
+ exit 1
+fi
+
+echo ""
+echo "=========================================="
+echo "All steps completed successfully!"
+echo "=========================================="
+echo ""
+echo "Results are saved in: results/long_bench-v2/$LIB-$VERSION/"
+echo " - Search results: ${LIB}_longbench_v2_search_results.json"
+echo " - Responses: ${LIB}_longbench_v2_responses.json"
+echo " - Metrics: ${LIB}_longbench_v2_metrics.json"
diff --git a/evaluation/scripts/long_bench-v2/wait_scheduler.py b/evaluation/scripts/long_bench-v2/wait_scheduler.py
new file mode 100644
index 000000000..716869a11
--- /dev/null
+++ b/evaluation/scripts/long_bench-v2/wait_scheduler.py
@@ -0,0 +1,67 @@
+import os
+import time
+
+import requests
+
+from dotenv import load_dotenv
+
+
+def wait_until_completed(params: dict, interval: float = 2.0, timeout: float = 600.0):
+ """
+ Keep polling /product/scheduler/status until status == 'completed' (or terminal).
+
+ params: dict passed as query params, e.g. {"user_id": "xxx"} or {"user_id": "xxx", "task_id": "..."}
+ interval: seconds between polls
+ timeout: max seconds to wait before raising TimeoutError
+ """
+ load_dotenv()
+ base_url = os.getenv("MEMOS_URL")
+ if not base_url:
+ raise RuntimeError("MEMOS_URL not set in environment")
+
+ url = f"{base_url}/product/scheduler/status"
+ start = time.time()
+ active_states = {"waiting", "pending", "in_progress"}
+
+ while True:
+ resp = requests.get(url, params=params, timeout=10)
+ resp.raise_for_status()
+ data = resp.json()
+
+ items = data.get("data", []) if isinstance(data, dict) else []
+ statuses = [item.get("status") for item in items if isinstance(item, dict)]
+ status_set = set(statuses)
+
+ # Print current status snapshot
+ print(f"Current status: {status_set or 'empty'}")
+
+ # Completed if no active states remain
+ if not status_set or status_set.isdisjoint(active_states):
+ print("Task completed!")
+ return data
+
+ if (time.time() - start) > timeout:
+ raise TimeoutError(f"Timeout after {timeout}s; last statuses={status_set or 'empty'}")
+
+ time.sleep(interval)
+
+
+if __name__ == "__main__":
+ import argparse
+ import json
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--user_id", default="longbench_v2_0_long-bench-v2-1208-2119-async", help="User ID to query"
+ )
+ parser.add_argument("--task_id", help="Optional task_id to query")
+ parser.add_argument("--interval", type=float, default=2.0, help="Poll interval seconds")
+ parser.add_argument("--timeout", type=float, default=600.0, help="Timeout seconds")
+ args = parser.parse_args()
+
+ params = {"user_id": args.user_id}
+ if args.task_id:
+ params["task_id"] = args.task_id
+
+ result = wait_until_completed(params, interval=args.interval, timeout=args.timeout)
+ print(json.dumps(result, indent=2, ensure_ascii=False))
diff --git a/evaluation/scripts/longbench_v2/longbench_v2_eval.py b/evaluation/scripts/longbench_v2/longbench_v2_eval.py
new file mode 100644
index 000000000..ae88c34f9
--- /dev/null
+++ b/evaluation/scripts/longbench_v2/longbench_v2_eval.py
@@ -0,0 +1,241 @@
+import argparse
+import json
+import os
+import re
+import time
+import traceback
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+
+from dotenv import load_dotenv
+from openai import OpenAI
+from tqdm import tqdm
+
+from evaluation.scripts.utils.prompts import LONGBENCH_V2_ANSWER_PROMPT
+
+
+load_dotenv()
+
+
+def retry_operation(func, *args, retries=5, delay=2, **kwargs):
+ for attempt in range(retries):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ traceback.print_exc()
+ if attempt < retries - 1:
+ func_name = getattr(func, "__name__", "Operation")
+ print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...")
+ time.sleep(delay)
+ delay *= 2
+ else:
+ raise e
+
+
+def extract_answer(response: str) -> str | None:
+ response = response.replace("*", "")
+ match = re.search(r"The correct answer is \(([A-D])\)", response)
+ if match:
+ return match.group(1)
+ match = re.search(r"The correct answer is ([A-D])", response)
+ if match:
+ return match.group(1)
+ return None
+
+
+def llm_answer(
+ oai_client, model_name, memories: list[str], question: str, choices: dict
+) -> tuple[str, int]:
+ doc_content = "\n\n".join([f"Retrieved chunk {idx + 1}: {m}" for idx, m in enumerate(memories)])
+ prompt = (
+ LONGBENCH_V2_ANSWER_PROMPT.replace("$DOC$", doc_content)
+ .replace("$Q$", question)
+ .replace("$C_A$", choices.get("A", ""))
+ .replace("$C_B$", choices.get("B", ""))
+ .replace("$C_C$", choices.get("C", ""))
+ .replace("$C_D$", choices.get("D", ""))
+ )
+ messages = [{"role": "user", "content": prompt}]
+ resp = retry_operation(
+ oai_client.chat.completions.create,
+ model=model_name,
+ messages=messages,
+ temperature=0.1,
+ max_tokens=12800,
+ )
+ return resp.choices[0].message.content or "", resp.usage.prompt_tokens
+
+
+def print_metrics(results: list[dict], duration: float) -> None:
+ easy, hard, short, medium, long = 0, 0, 0, 0, 0
+ easy_acc, hard_acc, short_acc, medium_acc, long_acc = 0, 0, 0, 0, 0
+ total_tokens = 0
+
+ for pred in results:
+ acc = int(pred.get("judge", False))
+ diff = pred.get("difficulty", "easy")
+ length = pred.get("length", "short")
+ tokens = pred.get("prompt_tokens", 0)
+ total_tokens += tokens
+
+ if diff == "easy":
+ easy += 1
+ easy_acc += acc
+ else:
+ hard += 1
+ hard_acc += acc
+
+ if length == "short":
+ short += 1
+ short_acc += acc
+ elif length == "medium":
+ medium += 1
+ medium_acc += acc
+ else:
+ long += 1
+ long_acc += acc
+
+ total = len(results)
+ if total == 0:
+ print("No results to calculate metrics.")
+ return
+
+ o_acc = round(100 * (easy_acc + hard_acc) / total, 2)
+ e_acc = round(100 * easy_acc / easy, 2) if easy > 0 else 0
+ h_acc = round(100 * hard_acc / hard, 2) if hard > 0 else 0
+ s_acc = round(100 * short_acc / short, 2) if short > 0 else 0
+ m_acc = round(100 * medium_acc / medium, 2) if medium > 0 else 0
+ l_acc = round(100 * long_acc / long, 2) if long > 0 else 0
+ avg_tokens = round(total_tokens / total, 2)
+
+ print("\n" + "=" * 60)
+ print(f"{'Metric':<15} | {'Count':<10} | {'Accuracy (%)':<10}")
+ print("-" * 60)
+ print(f"{'Overall':<15} | {total:<10} | {o_acc:<10}")
+ print(f"{'Easy':<15} | {easy:<10} | {e_acc:<10}")
+ print(f"{'Hard':<15} | {hard:<10} | {h_acc:<10}")
+ print(f"{'Short':<15} | {short:<10} | {s_acc:<10}")
+ print(f"{'Medium':<15} | {medium:<10} | {m_acc:<10}")
+ print(f"{'Long':<15} | {long:<10} | {l_acc:<10}")
+ print("-" * 60)
+ print(f"{'Avg Tokens':<15} | {total:<10} | {avg_tokens:<10}")
+ print(f"Total Duration: {duration:.2f} seconds")
+ print("=" * 60 + "\n")
+
+
+def _load_json_list(path: Path) -> list[dict]:
+ data = json.loads(path.read_text(encoding="utf-8"))
+ if not isinstance(data, list):
+ raise ValueError(f"Invalid json format: {path}")
+ return data
+
+
+def _save_json_list(path: Path, rows: list[dict]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ tmp = path.with_suffix(path.suffix + ".tmp")
+ tmp.write_text(json.dumps(rows, ensure_ascii=False, indent=2), encoding="utf-8")
+ os.replace(tmp, path)
+
+
+def evaluate_one(oai_client, model_name, row: dict) -> dict:
+ question = row.get("question") or ""
+ choices = row.get("choices") or {}
+ memories = row.get("memories_used") or []
+ response, prompt_tokens = llm_answer(
+ oai_client, model_name, list(memories), str(question), dict(choices)
+ )
+ pred = extract_answer(response)
+ judge = pred == row.get("answer")
+ out = dict(row)
+ out["response"] = response
+ out["pred"] = pred
+ out["judge"] = judge
+ out["prompt_tokens"] = prompt_tokens
+ out.pop("memories_used")
+ return out
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="LongBench-v2 eval Tool")
+ parser.add_argument(
+ "--lib",
+ "-b",
+ required=True,
+ choices=["memos", "mem0", "supermemory"],
+ help="Product name to evaluate",
+ )
+ parser.add_argument("--workers", "-w", type=int, default=20, help="Number of parallel threads")
+ parser.add_argument(
+ "--top-k", "-k", type=int, default=20, help="Top k results to use (default: 20)"
+ )
+ parser.add_argument("--version-dir", "-v", default=None, help="Version directory name")
+ parser.add_argument("--search_results_path", type=str, default=None)
+ parser.add_argument("--output_path", type=str, default=None)
+ parser.add_argument("--chat-model", type=str, default=None, help="Chat model for evaluation")
+ args = parser.parse_args()
+
+ print("=" * 60)
+ print("LongBench-v2 Product Eval Tool")
+ print("=" * 60)
+
+ start_time = time.time()
+
+ output_dir = os.path.join("evaluation/data/longbench_v2", args.version_dir)
+ search_filename = f"{args.lib}_search_results.json"
+ search_path = Path(os.path.join(output_dir, search_filename))
+
+ if not search_path.exists():
+ raise FileNotFoundError(f"Search results not found: {search_path}")
+
+ search_rows = _load_json_list(search_path)
+ output_filename = f"{args.lib}_eval_results.json"
+ output_path = Path(os.path.join(output_dir, output_filename))
+
+ results: list[dict] = []
+ processed_ids: set[str] = set()
+
+ # Resume from checkpoint
+ if output_path.exists():
+ try:
+ existing = _load_json_list(output_path)
+ results = existing
+ processed_ids = {str(r.get("_id")) for r in results if r.get("_id")}
+ print(f"Loaded {len(results)} existing results from checkpoint.")
+ except Exception as e:
+ print(f"Error loading checkpoint: {e}")
+
+ pending = [r for r in search_rows if str(r.get("_id")) not in processed_ids]
+ print(f"[Eval] total={len(search_rows)} pending={len(pending)} workers={args.workers}")
+ if not pending:
+ print_metrics(results, time.time() - start_time)
+ return
+
+ print("[Response model]: ", args.chat_model)
+ oai_client = OpenAI(
+ api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL")
+ )
+
+ with ThreadPoolExecutor(max_workers=args.workers) as executor:
+ futures = [
+ executor.submit(evaluate_one, oai_client, args.chat_model, row) for row in pending
+ ]
+ for idx, f in enumerate(
+ tqdm(as_completed(futures), total=len(futures), desc="Evaluating"), start=1
+ ):
+ try:
+ res = f.result()
+ results.append(res)
+ if idx % 10 == 0:
+ _save_json_list(output_path, results)
+ except Exception as e:
+ print(f"Evaluation Error: {e}")
+ traceback.print_exc()
+
+ _save_json_list(output_path, results)
+ print(f"Saved {len(results)} results to {output_path}")
+ print_metrics(results, time.time() - start_time)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/longbench_v2/longbench_v2_ingestion.py b/evaluation/scripts/longbench_v2/longbench_v2_ingestion.py
new file mode 100644
index 000000000..ec5257717
--- /dev/null
+++ b/evaluation/scripts/longbench_v2/longbench_v2_ingestion.py
@@ -0,0 +1,284 @@
+import argparse
+import json
+import os
+import time
+import traceback
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+
+from datasets import load_dataset
+from dotenv import load_dotenv
+from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
+from tqdm import tqdm
+
+from evaluation.scripts.utils.metrics import Metrics
+
+
+load_dotenv()
+
+
+def retry_operation(func, *args, retries=5, delay=2, **kwargs):
+ for attempt in range(retries):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ if attempt < retries - 1:
+ func_name = getattr(func, "__name__", "Operation")
+ print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...")
+ time.sleep(delay)
+ delay *= 2
+ else:
+ raise e
+
+
+def _get_lib_client(lib: str):
+ if lib == "mem0":
+ from evaluation.scripts.utils.client import Mem0Client
+
+ return Mem0Client(enable_graph=False)
+ if lib == "supermemory":
+ from evaluation.scripts.utils.client import SupermemoryClient
+
+ return SupermemoryClient()
+ from evaluation.scripts.utils.client import MemosApiClient
+
+ return MemosApiClient()
+
+
+def _load_dataset_jsonl(dataset_path: Path) -> list[dict]:
+ if not dataset_path.exists():
+ dataset = load_dataset("zai-org/LongBench-v2", split="train")
+ dataset_path.parent.mkdir(parents=True, exist_ok=True)
+ with open(dataset_path, "w", encoding="utf-8") as f:
+ for i in range(len(dataset)):
+ s = dataset[i]
+ row = {
+ "_id": s.get("_id") or s.get("id") or str(i),
+ "domain": s.get("domain"),
+ "sub_domain": s.get("sub_domain"),
+ "difficulty": s.get("difficulty"),
+ "length": s.get("length"),
+ "question": s.get("question"),
+ "choice_A": s.get("choice_A"),
+ "choice_B": s.get("choice_B"),
+ "choice_C": s.get("choice_C"),
+ "choice_D": s.get("choice_D"),
+ "answer": s.get("answer"),
+ "context": s.get("context") or s.get("document") or s.get("documents"),
+ }
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
+ print(f"Successfully saved dataset to {dataset_path}")
+
+ samples: list[dict] = []
+ with open(dataset_path, encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ samples.append(json.loads(line))
+ return samples
+
+
+def _load_added_ids(records_path: Path) -> set[str]:
+ if not records_path.exists():
+ return set()
+ try:
+ obj = json.loads(records_path.read_text(encoding="utf-8"))
+ ids = obj.get("added_ids") if isinstance(obj, dict) else None
+ if isinstance(ids, list):
+ return {str(x) for x in ids if x}
+ except Exception:
+ return set()
+ return set()
+
+
+def _save_added_ids(records_path: Path, added_ids: set[str]) -> None:
+ records_path.parent.mkdir(parents=True, exist_ok=True)
+ tmp = records_path.with_suffix(records_path.suffix + ".tmp")
+ tmp.write_text(
+ json.dumps({"added_ids": sorted(added_ids)}, ensure_ascii=False, indent=2),
+ encoding="utf-8",
+ )
+ os.replace(tmp, records_path)
+
+
+def ingest_context(
+ client, sample: dict, lib: str, mode: str = "fine", async_mode: str = "sync"
+) -> str:
+ sample_id = str(sample.get("_id"))
+ user_id = sample_id
+ context = sample.get("context") or ""
+ chunker = RecursiveCharacterTextSplitter.from_language(
+ language=Language.PYTHON, chunk_size=5120, chunk_overlap=128
+ )
+ chunks = [p for p in chunker.split_text(context or "") if p.strip()]
+
+ if lib == "memos":
+ messages = [{"type": "text", "text": p} for p in chunks]
+ writable_cube_ids = [user_id]
+ retry_operation(
+ client.add,
+ messages=messages,
+ user_id=user_id,
+ writable_cube_ids=writable_cube_ids,
+ source_type="batch_import",
+ mode=mode,
+ async_mode=async_mode,
+ )
+ return
+
+ if lib == "mem0":
+ messages = [{"role": "user", "content": p} for p in chunks]
+ ts = int(time.time())
+ retry_operation(client.add, messages=messages, user_id=user_id, timestamp=ts, batch_size=10)
+ return
+
+ if lib == "supermemory":
+ retry_operation(client.add, content=context, user_id=user_id)
+
+ return sample_id
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="LongBench-v2 Product Add Concurrent Script",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate")
+
+ parser.add_argument(
+ "--api-url",
+ default="http://127.0.0.1:8001",
+ help="MemOS API URL (default: http://127.0.0.1:8001)",
+ )
+
+ parser.add_argument("--workers", "-w", type=int, default=5, help="Concurrency (default: 10)")
+
+ parser.add_argument(
+ "--timeout", type=float, default=1200, help="Request timeout in seconds (default: 120)"
+ )
+
+ parser.add_argument(
+ "--mode", default="fine", choices=["fine", "fast"], help="Processing mode (default: fine)"
+ )
+
+ parser.add_argument(
+ "--async-mode", default="sync", choices=["sync", "async"], help="Async mode (default: sync)"
+ )
+
+ parser.add_argument("--version-dir", "-v", default=None, help="Version directory name")
+
+ parser.add_argument(
+ "--dataset_path",
+ "-p",
+ default="evaluation/data/longbench_v2/longbenchv2_train.json",
+ help="Dataset path",
+ )
+
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = parse_args()
+ print("=" * 60)
+ print("LongBench-v2 Product Add Concurrent Tool")
+ print("=" * 60)
+
+ dataset_path = Path(args.dataset_path)
+ dataset = _load_dataset_jsonl(dataset_path)
+
+ version_output_dir = os.path.join("evaluation/data/longbench_v2", args.version_dir)
+ os.makedirs(version_output_dir, exist_ok=True)
+ output_path = os.path.join(version_output_dir, f"{args.lib}_add_results.json")
+ output_path = Path(output_path)
+
+ added_ids = _load_added_ids(output_path)
+ pending = [s for s in dataset if str(s.get("_id")) not in added_ids]
+ print(
+ f"[Add] lib={args.lib} total={len(dataset)} pending={len(pending)} workers={args.workers}"
+ )
+ if not pending:
+ return
+
+ client = _get_lib_client(args.lib)
+ metrics = Metrics()
+
+ def do_ingest(sample):
+ start_time = time.perf_counter()
+ try:
+ sid = ingest_context(client, sample, args.lib, args.mode, args.async_mode)
+ duration = time.perf_counter() - start_time
+ metrics.record(duration, True)
+ return sid
+ except Exception as e:
+ duration = time.perf_counter() - start_time
+ metrics.record(duration, False, str(e))
+ raise e
+
+ start_time = time.time()
+ with ThreadPoolExecutor(max_workers=args.workers) as executor:
+ futures = [executor.submit(do_ingest, sample) for sample in pending]
+ for f in tqdm(as_completed(futures), total=len(futures), desc="Adding"):
+ try:
+ sid = f.result()
+ if sid:
+ added_ids.add(str(sid))
+ if len(added_ids) % 10 == 0:
+ _save_added_ids(output_path, added_ids)
+ except Exception as e:
+ print(f"[Add] Error: {e}")
+ traceback.print_exc()
+
+ _save_added_ids(output_path, added_ids)
+ print(f"[Add] saved records to {output_path}")
+
+ total_duration = time.time() - start_time
+ perf_out = Path(version_output_dir) / f"{args.lib}_add_perf.json"
+
+ summary = metrics.summary()
+
+ with open(perf_out, "w", encoding="utf-8") as f:
+ json.dump(
+ {
+ "summary": summary,
+ "total_duration": total_duration,
+ "config": {
+ "workers": args.workers,
+ "mode": args.mode,
+ "async_mode": args.async_mode,
+ "dataset_path": args.dataset_path,
+ },
+ },
+ f,
+ ensure_ascii=False,
+ indent=2,
+ )
+ print(f"[Add] saved performance metrics to {perf_out}")
+
+ print("\n" + "=" * 60)
+ print("Ingestion finished! Statistics:")
+ print("=" * 60)
+ print(f"Total duration: {total_duration:.2f}s")
+ print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}")
+
+ if summary["stats"]:
+ stats = summary["stats"]
+ qps = stats["count"] / total_duration if total_duration > 0 else 0
+ print(f"QPS: {qps:.2f}")
+ print("Latency stats (ms):")
+ print(f" Mean: {stats['mean']:.2f}")
+ print(f" Median: {stats['median']:.2f}")
+ print(f" Min: {stats['min']:.2f}")
+ print(f" Max: {stats['max']:.2f}")
+ print(f" P95: {stats['p95']:.2f}")
+ print(f" P99: {stats['p99']:.2f}")
+
+ if summary["errors"]:
+ print("\nError stats:")
+ for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]:
+ print(f" [{count} times] {error[:100]}...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/longbench_v2/longbench_v2_old.py b/evaluation/scripts/longbench_v2/longbench_v2_old.py
new file mode 100644
index 000000000..5c6eba616
--- /dev/null
+++ b/evaluation/scripts/longbench_v2/longbench_v2_old.py
@@ -0,0 +1,415 @@
+import argparse
+import json
+import os
+import re
+import sys
+import time
+import traceback
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from datetime import datetime
+from pathlib import Path
+
+from dotenv import load_dotenv
+from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
+from openai import OpenAI
+from tqdm import tqdm
+
+
+TEMPLATE_RAG = """Please read the following retrieved text chunks and answer the question below.
+
+
+$DOC$
+
+
+What is the correct answer to this question: $Q$
+Choices:
+(A) $C_A$
+(B) $C_B$
+(C) $C_C$
+(D) $C_D$
+
+Format your response as follows: "The correct answer is (insert answer here)"."""
+
+
+PROJECT_ROOT = Path(__file__).resolve().parents[3]
+SCRIPTS_ROOT = Path(__file__).resolve().parents[1]
+SRC_ROOT = PROJECT_ROOT / "src"
+sys.path.append(str(SCRIPTS_ROOT))
+sys.path.append(str(SRC_ROOT))
+load_dotenv()
+
+
+def retry_operation(func, *args, retries=5, delay=2, **kwargs):
+ for attempt in range(retries):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ traceback.print_exc()
+ if attempt < retries - 1:
+ func_name = getattr(func, "__name__", "Operation")
+ print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...")
+ time.sleep(delay)
+ delay *= 2
+ else:
+ raise e
+
+
+def _get_lib_client(lib: str):
+ if lib == "mem0":
+ from utils.client import Mem0Client # type: ignore
+
+ return Mem0Client(enable_graph=False)
+ if lib == "supermemory":
+ from utils.client import SupermemoryClient # type: ignore
+
+ return SupermemoryClient()
+ from utils.client import MemosApiClient # type: ignore
+
+ return MemosApiClient()
+
+
+def _get_clients(lib: str = "memos"):
+ client = _get_lib_client(lib)
+ openai_client = OpenAI(
+ api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL")
+ )
+ return client, openai_client
+
+
+def _dump_dataset_to_local():
+ from datasets import load_dataset
+
+ dataset = load_dataset("zai-org/LongBench-v2", split="train")
+ out_dir = Path("evaluation/data/longbenchV2")
+ out_dir.mkdir(parents=True, exist_ok=True)
+ out_path = out_dir / "longbenchv2_train.json"
+ with open(out_path, "w", encoding="utf-8") as f:
+ for i in range(len(dataset)):
+ s = dataset[i]
+ row = {
+ "_id": s.get("_id") or s.get("id") or str(i),
+ "domain": s.get("domain"),
+ "sub_domain": s.get("sub_domain"),
+ "difficulty": s.get("difficulty"),
+ "length": s.get("length"),
+ "question": s.get("question"),
+ "choice_A": s.get("choice_A"),
+ "choice_B": s.get("choice_B"),
+ "choice_C": s.get("choice_C"),
+ "choice_D": s.get("choice_D"),
+ "answer": s.get("answer"),
+ "context": s.get("context") or s.get("document") or s.get("documents"),
+ }
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
+ print(f"Saved dataset to {out_path}")
+ return dataset
+
+
+def add_context(client, user_id: str, context: str, lib: str) -> None:
+ iso = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+ chunker = RecursiveCharacterTextSplitter.from_language(
+ language=Language.PYTHON, chunk_size=5120, chunk_overlap=128
+ )
+ paragraphs = [p for p in chunker.split_text(context or "") if p.strip()]
+
+ if lib == "memos":
+ messages = [{"role": "user", "content": p, "created_at": iso} for p in paragraphs]
+ try:
+ retry_operation(client.add, messages=messages, user_id=user_id, conv_id=user_id)
+ print(f"[Add-memos]: successfully added {len(messages)} chunks to user {user_id}")
+ except Exception as e:
+ print(f"[Add-memos] failed: {e}")
+
+ elif lib == "mem0":
+ messages = [{"role": "user", "content": p} for p in paragraphs]
+ ts = int(time.time())
+ try:
+ retry_operation(
+ client.add, messages=messages, user_id=user_id, timestamp=ts, batch_size=10
+ )
+ print(f"[Add-mem0] user={user_id} total={len(messages)}")
+ except Exception as e:
+ print(f"[Add-mem0] failed: {e}")
+
+ elif lib == "supermemory":
+ iso = datetime.utcnow().isoformat() + "Z"
+ content = "\n".join([f"{iso} user: {p}" for p in paragraphs])
+ try:
+ retry_operation(client.add, content=content, user_id=user_id)
+ print(f"[Add-supermemory] user={user_id} total_chars={len(content)}")
+ except Exception as e:
+ print(f"[Add-supermemory] failed: {e}")
+
+
+def memos_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]:
+ results = retry_operation(client.search, query=query, user_id=user_id, top_k=top_k)
+ memories = results["text_mem"][0]["memories"]
+ mem_texts = [m["memory"] for m in memories]
+ print(f"[Search-memos] user={user_id} top_k={top_k} memories={len(memories)}")
+ return mem_texts
+
+
+def mem0_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]:
+ res = retry_operation(client.search, query, user_id, top_k)
+ results = res.get("results", [])
+ mem_texts = [m.get("memory", "") for m in results if m.get("memory")]
+ print(f"[Search-mem0] user={user_id} top_k={top_k} memories={len(mem_texts)}")
+ return mem_texts
+
+
+def supermemory_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]:
+ chunk_list = retry_operation(client.search, query, user_id, top_k)
+ print(f"[Search-supermemory] user={user_id} top_k={top_k} memories={len(chunk_list)}")
+ return chunk_list
+
+
+def extract_answer(response: str) -> str | None:
+ response = response.replace("*", "")
+ match = re.search(r"The correct answer is \(([A-D])\)", response)
+ if match:
+ return match.group(1)
+ else:
+ match = re.search(r"The correct answer is ([A-D])", response)
+ if match:
+ return match.group(1)
+ else:
+ return None
+
+
+def llm_answer(oai_client, memories: list[str], question: str, choices: dict) -> tuple[str, int]:
+ # Join memories to form the retrieved context document
+ doc_content = "\n\n".join([f"Retrieved chunk {idx + 1}: {m}" for idx, m in enumerate(memories)])
+
+ prompt = (
+ TEMPLATE_RAG.replace("$DOC$", doc_content)
+ .replace("$Q$", question)
+ .replace("$C_A$", choices.get("A", ""))
+ .replace("$C_B$", choices.get("B", ""))
+ .replace("$C_C$", choices.get("C", ""))
+ .replace("$C_D$", choices.get("D", ""))
+ )
+
+ messages = [
+ {"role": "user", "content": prompt},
+ ]
+ resp = retry_operation(
+ oai_client.chat.completions.create,
+ model=os.getenv("CHAT_MODEL"),
+ messages=messages,
+ temperature=0.1,
+ max_tokens=12800,
+ )
+ return resp.choices[0].message.content or "", resp.usage.prompt_tokens
+
+
+def ingest_sample(client, sample: dict, lib: str) -> None:
+ sample_id = str(sample.get("_id"))
+ user_id = sample_id
+ context = sample.get("context") or ""
+ add_context(client, user_id, str(context), lib)
+
+
+def evaluate_sample(client, oai_client, sample: dict, top_k: int, lib: str) -> dict:
+ sample_id = str(sample.get("_id"))
+ user_id = sample_id
+ question = sample.get("question") or ""
+ choices = {
+ "A": sample.get("choice_A") or "",
+ "B": sample.get("choice_B") or "",
+ "C": sample.get("choice_C") or "",
+ "D": sample.get("choice_D") or "",
+ }
+
+ if lib == "memos":
+ memories = memos_search(client, user_id, str(question), top_k=top_k)
+ elif lib == "mem0":
+ memories = mem0_search(client, user_id, str(question), top_k=top_k)
+ elif lib == "supermemory":
+ memories = supermemory_search(client, user_id, str(question), top_k=top_k)
+ else:
+ memories = []
+
+ response, prompt_tokens = llm_answer(oai_client, memories, str(question), choices)
+ pred = extract_answer(response)
+ judge = pred == sample.get("answer")
+ print("[Question]:", question)
+ print("[Choices]:", choices)
+ print("[Raw response]:", response)
+ print("[Answer]:", pred)
+ print("[Ground truth]:", sample.get("answer"))
+ print("[Prompt Tokens]:", prompt_tokens)
+
+ out = {
+ "_id": sample_id,
+ "domain": sample.get("domain"),
+ "sub_domain": sample.get("sub_domain"),
+ "difficulty": sample.get("difficulty"),
+ "length": sample.get("length"),
+ "question": question,
+ "choice_A": choices["A"],
+ "choice_B": choices["B"],
+ "choice_C": choices["C"],
+ "choice_D": choices["D"],
+ "answer": sample.get("answer"),
+ "memories_used": memories,
+ "response": response,
+ "pred": pred,
+ "judge": judge,
+ "prompt_tokens": prompt_tokens,
+ }
+ return out
+
+
+def print_metrics(results: list[dict], duration: float) -> None:
+ easy, hard, short, medium, long = 0, 0, 0, 0, 0
+ easy_acc, hard_acc, short_acc, medium_acc, long_acc = 0, 0, 0, 0, 0
+ total_tokens = 0
+
+ for pred in results:
+ acc = int(pred.get("judge", False))
+ diff = pred.get("difficulty", "easy")
+ length = pred.get("length", "short")
+ tokens = pred.get("prompt_tokens", 0)
+ total_tokens += tokens
+
+ if diff == "easy":
+ easy += 1
+ easy_acc += acc
+ else:
+ hard += 1
+ hard_acc += acc
+
+ if length == "short":
+ short += 1
+ short_acc += acc
+ elif length == "medium":
+ medium += 1
+ medium_acc += acc
+ else:
+ long += 1
+ long_acc += acc
+
+ total = len(results)
+ if total == 0:
+ print("No results to calculate metrics.")
+ return
+
+ o_acc = round(100 * (easy_acc + hard_acc) / total, 2)
+ e_acc = round(100 * easy_acc / easy, 2) if easy > 0 else 0
+ h_acc = round(100 * hard_acc / hard, 2) if hard > 0 else 0
+ s_acc = round(100 * short_acc / short, 2) if short > 0 else 0
+ m_acc = round(100 * medium_acc / medium, 2) if medium > 0 else 0
+ l_acc = round(100 * long_acc / long, 2) if long > 0 else 0
+ avg_tokens = round(total_tokens / total, 2)
+
+ print("\n" + "=" * 60)
+ print(f"{'Metric':<15} | {'Count':<10} | {'Accuracy (%)':<10}")
+ print("-" * 60)
+ print(f"{'Overall':<15} | {total:<10} | {o_acc:<10}")
+ print(f"{'Easy':<15} | {easy:<10} | {e_acc:<10}")
+ print(f"{'Hard':<15} | {hard:<10} | {h_acc:<10}")
+ print(f"{'Short':<15} | {short:<10} | {s_acc:<10}")
+ print(f"{'Medium':<15} | {medium:<10} | {m_acc:<10}")
+ print(f"{'Long':<15} | {long:<10} | {l_acc:<10}")
+ print("-" * 60)
+ print(f"{'Avg Tokens':<15} | {total:<10} | {avg_tokens:<10}")
+ print(f"Total Duration: {duration:.2f} seconds")
+ print("=" * 60 + "\n")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Evaluate LongBench-v2 with different backends.")
+ parser.add_argument(
+ "--lib",
+ type=str,
+ default="memos",
+ choices=["memos", "mem0", "supermemory"],
+ help="Backend library to use (default: memos)",
+ )
+ args = parser.parse_args()
+
+ start_time = time.time()
+ print("[Response model]: ", os.getenv("CHAT_MODEL"))
+
+ client, oai_client = _get_clients(lib=args.lib)
+ dataset = _dump_dataset_to_local()
+ results: list[dict] = []
+ os.makedirs("evaluation/data/longbenchV2", exist_ok=True)
+ out_json = Path(f"evaluation/data/longbenchV2/test/{args.lib}_cot_results.json")
+
+ # Checkpoint loading
+ processed_ids = set()
+ if out_json.exists():
+ try:
+ with open(out_json, encoding="utf-8") as f:
+ existing_results = json.load(f)
+ if isinstance(existing_results, list):
+ results = existing_results
+ processed_ids = {r.get("_id") for r in results if r.get("_id")}
+ print(f"Loaded {len(results)} existing results from checkpoint.")
+ except Exception as e:
+ print(f"Error loading checkpoint: {e}")
+
+ # Filter dataset to skip processed samples
+ remaining_dataset = [
+ s
+ for s in dataset
+ if (s.get("_id") or s.get("id") or str(dataset.index(s))) not in processed_ids
+ ]
+
+ # Concurrency settings
+ max_workers = 4
+ print(f"Starting evaluation with {max_workers} workers using backend: {args.lib}")
+ print(f"Total dataset size: {len(dataset)}")
+ print(f"Already processed: {len(processed_ids)}")
+ print(f"Remaining to process: {len(remaining_dataset)}")
+
+ if not remaining_dataset:
+ print("All samples have been processed.")
+ print_metrics(results, time.time() - start_time)
+ return
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # Phase 1: Ingestion
+ print("Phase 1: Ingesting context...")
+ ingest_futures = [
+ executor.submit(ingest_sample, client, sample, args.lib) for sample in remaining_dataset
+ ]
+ for f in tqdm(as_completed(ingest_futures), total=len(ingest_futures), desc="Ingesting"):
+ try:
+ f.result()
+ except Exception as e:
+ print(f"Ingestion Error: {e}")
+
+ # Phase 2: Evaluation
+ print("Phase 2: Evaluating...")
+ futures = [
+ executor.submit(evaluate_sample, client, oai_client, sample, 30, args.lib)
+ for sample in remaining_dataset
+ ]
+
+ # Use tqdm for progress bar
+ for f in tqdm(as_completed(futures), total=len(futures), desc="Evaluating"):
+ try:
+ res = f.result()
+ results.append(res)
+
+ # Save intermediate results every 10 samples
+ if len(results) % 10 == 0:
+ out_json.write_text(
+ json.dumps(results, ensure_ascii=False, indent=2), encoding="utf-8"
+ )
+ print_metrics(results, time.time() - start_time)
+ except Exception as e:
+ print(f"Evaluation Error: {e}")
+ traceback.print_exc()
+
+ # Final save
+ out_json.write_text(json.dumps(results, ensure_ascii=False, indent=2), encoding="utf-8")
+ print(f"Saved {len(results)} results to {out_json}")
+ print_metrics(results, time.time() - start_time)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/longbench_v2/longbench_v2_search.py b/evaluation/scripts/longbench_v2/longbench_v2_search.py
new file mode 100644
index 000000000..a224ea3fd
--- /dev/null
+++ b/evaluation/scripts/longbench_v2/longbench_v2_search.py
@@ -0,0 +1,284 @@
+import argparse
+import json
+import os
+import time
+import traceback
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from evaluation.scripts.utils.metrics import Metrics
+
+
+load_dotenv()
+
+
+def retry_operation(func, *args, retries=5, delay=2, **kwargs):
+ for attempt in range(retries):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ if attempt < retries - 1:
+ func_name = getattr(func, "__name__", "Operation")
+ print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...")
+ time.sleep(delay)
+ delay *= 2
+ else:
+ raise e
+
+
+def _get_lib_client(lib: str):
+ if lib == "mem0":
+ from evaluation.scripts.utils.client import Mem0Client
+
+ return Mem0Client(enable_graph=False)
+ if lib == "supermemory":
+ from evaluation.scripts.utils.client import SupermemoryClient
+
+ return SupermemoryClient()
+ from evaluation.scripts.utils.client import MemosApiClient
+
+ return MemosApiClient()
+
+
+def _load_dataset_jsonl(dataset_path: Path) -> list[dict]:
+ samples: list[dict] = []
+ with open(dataset_path, encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ samples.append(json.loads(line))
+ return samples
+
+
+def memos_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]:
+ results = retry_operation(client.search, query=query, user_id=user_id, top_k=top_k)
+ memories = results["text_mem"][0]["memories"]
+ return [m["memory"] for m in memories]
+
+
+def mem0_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]:
+ res = retry_operation(client.search, query, user_id, top_k)
+ results = res.get("results", [])
+ return [m.get("memory", "") for m in results if m.get("memory")]
+
+
+def supermemory_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]:
+ return retry_operation(client.search, query, user_id, top_k)
+
+
+def _load_existing_results(output_path: Path) -> tuple[list[dict], set[str]]:
+ if not output_path.exists():
+ return [], set()
+ try:
+ data = json.loads(output_path.read_text(encoding="utf-8"))
+ if isinstance(data, list):
+ ids = {str(r.get("_id")) for r in data if r.get("_id")}
+ return data, ids
+ if isinstance(data, dict) and isinstance(data.get("results"), list):
+ rows = data.get("results") or []
+ ids = {str(r.get("_id")) for r in rows if r.get("_id")}
+ return rows, ids
+ except Exception:
+ return [], set()
+ return [], set()
+
+
+def _save_json_list(path: Path, rows: list[dict]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ tmp = path.with_suffix(path.suffix + ".tmp")
+ tmp.write_text(json.dumps({"results": rows}, ensure_ascii=False, indent=2), encoding="utf-8")
+ os.replace(tmp, path)
+
+
+def search_one(client, sample: dict, lib: str, top_k: int) -> dict:
+ sample_id = str(sample.get("_id"))
+ user_id = sample_id
+ question = sample.get("question") or ""
+ choices = {
+ "A": sample.get("choice_A") or "",
+ "B": sample.get("choice_B") or "",
+ "C": sample.get("choice_C") or "",
+ "D": sample.get("choice_D") or "",
+ }
+
+ if lib == "memos":
+ memories = memos_search(client, user_id, str(question), top_k=top_k)
+ elif lib == "mem0":
+ memories = mem0_search(client, user_id, str(question), top_k=top_k)
+ elif lib == "supermemory":
+ memories = supermemory_search(client, user_id, str(question), top_k=top_k)
+ else:
+ memories = []
+ print(f"[{lib} Search] sample_id: {sample_id} search memories: {len(memories)}")
+
+ return {
+ "_id": sample_id,
+ "domain": sample.get("domain"),
+ "sub_domain": sample.get("sub_domain"),
+ "difficulty": sample.get("difficulty"),
+ "length": sample.get("length"),
+ "question": question,
+ "choices": choices,
+ "answer": sample.get("answer"),
+ "memories_used": memories,
+ }
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Longbench-v2 Product Search Concurrent Script",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate")
+
+ parser.add_argument(
+ "--dataset-path",
+ "-s",
+ default="evaluation/data/longbench_v2/longbenchv2_train.json",
+ help="Path to JSON file containing samples",
+ )
+
+ parser.add_argument(
+ "--api-url",
+ default="http://127.0.0.1:8001",
+ help="API service address (default: http://127.0.0.1:8001)",
+ )
+
+ parser.add_argument("--workers", "-c", type=int, default=5, help="Concurrency (default: 5)")
+
+ parser.add_argument(
+ "--timeout", type=float, default=120.0, help="Request timeout in seconds (default: 120)"
+ )
+
+ parser.add_argument(
+ "--top-k",
+ "-k",
+ type=int,
+ default=20,
+ help="Number of results to return per search (default: 20)",
+ )
+
+ parser.add_argument("--version-dir", "-v", default=None, help="Version directory name")
+
+ parser.add_argument(
+ "--limit",
+ "-l",
+ type=int,
+ default=None,
+ help="Limit number of samples to process (for testing, default all)",
+ )
+
+ parser.add_argument(
+ "--mode", "-m", type=str, default="fast", help="Search mode (default: fast)"
+ )
+
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = parse_args()
+
+ print("=" * 60)
+ print("Longbench-v2 Product Search Concurrent Tool")
+ print("=" * 60)
+
+ dataset_path = Path(args.dataset_path)
+ if not dataset_path.exists():
+ raise FileNotFoundError(f"Dataset file not found: {dataset_path}")
+ dataset = _load_dataset_jsonl(dataset_path)
+ if args.limit is not None:
+ dataset = dataset[: args.limit]
+
+ output_dir = os.path.join("evaluation/data/longbench_v2", args.version_dir)
+ os.makedirs(output_dir, exist_ok=True)
+ output_filename = f"{args.lib}_search_results.json"
+ output_path = Path(os.path.join(output_dir, output_filename))
+
+ results, processed_ids = _load_existing_results(output_path)
+ pending = [s for s in dataset if str(s.get("_id")) not in processed_ids]
+ if not pending:
+ return
+
+ client = _get_lib_client(args.lib)
+ metrics = Metrics()
+ start_time = time.time()
+
+ with ThreadPoolExecutor(max_workers=args.workers) as executor:
+
+ def do_search(sample: dict) -> dict:
+ st = time.perf_counter()
+ r = search_one(client, sample, args.lib, args.top_k)
+ dur = time.perf_counter() - st
+ r["duration_ms"] = dur * 1000
+ metrics.record(dur, True)
+ return r
+
+ futures = [executor.submit(do_search, sample) for sample in pending]
+ for idx, f in enumerate(
+ tqdm(as_completed(futures), total=len(futures), desc="Searching"), start=1
+ ):
+ try:
+ r = f.result()
+ results.append(r)
+ if idx % 10 == 0:
+ _save_json_list(output_path, results)
+ except Exception as e:
+ metrics.record(0.0, False, str(e))
+ print(f"[Search] Error: {e}")
+ traceback.print_exc()
+
+ _save_json_list(output_path, results)
+ print(f"[Search] saved {len(results)} rows to {output_path}")
+
+ total_duration = time.time() - start_time
+ summary = metrics.summary()
+ combined_obj = {
+ "results": results,
+ "perf": {
+ "summary": summary,
+ "total_duration": total_duration,
+ "config": {
+ "workers": args.workers,
+ "top_k": args.top_k,
+ "dataset_path": str(dataset_path),
+ "limit": args.limit,
+ "mode": args.mode,
+ },
+ },
+ }
+ tmp = output_path.with_suffix(output_path.suffix + ".tmp")
+ tmp.write_text(json.dumps(combined_obj, ensure_ascii=False, indent=2), encoding="utf-8")
+ os.replace(tmp, output_path)
+
+ print("\n" + "=" * 60)
+ print("Search finished! Statistics:")
+ print("=" * 60)
+ print(f"Total duration: {total_duration:.2f}s")
+ print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}")
+
+ if summary["stats"]:
+ stats = summary["stats"]
+ qps = stats["count"] / total_duration if total_duration > 0 else 0
+ print(f"QPS: {qps:.2f}")
+ print("Latency stats (ms):")
+ print(f" Mean: {stats['mean']:.2f}")
+ print(f" Median: {stats['median']:.2f}")
+ print(f" Min: {stats['min']:.2f}")
+ print(f" Max: {stats['max']:.2f}")
+ print(f" P95: {stats['p95']:.2f}")
+ print(f" P99: {stats['p99']:.2f}")
+
+ if summary["errors"]:
+ print("\nError stats:")
+ for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]:
+ print(f" [{count} times] {error[:100]}...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/mmlongbench/eval_utils/__init__.py b/evaluation/scripts/mmlongbench/eval_utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/evaluation/scripts/xinyu/eval/eval_score_llm.py b/evaluation/scripts/mmlongbench/eval_utils/eval_score.py
similarity index 63%
rename from evaluation/scripts/xinyu/eval/eval_score_llm.py
rename to evaluation/scripts/mmlongbench/eval_utils/eval_score.py
index f5764ce39..02ef6eb53 100644
--- a/evaluation/scripts/xinyu/eval/eval_score_llm.py
+++ b/evaluation/scripts/mmlongbench/eval_utils/eval_score.py
@@ -1,59 +1,33 @@
-import os
import re
-import traceback
from collections import defaultdict
from math import isclose
-from memos.configs.mem_os import MOSConfig
-from memos.llms.factory import LLMFactory
-
-
-openapi_config = {
- "model_name_or_path": "gpt-5-nano",
- "top_k": 50,
- "remove_think_prefix": True,
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
-}
-config = {
- "user_id": "user_name",
- "chat_model": {
- "backend": "openai",
- "config": openapi_config,
- },
- "mem_reader": {
- "backend": "simple_struct",
- "config": {
- "llm": {"backend": "openai", "config": openapi_config},
- "embedder": {
- "backend": "universal_api",
- "config": {
- "provider": "openai",
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "model_name_or_path": "text-embedding-3-large",
- "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
- "chunker": {
- "backend": "sentence",
- "config": {
- "tokenizer_or_token_counter": "gpt2",
- "chunk_size": 512,
- "chunk_overlap": 128,
- "min_sentences_per_chunk": 1,
- },
- },
- },
- },
- "max_turns_window": 20,
- "top_k": 5,
- "enable_textual_memory": True,
- "enable_activation_memory": False,
- "enable_parametric_memory": False,
-}
-mos_config = MOSConfig(**config)
-chat_llm = LLMFactory.from_config(mos_config.chat_model)
+
+def levenshtein_distance(s1, s2):
+ if len(s1) > len(s2):
+ s1, s2 = s2, s1
+
+ distances = range(len(s1) + 1)
+ for i2, c2 in enumerate(s2):
+ distances_ = [i2 + 1]
+ for i1, c1 in enumerate(s1):
+ if c1 == c2:
+ distances_.append(distances[i1])
+ else:
+ distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
+ distances = distances_
+ return distances[-1]
+
+
+def anls_compute(groundtruth, prediction, threshold=0.5):
+ dist = levenshtein_distance(groundtruth, prediction)
+ length = max(len(groundtruth.upper()), len(prediction.upper()))
+ value = 0.0 if length == 0 else float(dist) / float(length)
+ anls = 1.0 - value
+ if anls <= threshold:
+ anls = 0.0
+ return anls
def is_float_equal(
@@ -132,55 +106,48 @@ def isfloat(num):
return False
-def eval_score(question, gt, pred):
- prompt = """
- 你是一个评委,根据问题和标准答案对学生的答案进行打分。打分规则如下:
-
- 完全不对(0分):
- 学生答案与问题无关,未展示出任何相关概念或知识。
- 对了一部分(0.5分):
- 学生答案提供了一些相关信息,但未能直接回答问题。
- 答案中包含部分正确内容,但缺乏关键信息,导致整体理解不清。
- 基本正确(0.7分):
- 学生答案提供了大部分关键信息,不过依然距离标准答案有一定缺失。
- 答案中包含部分关键内容,但缺乏部分信息,导致不够完整。
- 完全正确(1分):
- 学生答案准确地回答了问题,涵盖所有关键信息。
- 表达清晰,逻辑合理,直接且有效地回应了问题。
-
- 问题:{}
-
- 标准答案:{}
-
- 学生答案:{}
- """
-
- max_try = 20
- try_i = 0
- while try_i < max_try:
+def eval_score(gt, pred, answer_type):
+ if answer_type == "Int":
try:
- llm_input_prompt_score = (
- prompt.format(question, gt, pred)
- + """请返回给我一个json:
- {
- "分数": 1,
- "理由": "xxxx"
- }"""
- )
- score = chat_llm.generate(
- [
- {"role": "user", "content": llm_input_prompt_score},
- ]
- )
-
- print(f"score: {score}")
- score_real = eval(score.replace("json", "").replace("\n", "").replace("```", ""))
- return float(score_real["分数"])
+ gt, pred = int(gt), int(float(pred))
+ except Exception:
+ pred = ""
+ score = gt == pred
+ elif answer_type == "Float":
+ try:
+ gt = float(get_clean_string(str(gt)))
+ pred = float(get_clean_string(str(pred)))
except Exception:
- traceback.print_exc()
- print(f"trying num {try_i}")
- try_i += 1
- return -1
+ pred = ""
+ score = is_float_equal(gt, pred, include_percentage=True, is_close=True)
+ elif answer_type in ["Str", "None"]:
+ gt = get_clean_string(gt)
+ pred = get_clean_string(pred)
+ score = gt == pred if is_exact_match(gt) else anls_compute(gt, pred)
+ else:
+ if isinstance(gt, str) and gt.startswith("["):
+ gt = eval(gt)
+ if not isinstance(gt, list):
+ gt = [gt]
+ if isinstance(pred, str) and pred.startswith("["):
+ pred = eval(pred)
+ if not isinstance(pred, list):
+ pred = [pred]
+ print(len(gt), len(pred))
+ if len(gt) != len(pred):
+ score = 0.0
+ else:
+ gt = sorted([get_clean_string(a) for a in gt])
+ pred = sorted([get_clean_string(a) for a in pred])
+ print(gt, pred)
+ if isfloat(gt[0]) or is_exact_match(gt[0]):
+ score = "-".join(gt) == "-".join(pred)
+ else:
+ score = min(
+ [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)]
+ )
+
+ return float(score)
def eval_acc_and_f1(samples):
diff --git a/evaluation/scripts/mmlongbench/mmlongbench_eval.py b/evaluation/scripts/mmlongbench/mmlongbench_eval.py
new file mode 100644
index 000000000..f83e39e2d
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/mmlongbench_eval.py
@@ -0,0 +1,470 @@
+import base64
+import json
+import mimetypes
+import os
+import re
+import sys
+import time
+import traceback
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from datetime import datetime
+from pathlib import Path
+from typing import Any
+
+import openai
+
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from evaluation.scripts.utils.eval_score import eval_acc_and_f1, eval_score, show_results
+from evaluation.scripts.utils.extract_answer import extract_answer
+from evaluation.scripts.utils.prompts import MMLONGBENCH_ANSWER_PROMPT
+from memos.log import get_logger
+
+
+logger = get_logger(__name__)
+
+
+load_dotenv()
+
+# Initialize OpenAI Client
+oai_client = openai.Client(
+ api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+)
+
+
+def _encode_image_to_data_url(image_path: str) -> str | None:
+ """Encode local image file to base64 data URL for OpenAI-compatible image messages.
+
+ Returns a data URL like: data:image/jpeg;base64,<...>
+ """
+ try:
+ mime, _ = mimetypes.guess_type(image_path)
+ if not mime:
+ # default to jpeg
+ mime = "image/jpeg"
+ with open(image_path, "rb") as f:
+ b64 = base64.b64encode(f.read()).decode("ascii")
+ return f"data:{mime};base64,{b64}"
+ except Exception as e:
+ logger.warning(f"Failed to encode image '{image_path}' to data URL: {e}")
+ return None
+
+
+# @lru_cache(maxsize=1)
+def build_images_index() -> dict[str, str]:
+ """Scan `./ppt_test_result` recursively and index images by filename.
+
+ New structure example:
+ ./ppt_test_result//extracted/file_*//auto/images/*.{png,jpg,jpeg,webp,gif}
+
+ Also compatible with previous layouts. Returns mapping:
+ basename (e.g. img_123.jpg) -> absolute path
+ """
+ base_dir = Path("/Users/tianxingshi/Desktop/lcy/ppt_test_result")
+ index: dict[str, str] = {}
+ if not base_dir.exists():
+ return index
+
+ # Recursively find any `auto/images` directories under ppt_test_result
+ for images_dir in base_dir.rglob("auto/images"):
+ if images_dir.is_dir():
+ for img_file in images_dir.iterdir():
+ if img_file.is_file():
+ index[img_file.name] = str(img_file.resolve())
+ return index
+
+
+index_dict = build_images_index()
+
+
+def get_images(sources: list) -> list[str]:
+ """Extract image absolute paths from metadata sources.
+
+ Supports patterns like:  or any 'images/...jpg' substring.
+ Falls back to scanning the ppt_test_result index to resolve basenames.
+ """
+ if not sources:
+ return []
+
+ # Ensure index exists
+
+ found: list[str] = []
+
+ md_img_pattern = re.compile(r"\[Image: images/\s*(.+?)\s*-")
+ images_substr_pattern = re.compile(r"images/[^\s)]+\.(?:png|jpg|jpeg|webp)", re.IGNORECASE)
+ for src in sources:
+ if not src:
+ continue
+ # 1) markdown image syntax
+ for m in md_img_pattern.findall(src):
+ candidate = m.strip()
+ # if it's a relative like 'images/xxx.jpg', resolve via index
+ basename = os.path.basename(candidate)
+ if basename in index_dict:
+ found.append(index_dict[basename])
+ else:
+ # try direct path (absolute or relative)
+ p = Path(candidate)
+ if not p.is_absolute():
+ p = Path.cwd() / p
+ if p.exists():
+ found.append(str(p.resolve()))
+
+ # 2) any 'images/xxx.jpg' substring
+ for m in images_substr_pattern.findall(src):
+ candidate = m.strip()
+ basename = os.path.basename(candidate)
+ if basename in index_dict:
+ found.append(index_dict[basename])
+ else:
+ p = Path(candidate)
+ if not p.is_absolute():
+ p = Path.cwd() / p
+ if p.exists():
+ found.append(str(p.resolve()))
+
+ # Deduplicate preserving order
+ dedup: list[str] = []
+ seen = set()
+ for path in found:
+ if path not in seen:
+ dedup.append(path)
+ seen.add(path)
+ return dedup
+
+
+def add_images_context(
+ current_messages: list[dict[str, Any]], images: list[str]
+) -> list[dict[str, Any]]:
+ """Append images in OpenAI-compatible multi-part format and ensure message structure.
+
+ - Deduplicates image paths.
+ - Ensures a system message exists with a concise CN vision instruction.
+ - Ensures the last user message has multi-part content: [text, image_url...].
+ - Uses base64 data URLs. Limits to 6 images.
+ - In-place modification of `current_messages`.
+ """
+ if not images:
+ return current_messages
+
+ # Deduplicate images while preserving order
+ unique_images: list[str] = []
+ seen_paths: set[str] = set()
+ for p in images:
+ if p not in seen_paths:
+ unique_images.append(p)
+ seen_paths.add(p)
+
+ # Locate or create the last user message
+ user_idx = None
+ for i in range(len(current_messages) - 1, -1, -1):
+ if current_messages[i].get("role") == "user":
+ user_idx = i
+ break
+
+ user_msg = current_messages[user_idx]
+ orig_content = user_msg.get("content", "")
+
+ # Normalize user content to multi-part format using original query as text (no fallback)
+ content_parts: list[dict[str, Any]]
+ if isinstance(orig_content, str):
+ content_parts = [{"type": "text", "text": orig_content}]
+ elif isinstance(orig_content, list):
+ content_parts = orig_content
+ else:
+ content_parts = [{"type": "text", "text": str(orig_content)}]
+
+ # 5) Append up to 3 images as data URLs
+ limit = 6
+ count = 0
+
+ for img_path in unique_images:
+ if count >= limit:
+ break
+ data_url = _encode_image_to_data_url(img_path)
+ if data_url:
+ content_parts.append({"type": "image_url", "image_url": {"url": data_url}})
+ count += 1
+ user_msg["content"] = content_parts
+ current_messages[user_idx] = user_msg
+ logger.info(
+ f"Attached {count} images to user message (deduplicated from {len(images)}), {json.dumps(current_messages, ensure_ascii=False, indent=2)}"
+ )
+ return current_messages
+
+
+def multimodal_answer(
+ oai_client,
+ memories: list[str],
+ question: str,
+ top_k: int = 15,
+ sources: list | None = None,
+) -> tuple[str, int | None]:
+ sources_texts: list[str] = []
+ for source in sources[:top_k]:
+ source = source[0]
+ content = source.get("content") if isinstance(source, dict) else str(source)
+ if content:
+ sources_texts.append(content)
+
+ image_paths = get_images(sources_texts)
+ system_prompt = MMLONGBENCH_ANSWER_PROMPT.format(
+ memories="\n\n".join(memories[:top_k]), question=question
+ )
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": question}]
+ messages = add_images_context(messages, image_paths)
+ for _, msg in enumerate(messages):
+ if msg.get("role") == "user" and isinstance(msg.get("content"), list):
+ img_count = sum(1 for p in msg["content"] if p.get("type") == "image_url")
+ print(
+ f"DEBUG: user message has {len(memories[:top_k])} memories, {img_count} images attached"
+ )
+
+ resp = oai_client.chat.completions.create(
+ model=args.chat_model, messages=messages, temperature=0
+ )
+ return resp.choices[0].message.content or "", resp.usage.prompt_tokens
+
+
+def process_single_item(item: dict, index: int, top_k: int = 20) -> dict:
+ """Process a single evaluation item"""
+ question = item["question"]
+ memories = item.get("memories", [])
+ sources = item.get("sources", [])
+ if not memories:
+ result = {
+ "response": None,
+ "extracted_res": None,
+ "pred": None,
+ "score": 0,
+ "eval_success": False,
+ "eval_error": "",
+ }
+ try:
+ # Get model response
+ response, prompt_tokens = multimodal_answer(oai_client, memories, question, top_k, sources)
+
+ # Extract answer
+ extracted_res = extract_answer(question, response)
+
+ # Parse extracted answer
+ try:
+ pred_ans = (
+ extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip()
+ )
+ except Exception as e:
+ print("extract_answer error**********", e)
+ pred_ans = response.strip()
+
+ # Calculate score
+ score = eval_score(item.get("answer"), pred_ans, item.get("answer_format", "Str"))
+
+ # Build result
+ result = {
+ "response": response,
+ "extracted_res": extracted_res,
+ "pred": pred_ans,
+ "score": score,
+ "prompt_tokens": prompt_tokens,
+ "eval_success": True,
+ "eval_error": None,
+ }
+
+ except Exception as e:
+ traceback.print_exc()
+ result = {
+ "response": None,
+ "extracted_res": None,
+ "pred": None,
+ "score": 0,
+ "eval_success": False,
+ "eval_error": str(e),
+ }
+
+ return {"index": index, "result": result}
+
+
+def run_eval(
+ questions_file: str | Path,
+ output_file: str | Path | None = None,
+ version_dir: str | Path | None = None,
+ max_workers: int = 10,
+ top_k: int = 20,
+) -> None:
+ """
+ Run evaluation
+
+ Args:
+ version_dir: version directory
+ questions_file: Input questions file path
+ output_file: Output file path, overwrites input file if None
+ max_workers: Number of concurrent workers
+ """
+ questions_file = Path(questions_file)
+ output_file = questions_file if output_file is None else Path(output_file)
+
+ # Read input data
+ with open(questions_file, encoding="utf-8") as f:
+ data = json.load(f)
+
+ items = data["results"]
+ total = len(items)
+ print(f"[Info] Starting evaluation, total {total} items, concurrency: {max_workers}")
+
+ # Concurrent processing
+ results_map = {}
+ start_time = time.time()
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures = {
+ executor.submit(process_single_item, item, i, top_k): i for i, item in enumerate(items)
+ }
+
+ # Use tqdm to show progress bar
+ with tqdm(
+ total=total,
+ desc="Evaluation Progress",
+ unit="items",
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
+ ) as pbar:
+ for future in as_completed(futures):
+ result_data = future.result()
+ idx = result_data["index"]
+ result = result_data["result"]
+ results_map[idx] = result
+
+ # Update progress bar, show current score
+ score = result.get("score", 0)
+ success = result.get("eval_success", False)
+ status = f"score={score:.2f}" if success else "ERROR"
+ pbar.set_postfix_str(status)
+ pbar.update(1)
+
+ # Write results to each item in original data
+ for i, item in enumerate(items):
+ if i in results_map:
+ item.update(results_map[i])
+
+ # Update summary info
+ eval_duration = time.time() - start_time
+
+ # Calculate evaluation statistics
+ success_count = sum(1 for r in results_map.values() if r.get("eval_success", False))
+ failed_count = total - success_count
+ scores = [r.get("score", 0) for r in results_map.values() if r.get("eval_success", False)]
+ prompt_tokens_list = [
+ r.get("prompt_tokens")
+ for r in results_map.values()
+ if r.get("eval_success", False) and isinstance(r.get("prompt_tokens"), int)
+ ]
+ avg_prompt_tokens = (
+ (sum(prompt_tokens_list) / len(prompt_tokens_list)) if prompt_tokens_list else 0
+ )
+
+ # Calculate acc and f1
+ eval_results = [{**items[i], **results_map[i]} for i in range(len(items)) if i in results_map]
+ acc, f1 = eval_acc_and_f1(eval_results)
+
+ # Update data summary
+ if "eval_summary" not in data:
+ data["eval_summary"] = {}
+
+ data["eval_summary"] = {
+ "eval_duration_seconds": eval_duration,
+ "total_samples": total,
+ "success_count": success_count,
+ "failed_count": failed_count,
+ "accuracy": acc,
+ "f1_score": f1,
+ "avg_score": sum(scores) / len(scores) if scores else 0,
+ "avg_prompt_tokens": avg_prompt_tokens,
+ "max_workers": max_workers,
+ "eval_timestamp": datetime.now().isoformat(),
+ }
+
+ # Save results to output file
+ output_file.parent.mkdir(parents=True, exist_ok=True)
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(data, f, ensure_ascii=False, indent=2)
+
+ print(f"\n{'=' * 60}")
+ print("[Evaluation Finished]")
+ print(f" Total samples: {total}")
+ print(f" Success: {success_count}, Failed: {failed_count}")
+ print(f" Accuracy: {acc:.4f}")
+ print(f" F1 Score: {f1:.4f}")
+ print(f" Average Score: {data['eval_summary']['avg_score']:.4f}")
+ print(f" Average prompt_tokens: {data['eval_summary']['avg_prompt_tokens']:.2f}")
+ print(f" Duration: {eval_duration:.2f}s")
+ print(f" Results saved to: {output_file}")
+ print(f"{'=' * 60}")
+
+ # Generate detailed report
+ report_path = version_dir / f"{args.lib}_eval_results.txt"
+ show_results(eval_results, show_path=str(report_path))
+ print(f"[Report] Detailed report saved to: {report_path}")
+
+ # Save concise metrics file
+ metrics_path = report_path.with_name(report_path.stem + "_metrics.json")
+
+ metrics = {
+ "accuracy": acc,
+ "f1_score": f1,
+ "avg_score": data["eval_summary"]["avg_score"],
+ "avg_prompt_tokens": data["eval_summary"]["avg_prompt_tokens"],
+ "total_samples": total,
+ "success_count": success_count,
+ "failed_count": failed_count,
+ "eval_duration_seconds": eval_duration,
+ "eval_timestamp": data["eval_summary"]["eval_timestamp"],
+ }
+ with open(metrics_path, "w", encoding="utf-8") as f:
+ json.dump(metrics, f, ensure_ascii=False, indent=2)
+ print(f"[Metrics] Metrics saved to: {metrics_path}")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="MMlongbench Evaluation Script")
+ parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate")
+ parser.add_argument("--workers", "-w", type=int, default=20, help="Concurrent workers")
+ parser.add_argument(
+ "--top-k", "-k", type=int, default=20, help="Top K results to use (default: 20)"
+ )
+ parser.add_argument("--version-dir", "-v", default=None, help="Version directory name")
+ parser.add_argument("--chat-model", "-m", default=None, help="chat model name")
+
+ args = parser.parse_args()
+
+ print("=" * 60)
+ print("MMLongBench-doc Product Eval Tool")
+ print("=" * 60)
+
+ print("[Response model]: ", os.getenv("CHAT_MODEL"))
+
+ base_dir = Path("evaluation/data/mmlongbench")
+ version_dir = base_dir / args.version_dir
+ input_filename = f"{args.lib}_search_results.json"
+ input_path = version_dir / input_filename
+
+ if not input_path.exists():
+ print(f"Error: Input file not found: {input_path}")
+ sys.exit(1)
+
+ output_path = input_path
+
+ print(f"[Info] Input file: {input_path}")
+ print(f"[Info] Output file: {output_path}")
+ print(f"[Response Model]: {args.chat_model}")
+
+ run_eval(
+ questions_file=input_path,
+ output_file=output_path,
+ version_dir=version_dir,
+ max_workers=args.workers,
+ top_k=args.top_k,
+ )
diff --git a/evaluation/scripts/mmlongbench/mmlongbench_ingestion.py b/evaluation/scripts/mmlongbench/mmlongbench_ingestion.py
new file mode 100644
index 000000000..618caf09e
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/mmlongbench_ingestion.py
@@ -0,0 +1,341 @@
+#!/usr/bin/env python3
+
+import argparse
+import json
+import threading
+import time
+
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+
+from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
+
+from evaluation.scripts.utils.metrics import Metrics
+
+
+def read_filenames(filepath: str) -> list[str]:
+ """
+ Read filename list from file
+ Supports one filename per line, automatically removes empty lines and whitespace
+ """
+ filenames = []
+ with open(filepath, encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if line: # Skip empty lines
+ filenames.append(line)
+ return filenames
+
+
+def run_concurrent_add(
+ lib: str,
+ filenames: list[str],
+ url_prefix: str,
+ user_prefix: str,
+ workers: int,
+ source_type: str = "extreme_multimodal",
+ mode: str = "fine",
+ async_mode: str = "sync",
+) -> dict:
+ """
+ Execute concurrent add operations
+
+ Args:
+ lib: Client name
+ filenames: List of filenames
+ url_prefix: URL prefix
+ user_prefix: User ID prefix
+ workers: Concurrency
+ source_type: Source type
+ mode: Mode
+ async_mode: Async mode
+
+ Returns:
+ Statistics result
+ """
+
+ client = _get_lib_client(lib)
+ metrics = Metrics()
+ total_files = len(filenames)
+ completed = 0
+ completed_lock = threading.Lock()
+ user_id = user_prefix
+
+ def add_single_file(filename: str, doc_id: str = ""):
+ nonlocal completed
+
+ file_id = filename # 文件名作为file_id
+ file_data = f"{url_prefix.rstrip('/')}/{filename}" # URL前缀 + 文件名
+
+ base_dir = Path("ppt_test_result")
+ all_md_files = list(base_dir.rglob("*.md"))
+ stem = Path(file_id).stem.lower()
+ name = file_id.lower()
+ md_path = ""
+ for md in all_md_files:
+ pstr = str(md).lower()
+ if (stem and stem in pstr) or (name and name in pstr):
+ md_path = md
+ text = md_path.read_text(encoding="utf-8", errors="ignore")
+
+ start_time = time.perf_counter()
+ user_id = user_prefix + "_" + doc_id
+ writable_cube_ids = [user_id]
+ chat_time = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+
+ result = None
+ try:
+ if lib == "memos":
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "file",
+ "file": {
+ "file_id": file_id,
+ "filename": file_data,
+ "file_data": file_data,
+ },
+ }
+ ],
+ "chat_time": chat_time,
+ }
+ ]
+ result = client.add(
+ messages=messages,
+ user_id=user_id,
+ writable_cube_ids=writable_cube_ids,
+ source_type=source_type,
+ mode=mode,
+ async_mode=async_mode,
+ )
+ elif lib == "supermemory":
+ result = client.add(content=text, user_id=user_id)
+ elif lib == "mem0":
+ chunker = RecursiveCharacterTextSplitter.from_language(
+ language=Language.PYTHON, chunk_size=5120, chunk_overlap=128
+ )
+ paragraphs = [p for p in chunker.split_text(text) if p.strip()]
+ messages = [{"role": "user", "content": p} for p in paragraphs]
+ ts = int(time.time())
+
+ result = client.add(messages=messages, user_id=doc_id, timestamp=ts, batch_size=10)
+
+ duration = time.perf_counter() - start_time
+ metrics.record(duration, True)
+
+ with completed_lock:
+ completed += 1
+ print(
+ f"[{completed}/{total_files}] ✓ Success: {filename} ({duration * 1000:.0f}ms)"
+ )
+
+ return True, result
+
+ except Exception as e:
+ duration = time.perf_counter() - start_time
+ error_msg = str(e)
+ metrics.record(duration, False, error_msg)
+
+ with completed_lock:
+ completed += 1
+ print(f"[{completed}/{total_files}] ✗ Failed: {filename} - {error_msg[:100]}")
+
+ return False, error_msg
+
+ print(f"\nStarting concurrent add for {total_files} files...")
+ print(f"Concurrency: {workers}")
+ print(f"User ID: {user_id}")
+ print(f"URL prefix: {url_prefix}")
+ print("-" * 60)
+
+ start_time = time.time()
+
+ with ThreadPoolExecutor(max_workers=workers) as executor:
+ futures = []
+ for _, filename in enumerate(filenames):
+ doc_id = filename[:-3] + ".pdf"
+ future = executor.submit(add_single_file, filename, doc_id)
+ futures.append((filename, future))
+
+ # Wait for all tasks to complete
+ results = []
+ for filename, future in futures:
+ try:
+ success, result = future.result()
+ results.append({"filename": filename, "success": success, "result": result})
+ except Exception as e:
+ results.append({"filename": filename, "success": False, "result": str(e)})
+
+ end_time = time.time()
+ total_duration = end_time - start_time
+
+ # Print statistics
+ summary = metrics.summary()
+
+ print("\n" + "=" * 60)
+ print("Ingestion finished! Statistics:")
+ print("=" * 60)
+ print(f"Total duration: {total_duration:.2f}s")
+ print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}")
+
+ if summary["stats"]:
+ stats = summary["stats"]
+ qps = stats["count"] / total_duration if total_duration > 0 else 0
+ print(f"QPS: {qps:.2f}")
+ print("Latency stats (ms):")
+ print(f" Mean: {stats['mean']:.2f}")
+ print(f" Median: {stats['median']:.2f}")
+ print(f" Min: {stats['min']:.2f}")
+ print(f" Max: {stats['max']:.2f}")
+ print(f" P95: {stats['p95']:.2f}")
+ print(f" P99: {stats['p99']:.2f}")
+
+ if summary["errors"]:
+ print("\nError statistics:")
+ for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]:
+ print(f" [{count} times] {error[:100]}...")
+
+ return {"summary": summary, "total_duration": total_duration, "results": results}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="MMLongbench-doc Product Add Concurrent Script",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate")
+ parser.add_argument(
+ "--filenames-file",
+ "-f",
+ default="evaluation/data/mmlongbench/md_file_list.txt",
+ help="Path to text file containing list of filenames (one per line)",
+ )
+
+ parser.add_argument(
+ "--url-prefix",
+ "-u",
+ default="https://memos-knowledge-base-file-pre.oss-cn-shanghai.aliyuncs.com/ppt_md_files/",
+ help="URL prefix to be prepended to filenames",
+ )
+
+ parser.add_argument(
+ "--api-url",
+ default="http://127.0.0.1:8001",
+ help="MemOS API address (default: http://127.0.0.1:8001)",
+ )
+
+ parser.add_argument("--workers", "-w", type=int, default=5, help="Concurrency (default: 5)")
+
+ parser.add_argument(
+ "--timeout", type=float, default=1200, help="Request timeout in seconds (default: 120)"
+ )
+
+ parser.add_argument(
+ "--source-type",
+ default="extreme_multimodal",
+ help="Source type (default: extreme_multimodal)",
+ )
+
+ parser.add_argument(
+ "--mode", default="fine", choices=["fine", "fast"], help="Processing mode (default: fine)"
+ )
+
+ parser.add_argument(
+ "--async-mode", default="sync", choices=["sync", "async"], help="Async mode (default: sync)"
+ )
+
+ parser.add_argument("--version-dir", "-v", default=None, help="Version directory name")
+
+ return parser.parse_args()
+
+
+def _get_lib_client(lib: str):
+ if lib == "memos":
+ from evaluation.scripts.utils.client import MemosApiClient
+
+ return MemosApiClient()
+ if lib == "mem0":
+ from evaluation.scripts.utils.client import Mem0Client
+
+ return Mem0Client(enable_graph=False)
+ if lib == "supermemory":
+ from evaluation.scripts.utils.client import SupermemoryClient
+
+ return SupermemoryClient()
+
+
+def main():
+ args = parse_args()
+
+ print("=" * 60)
+ print("MMLongbench-doc Product Add Concurrent Tool")
+ print("=" * 60)
+
+ # Read filename list
+ print(f"\nReading filename list: {args.filenames_file}")
+ try:
+ filenames = read_filenames(args.filenames_file)
+ print(f"Read {len(filenames)} filenames")
+ if len(filenames) == 0:
+ print("Error: Filename list is empty!")
+ return
+
+ # Show first few filenames
+ print("First 5 filenames:")
+ for fn in filenames[:5]:
+ print(f" - {fn}")
+ if len(filenames) > 5:
+ print(f" ... and {len(filenames) - 5} more files")
+
+ except FileNotFoundError:
+ print(f"Error: File not found {args.filenames_file}")
+ return
+ except Exception as e:
+ print(f"Error: Failed to read file - {e}")
+ return
+
+ # Execute concurrent add
+ result = run_concurrent_add(
+ lib=args.lib,
+ filenames=filenames,
+ url_prefix=args.url_prefix,
+ user_prefix=args.version_dir,
+ workers=args.workers,
+ source_type=args.source_type,
+ mode=args.mode,
+ async_mode=args.async_mode,
+ )
+
+ # Determine output file path
+ import os
+
+ version_output_dir = os.path.join("evaluation/data/mmlongbench", args.version_dir)
+ os.makedirs(version_output_dir, exist_ok=True)
+ output_path = os.path.join(version_output_dir, f"{args.lib}_add_results.json")
+
+ # Save results to file
+ if output_path:
+ with open(output_path, "w", encoding="utf-8") as f:
+ # Remove non-serializable content
+ output_data = {
+ "summary": result["summary"],
+ "total_duration": result["total_duration"],
+ "config": {
+ "filenames_file": args.filenames_file,
+ "url_prefix": args.url_prefix,
+ "api_url": args.api_url,
+ "concurrency": args.workers,
+ "source_type": args.source_type,
+ "mode": args.mode,
+ "async_mode": args.async_mode,
+ "version_dir": args.version_dir,
+ },
+ }
+ json.dump(output_data, f, ensure_ascii=False, indent=2)
+ print(f"\nResults saved to: {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/mmlongbench/mmlongbench_old.py b/evaluation/scripts/mmlongbench/mmlongbench_old.py
new file mode 100644
index 000000000..9a887e987
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/mmlongbench_old.py
@@ -0,0 +1,403 @@
+import argparse
+import json
+import os
+import sys
+import time
+
+from collections.abc import Iterator
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+
+from dotenv import load_dotenv
+from eval.eval_score import eval_acc_and_f1, eval_score, show_results # type: ignore
+from eval.extract_answer import extract_answer # type: ignore
+from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
+from openai import OpenAI
+
+
+# Ensure project paths for imports
+PROJECT_ROOT = Path(__file__).resolve().parents[3]
+SCRIPTS_ROOT = Path(__file__).resolve().parents[1]
+SRC_ROOT = PROJECT_ROOT / "src"
+sys.path.append(str(SCRIPTS_ROOT))
+sys.path.append(str(SRC_ROOT))
+load_dotenv()
+max_retries = 5
+
+
+def iter_markdown_files(base_dir: str | Path) -> Iterator[Path]:
+ base = Path(base_dir)
+ if not base.exists():
+ return
+ # glob all 'auto/*.md'
+ for md in base.rglob("auto/*.md"):
+ if md.is_file():
+ yield md
+
+
+def _get_clients():
+ from utils.client import MemosApiClient # type: ignore
+ from utils.prompts import MEMOS_CONTEXT_TEMPLATE # type: ignore
+
+ from memos.mem_os.core import add_images_context, get_images # type: ignore
+
+ memos_client = MemosApiClient()
+ openai_client = OpenAI(
+ api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL")
+ )
+ return memos_client, openai_client, MEMOS_CONTEXT_TEMPLATE, add_images_context, get_images
+
+
+def _get_lib_client(lib: str):
+ if lib == "mem0":
+ from utils.client import Mem0Client # type: ignore
+
+ return Mem0Client(enable_graph=False)
+ if lib == "supermemory":
+ from utils.client import SupermemoryClient # type: ignore
+
+ return SupermemoryClient()
+ from utils.client import MemosApiClient # type: ignore
+
+ return MemosApiClient()
+
+
+def _load_existing_results(output_path):
+ completed_pairs: set[tuple[str, str]] = set()
+ completed_docs: set[str] = set()
+ existing: list[dict] = []
+ p = Path(output_path)
+ if p.exists():
+ try:
+ existing = json.loads(p.read_text(encoding="utf-8"))
+ for r in existing:
+ did = str(r.get("doc_id") or "").strip()
+ q = str(r.get("question") or "").strip()
+ # normalize whitespace for robust resume
+ did_norm = did
+ q_norm = " ".join(q.split())
+ if did:
+ completed_docs.add(did)
+ if did_norm and q_norm:
+ completed_pairs.add((did_norm, q_norm))
+ except Exception:
+ existing = []
+ return existing, completed_pairs, completed_docs
+
+
+def add_context(client, doc_id: str, md_path: Path, lib) -> None:
+ text = md_path.read_text(encoding="utf-8", errors="ignore")
+ print(f"[Add context] doc_id={doc_id} path={md_path}")
+
+ if lib == "memos":
+ messages = [
+ {
+ "role": "user",
+ "content": {
+ "type": "file",
+ "file": {"file_id": doc_id, "filename": doc_id, "file_data": text},
+ },
+ }
+ ]
+ try:
+ client.add(messages=messages, user_id=doc_id, conv_id=doc_id)
+ print(f"[Add-memos] user={doc_id} total={len(messages)}")
+ except Exception as e:
+ print(f"[Add-memos] failed: {e}")
+
+ elif lib == "mem0":
+ chunker = RecursiveCharacterTextSplitter.from_language(
+ language=Language.PYTHON, chunk_size=512, chunk_overlap=128
+ )
+ paragraphs = [p for p in chunker.split_text(text) if p.strip()]
+
+ messages = [{"role": "user", "content": p} for p in paragraphs]
+ ts = int(time.time())
+ try:
+ client.add(messages=messages, user_id=doc_id, timestamp=ts, batch_size=10)
+ print(f"[Add-mem0] user={doc_id} total={len(messages)}")
+ except Exception as e:
+ print(f"[Add-mem0] failed: {e}")
+ elif lib == "supermemory":
+ try:
+ doc_id = "stx_" + doc_id.replace(".pdf", "")
+ client.add(content=text, user_id=doc_id)
+ print(f"[Add-supermemory] user={doc_id}")
+ except Exception as e:
+ print(f"[Add-supermemory] failed: {e}")
+
+
+def memos_search(
+ client, get_images, user_id: str, query: str, top_k: int = 15
+) -> tuple[list[str], list[str]]:
+ results = client.search(query=query, user_id=user_id, top_k=top_k)
+ memories = results["text_mem"][0]["memories"]
+ mem_texts = [m["memory"] for m in memories]
+
+ # Collect possible image paths from memory texts (and any source content if present)
+ sources_texts: list[str] = []
+ for m in memories:
+ srcs = (m.get("metadata", {}) or {}).get("sources") or []
+ for s in srcs:
+ content = s.get("content") if isinstance(s, dict) else str(s)
+ if content:
+ sources_texts.append(content)
+
+ image_paths = get_images(sources_texts)
+ print(
+ f"[Search] user={user_id} top_k={top_k} memories={len(memories), len(mem_texts)} images={len(image_paths)}"
+ )
+ return mem_texts, image_paths
+
+
+def mem0_search(
+ client, get_images, user_id: str, query: str, top_k: int = 15
+) -> tuple[list[str], list[str]]:
+ res = client.search(query, user_id, top_k)
+ results = res.get("results", [])
+ mem_texts = [m.get("memory", "") for m in results if m.get("memory")]
+ image_paths = get_images(mem_texts)
+ print(
+ f"[Search] user={user_id} top_k={top_k} memories={len(results)} images={len(image_paths)}"
+ )
+ return mem_texts, image_paths
+
+
+def supermemory_search(
+ client, get_images, user_id: str, query: str, top_k: int = 15
+) -> tuple[list[str], list[str]]:
+ chunk_list = client.search(query, user_id, top_k)
+ image_paths = get_images(chunk_list)
+ print(
+ f"[Search] user={user_id} top_k={top_k} memories={len(chunk_list)} images={len(image_paths)}"
+ )
+ return chunk_list, image_paths
+
+
+def multimodal_answer(
+ add_images_context, oai_client, memories: list[str], question: str, image_paths: list[str]
+) -> tuple[str, int]:
+ from memos.mem_os.core import MOSCore # type: ignore
+
+ system_prompt = MOSCore._build_system_prompt(MOSCore.__new__(MOSCore), memories)
+
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": question}]
+ add_images_context(messages, image_paths)
+ print("[Response model]:", os.getenv("CHAT_MODEL"))
+ resp = oai_client.chat.completions.create(
+ model=os.getenv("CHAT_MODEL"), messages=messages, temperature=0
+ )
+ return resp.choices[0].message.content or "", resp.usage.prompt_tokens
+
+
+def main():
+ parser = argparse.ArgumentParser(description="MMLongBench Evaluation Script")
+ parser.add_argument(
+ "--lib",
+ type=str,
+ default="supermemory",
+ help="Backend library to use (memos, mem0, supermemory)",
+ )
+ args = parser.parse_args()
+
+ # Hardcoded parameters
+ ppt_root = "ppt_test_result"
+ questions_file = "evaluation/data/mmlongbench/samples.json"
+ top_k = 15
+ workers = 4
+ lib = args.lib
+
+ print("[Memory util]:", lib)
+
+ start_time = time.time()
+ client, oai_client, memos_context_template, add_images_context, get_images = _get_clients()
+ if questions_file and Path(questions_file).exists():
+ data = json.loads(Path(questions_file).read_text(encoding="utf-8"))
+ print(f"[Load] samples={len(data)} file={questions_file}")
+
+ # Build allowed doc list from documents directory (align with eval_docs grouping by doc_id)
+ docs_dir = Path("evaluation/data/mmlongbench/documents")
+ allowed_docs: set[str] = set()
+ if docs_dir.exists():
+ for f in docs_dir.iterdir():
+ if f.is_file():
+ allowed_docs.add(f.name)
+ if allowed_docs:
+ print(f"[Filter] allowed_docs={len(allowed_docs)} from {docs_dir}")
+
+ # Determine doc_ids present in samples and apply allowed filter
+ doc_ids_in_samples = {str(s.get("doc_id") or "").strip() for s in data if s.get("doc_id")}
+ doc_list = [d for d in doc_ids_in_samples if (not allowed_docs or d in allowed_docs)]
+ doc_list.sort()
+
+ output_path = f"evaluation/data/mmlongbench/test/{lib}_add_fine_search_fine_results.json"
+ report_path = Path(
+ f"evaluation/data/mmlongbench/test/{lib}_add_fine_search_fine_report.txt"
+ )
+
+ # Resume state
+ existing, completed_pairs, completed_docs = _load_existing_results(output_path)
+ print(f"[Resume] loaded_results={len(existing)} completed_docs={len(completed_docs)}")
+ results: list[dict] = list(existing)
+ ingested_doc_ids: set[str] = set(completed_docs)
+
+ base_dir = Path(ppt_root)
+ all_md_files: list[Path] = []
+ if base_dir.exists():
+ all_md_files = list(base_dir.rglob("*.md"))
+
+ def _find_md_for_doc(doc_id_val: str) -> Path | None:
+ stem = Path(doc_id_val).stem.lower()
+ name = doc_id_val.lower()
+ for md in all_md_files:
+ pstr = str(md).lower()
+ if (stem and stem in pstr) or (name and name in pstr):
+ return md
+ return None
+
+ to_ingest: list[tuple[str, Path]] = []
+ for did in doc_list:
+ if did and did not in ingested_doc_ids:
+ mdp = _find_md_for_doc(did)
+ if mdp is not None:
+ to_ingest.append((did, mdp))
+ else:
+ print(f"[Skip] markdown not found for doc_id={did}")
+
+ # Phase 1: Ingestion
+ print("Phase 1: Ingesting context...")
+ if to_ingest:
+ print(f"[Ingest-Concurrent] tasks={len(to_ingest)} from {ppt_root}")
+
+ def _ingest_one(doc_id_local: str, md_path_local: Path, lib_local: str = lib) -> str:
+ user_id_local = doc_id_local
+ c_local = client if lib_local == "memos" else _get_lib_client(lib_local)
+ add_context(c_local, user_id_local, md_path_local, lib=lib_local)
+ return doc_id_local
+
+ with ThreadPoolExecutor(max_workers=workers) as executor:
+ futures = [executor.submit(_ingest_one, did, mdp) for did, mdp in to_ingest]
+ for f in as_completed(futures):
+ try:
+ done_id = f.result()
+ ingested_doc_ids.add(done_id)
+ except Exception as e:
+ print(f"[Add-Error] {e}")
+
+ # Phase 2: Evaluation
+ print("Phase 2: Evaluating...")
+ for doc_id in doc_list:
+ if not doc_id:
+ continue
+ print(f"\n===== [Doc] {doc_id} =====")
+ user_id = doc_id
+
+ doc_samples = [s for s in data if str(s.get("doc_id") or "").strip() == doc_id]
+
+ def _process_item(
+ item: dict,
+ doc_id_local: str = doc_id,
+ user_id_local: str = user_id,
+ lib_local: str = lib,
+ ) -> dict:
+ q = item["question"]
+ q_norm_local = " ".join(str(q).split())
+ if (doc_id_local, q_norm_local) in completed_pairs:
+ return {"skip": True}
+ if lib_local == "memos":
+ memories, images = memos_search(
+ client, get_images, user_id_local, q, top_k=top_k
+ )
+ elif lib_local == "mem0":
+ c_local = _get_lib_client(lib_local)
+ memories, images = mem0_search(
+ c_local, get_images, user_id_local, q, top_k=top_k
+ )
+ elif lib_local == "supermemory":
+ c_local = _get_lib_client(lib_local)
+ memories, images = supermemory_search(
+ c_local, get_images, user_id_local, q, top_k=top_k
+ )
+ else:
+ memories, images = [], []
+ resp, prompt_tokens = multimodal_answer(
+ add_images_context, oai_client, memories, q, images
+ )
+ with open(
+ Path("evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md"),
+ encoding="utf-8",
+ ) as f:
+ prompt_local = f.read()
+ extracted_res_local = extract_answer(q, resp, prompt_local)
+ try:
+ pred_ans_local = (
+ extracted_res_local.split("Answer format:")[0]
+ .split("Extracted answer:")[1]
+ .strip()
+ )
+ except Exception:
+ pred_ans_local = resp.strip()
+ score_local = eval_score(
+ item.get("answer"), pred_ans_local, item.get("answer_format", "Str")
+ )
+ sr = dict(item)
+ sr["response"] = resp
+ sr["extracted_res"] = extracted_res_local
+ sr["pred"] = pred_ans_local
+ sr["score"] = score_local
+ sr["q_norm"] = q_norm_local
+ sr["images"] = images
+ sr["prompt_tokens"] = prompt_tokens
+ sr["memory_used"] = memories
+ return sr
+
+ with ThreadPoolExecutor(max_workers=workers) as executor:
+ futures = [executor.submit(_process_item, it) for it in doc_samples]
+ for f in as_completed(futures):
+ try:
+ res = f.result()
+ if res.get("skip"):
+ continue
+ print("[images used]:", res.get("images") or [])
+ print(
+ f"[QA] doc_id={doc_id} images={len(res.get('images') or [])} score={res.get('score')}"
+ )
+ results.append(
+ {k: v for k, v in res.items() if k not in ("q_norm", "images")}
+ )
+ completed_pairs.add((doc_id, res.get("q_norm") or ""))
+
+ print("[Question]:", res.get("question"))
+ print("[Answer]:", res.get("pred"))
+ print("[Ground truth]:", res.get("answer"))
+ print("[Score]:", res.get("score"))
+ print("[Prompt Tokens]:", res.get("prompt_tokens"))
+ out_json = Path(output_path)
+ out_json.parent.mkdir(parents=True, exist_ok=True)
+ out_json.write_text(
+ json.dumps(results, ensure_ascii=False, indent=2), encoding="utf-8"
+ )
+ acc, f1 = eval_acc_and_f1(results)
+ total_target = sum(
+ 1 for s in data if str(s.get("doc_id") or "") in doc_list
+ )
+ total_tokens = sum(r.get("prompt_tokens", 0) for r in results)
+ avg_tokens = round(total_tokens / len(results), 2) if results else 0
+ print(
+ f"[Metric] acc={acc} f1={f1} avg_tokens={avg_tokens} processed={len(results)}/{total_target}"
+ )
+ except Exception as e:
+ print(f"[Error] processing item: {e}")
+
+ show_results(results, show_path=str(report_path))
+ print(f"[Report] saved to {report_path}")
+
+ end_time = time.time()
+ total_duration = end_time - start_time
+ with open(report_path, "a", encoding="utf-8") as f:
+ f.write(f"\nTotal Evaluation Time: {total_duration:.2f} seconds\n")
+ print(f"[Time] Total duration: {total_duration:.2f}s")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/mmlongbench/mmlongbench_search.py b/evaluation/scripts/mmlongbench/mmlongbench_search.py
new file mode 100644
index 000000000..59c8e86ea
--- /dev/null
+++ b/evaluation/scripts/mmlongbench/mmlongbench_search.py
@@ -0,0 +1,383 @@
+#!/usr/bin/env python3
+
+import argparse
+import json
+import threading
+import time
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from evaluation.scripts.utils.metrics import Metrics
+
+
+def load_samples(filepath: str) -> list[dict]:
+ """
+ Read sample list from JSON file
+ """
+ with open(filepath, encoding="utf-8") as f:
+ samples = json.load(f)
+ return samples
+
+
+def memos_search(
+ client, user_id: str, query: str, top_k: int, mode: str, readable_cube_ids: list[str]
+) -> tuple[list[str], list[str]]:
+ results = client.search(
+ query=query, user_id=user_id, readable_cube_ids=readable_cube_ids, top_k=top_k, mode=mode
+ )
+ sources = []
+ if "text_mem" in results["data"] and results["data"]["text_mem"]:
+ memories = results["data"]["text_mem"][0].get("memories", [])
+ sources.extend(
+ m["metadata"].get("sources", []) for m in memories if m["metadata"].get("sources", [])
+ )
+ return [m.get("memory", "") for m in memories], sources
+ return [], []
+
+
+def mem0_search(client, user_id: str, query: str, top_k: int = 15) -> tuple[list[str], list[str]]:
+ res = client.search(query, user_id, top_k)
+ results = res.get("results", [])
+ mem_texts = [m.get("memory", "") for m in results if m.get("memory")]
+ return mem_texts, mem_texts
+
+
+def supermemory_search(
+ client, user_id: str, query: str, top_k: int = 15
+) -> tuple[list[str], list[str]]:
+ chunk_list = client.search(query, user_id, top_k)
+ print(chunk_list)
+
+ return chunk_list, chunk_list
+
+
+def _get_lib_client(lib: str):
+ if lib == "memos":
+ from evaluation.scripts.utils.client import MemosApiClient
+
+ return MemosApiClient()
+ if lib == "mem0":
+ from evaluation.scripts.utils.client import Mem0Client
+
+ return Mem0Client(enable_graph=False)
+ if lib == "supermemory":
+ from evaluation.scripts.utils.client import SupermemoryClient
+
+ return SupermemoryClient()
+
+
+def run_concurrent_search(
+ lib: str, samples: list[dict], user_prefix: str, concurrency: int, top_k: int, mode: str
+) -> dict:
+ """
+ Execute concurrent search operations
+
+ Args:
+ lib: Client name
+ samples: Sample list, each containing doc_id and question
+ user_prefix: User ID prefix
+ concurrency: Concurrency
+ top_k: Number of results to return
+ mode: Query mode ['fast', 'fine']
+
+ Returns:
+ Search results
+ """
+
+ client = _get_lib_client(lib)
+ metrics = Metrics()
+ total_samples = len(samples)
+ completed = 0
+ completed_lock = threading.Lock()
+
+ # 用于存储所有搜索结果
+ all_results = []
+ results_lock = threading.Lock()
+
+ user_id = user_prefix
+
+ def search_single(sample: dict, index: int):
+ nonlocal completed
+
+ doc_id = sample.get("doc_id", "")
+ question = sample.get("question", "")
+
+ user_id = user_prefix + "_" + doc_id
+ readable_cube_ids = [user_id]
+ start_time = time.perf_counter()
+ try:
+ memories, sources = [], []
+ if lib == "memos":
+ memories, sources = memos_search(
+ client=client,
+ query=question,
+ user_id=user_id,
+ readable_cube_ids=readable_cube_ids,
+ top_k=top_k,
+ mode=mode,
+ )
+ elif lib == "mem0":
+ memories, sources = mem0_search(client, user_id, question, top_k=top_k)
+ elif lib == "supermemory":
+ memories, sources = supermemory_search(client, user_id, question, top_k=top_k)
+
+ duration = time.perf_counter() - start_time
+ metrics.record(duration, True)
+
+ result = {
+ "index": index,
+ "doc_id": doc_id,
+ "question": question,
+ "answer": sample.get("answer", ""),
+ "evidence_pages": sample.get("evidence_pages", ""),
+ "evidence_sources": sample.get("evidence_sources", ""),
+ "answer_format": sample.get("answer_format", ""),
+ "doc_type": sample.get("doc_type", ""),
+ "memories": memories,
+ "sources": sources,
+ "memory_count": len(memories),
+ "success": True,
+ "duration_ms": duration * 1000,
+ "mode": mode,
+ }
+
+ with results_lock:
+ all_results.append(result)
+
+ with completed_lock:
+ completed += 1
+ print(
+ f"[{completed}/{total_samples}] ✓ Success: {doc_id[:30]}... ({duration * 1000:.0f}ms, {len(memories)} memories)"
+ )
+
+ return True, result
+
+ except Exception as e:
+ duration = time.perf_counter() - start_time
+ error_msg = str(e)
+ metrics.record(duration, False, error_msg)
+
+ result = {
+ "index": index,
+ "doc_id": doc_id,
+ "question": question,
+ "answer": sample.get("answer", ""),
+ "evidence_pages": sample.get("evidence_pages", ""),
+ "evidence_sources": sample.get("evidence_sources", ""),
+ "answer_format": sample.get("answer_format", ""),
+ "doc_type": sample.get("doc_type", ""),
+ "memories": [],
+ "memory_count": 0,
+ "success": False,
+ "error": error_msg,
+ "duration_ms": duration * 1000,
+ }
+
+ with results_lock:
+ all_results.append(result)
+
+ with completed_lock:
+ completed += 1
+ print(
+ f"[{completed}/{total_samples}] ✗ Failed: {doc_id[:30]}... - {error_msg[:80]}"
+ )
+
+ return False, result
+
+ print(f"\nStarting concurrent search for {total_samples} questions...")
+ print(f"Concurrency: {concurrency}")
+ print(f"User ID: {user_id}")
+ print(f"Top-K: {top_k}")
+ print("-" * 60)
+
+ start_time = time.time()
+
+ with ThreadPoolExecutor(max_workers=concurrency) as executor:
+ futures = []
+ for i, sample in enumerate(samples):
+ future = executor.submit(search_single, sample, i)
+ futures.append(future)
+
+ # Wait for all tasks to complete
+ for future in as_completed(futures):
+ try:
+ future.result()
+ except Exception as e:
+ print(f"Task execution exception: {e}")
+
+ end_time = time.time()
+ total_duration = end_time - start_time
+
+ # Sort results by original index
+ all_results.sort(key=lambda x: x["index"])
+
+ # Print statistics
+ summary = metrics.summary()
+
+ print("\n" + "=" * 60)
+ print("Search finished! Statistics:")
+ print("=" * 60)
+ print(f"Total duration: {total_duration:.2f}s")
+ print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}")
+
+ if summary["stats"]:
+ stats = summary["stats"]
+ qps = stats["count"] / total_duration if total_duration > 0 else 0
+ print(f"QPS: {qps:.2f}")
+ print("Latency stats (ms):")
+ print(f" Mean: {stats['mean']:.2f}")
+ print(f" Median: {stats['median']:.2f}")
+ print(f" Min: {stats['min']:.2f}")
+ print(f" Max: {stats['max']:.2f}")
+ print(f" P95: {stats['p95']:.2f}")
+ print(f" P99: {stats['p99']:.2f}")
+
+ if summary["errors"]:
+ print("\nError statistics:")
+ for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]:
+ print(f" [{count} times] {error[:100]}...")
+
+ return {"summary": summary, "total_duration": total_duration, "results": all_results}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="MMLongbench-doc Product Search Concurrent Script",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate")
+
+ parser.add_argument(
+ "--samples-file",
+ "-s",
+ default="evaluation/data/mmlongbench/samples.json",
+ help="Path to JSON file containing samples",
+ )
+
+ parser.add_argument(
+ "--api-url",
+ default="http://127.0.0.1:8001",
+ help="API service address (default: http://127.0.0.1:8001)",
+ )
+
+ parser.add_argument("--api-key", default="", help="API key (optional)")
+
+ parser.add_argument("--workers", "-c", type=int, default=5, help="Concurrency (default: 5)")
+
+ parser.add_argument(
+ "--timeout", type=float, default=120.0, help="Request timeout in seconds (default: 120)"
+ )
+
+ parser.add_argument(
+ "--top-k",
+ "-k",
+ type=int,
+ default=20,
+ help="Number of results to return per search (default: 20)",
+ )
+
+ parser.add_argument("--version-dir", "-v", default=None, help="Version directory name")
+
+ parser.add_argument(
+ "--limit",
+ "-l",
+ type=int,
+ default=None,
+ help="Limit number of samples to process (for testing, default all)",
+ )
+
+ parser.add_argument(
+ "--mode", "-m", type=str, default="fast", help="Search mode (default: fast)"
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ print("=" * 60)
+ print("MMLongbench-doc Product Search Concurrent Tool")
+ print("=" * 60)
+
+ # Read sample data
+ samples_path = "evaluation/data/mmlongbench/samples.json"
+ print(f"\nReading sample file: {samples_path}")
+ try:
+ samples = load_samples(samples_path)
+ print(f"Total {len(samples)} samples read")
+
+ # Limit number of samples
+ if args.limit and args.limit > 0:
+ samples = samples[: args.limit]
+ print(f"Limiting to first {len(samples)} samples")
+
+ if len(samples) == 0:
+ print("Error: Sample list is empty!")
+ return
+
+ # Show first few samples
+ print("First 3 samples:")
+ for sample in samples[:3]:
+ doc_id = sample.get("doc_id", "N/A")
+ question = sample.get("question", "N/A")[:50]
+ print(f" - {doc_id}: {question}...")
+ if len(samples) > 3:
+ print(f" ... and {len(samples) - 3} more samples")
+
+ except FileNotFoundError:
+ print(f"Error: File not found {args.samples_file}")
+ return
+ except json.JSONDecodeError as e:
+ print(f"Error: JSON parse failed - {e}")
+ return
+ except Exception as e:
+ print(f"Error: Failed to read file - {e}")
+ return
+
+ # Execute concurrent search
+ result = run_concurrent_search(
+ lib=args.lib,
+ samples=samples,
+ user_prefix=args.version_dir,
+ concurrency=args.workers,
+ top_k=args.top_k,
+ mode=args.mode,
+ )
+
+ # Determine output file path
+ import os
+
+ output_dir = os.path.join("evaluation/data/mmlongbench", args.version_dir)
+ os.makedirs(output_dir, exist_ok=True)
+ output_filename = f"{args.lib}_search_results.json"
+ output_path = os.path.join(output_dir, output_filename)
+
+ # Save results
+ output_data = {
+ "summary": result["summary"],
+ "total_duration": result["total_duration"],
+ "config": {
+ "samples_file": args.samples_file,
+ "api_url": args.api_url,
+ "workers": args.workers,
+ "top_k": args.top_k,
+ },
+ "results": result["results"],
+ }
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(output_data, f, ensure_ascii=False, indent=2)
+
+ print(f"\nResults saved to: {output_path}")
+
+ # Calculate valid results
+ success_results = [r for r in result["results"] if r["success"]]
+ total_memories = sum(r["memory_count"] for r in success_results)
+ avg_memories = total_memories / len(success_results) if success_results else 0
+ print(f"Average {avg_memories:.1f} memories returned per question")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluation/scripts/run_hotpot_eval.sh b/evaluation/scripts/run_hotpot_eval.sh
new file mode 100755
index 000000000..3ee5177c6
--- /dev/null
+++ b/evaluation/scripts/run_hotpot_eval.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+set -e
+
+ROOT_DIR=$(cd "$(dirname "$0")/../.." && pwd)
+cd "$ROOT_DIR"
+export PYTHONPATH="$ROOT_DIR"
+
+# Common parameters
+LIB="mem0"
+WORKERS=20
+TOPK=7
+ADD_MODE="fine"
+SEARCH_MODE="fine"
+VERSION_DIR="test_0101_07"
+ASYNC_MODE="sync"
+CHAT_MODEL="gpt-4o-mini"
+LIMIT=100
+
+# Add / Ingestion
+#echo "Running hotpot_ingestion.py..."
+#python -m evaluation.scripts.hotpot.hotpot_ingestion \
+# --lib "$LIB" \
+# --workers "$WORKERS" \
+# --version-dir "$VERSION_DIR" \
+# --mode "$ADD_MODE" \
+# --async-mode "$ASYNC_MODE" \
+# --limit "$LIMIT"
+
+# Search
+#echo "Running hotpot_search.py..."
+#python -m evaluation.scripts.hotpot.hotpot_search \
+# --lib "$LIB" \
+# --workers "$WORKERS" \
+# --version-dir "$VERSION_DIR" \
+# --top-k "$TOPK" \
+# --search-mode "$SEARCH_MODE" \
+# --limit "$LIMIT"
+
+# Eval
+echo "Running hotpot_eval.py..."
+python -m evaluation.scripts.hotpot.hotpot_eval \
+ --lib "$LIB" \
+ --version-dir "$VERSION_DIR" \
+ --workers "$WORKERS" \
+ --search-mode "$SEARCH_MODE" \
+ --chat-model "$CHAT_MODEL"
+
+echo "All scripts completed successfully!"
diff --git a/evaluation/scripts/run_longbench_v2_eval.sh b/evaluation/scripts/run_longbench_v2_eval.sh
new file mode 100755
index 000000000..9af4572fc
--- /dev/null
+++ b/evaluation/scripts/run_longbench_v2_eval.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+set -e
+
+ROOT_DIR=$(cd "$(dirname "$0")/../.." && pwd)
+cd "$ROOT_DIR"
+export PYTHONPATH="$ROOT_DIR"
+
+# Common parameters
+LIB="supermemory"
+WORKERS=5
+TOPK=20
+ADD_MODE="fine"
+SEARCH_MODE="fine"
+VERSION_DIR="test_1231"
+ASYNC_MODE="sync"
+CHAT_MODEL="gpt-4o-mini"
+
+# Add / Ingestion
+echo "Running longbench_v2_ingestion.py..."
+python -m evaluation.scripts.longbench_v2.longbench_v2_ingestion \
+ --lib "$LIB" \
+ --workers "$WORKERS" \
+ --version-dir "$VERSION_DIR" \
+ --mode "$ADD_MODE" \
+ --async-mode "$ASYNC_MODE"
+
+# Search
+echo "Running longbench_v2_search.py..."
+python -m evaluation.scripts.longbench_v2.longbench_v2_search \
+ --lib "$LIB" \
+ --workers "$WORKERS" \
+ --version-dir "$VERSION_DIR" \
+ --top-k "$TOPK" \
+ --mode "$SEARCH_MODE" \
+ --limit 30
+
+# Eval
+echo "Running longbench_v2_eval.py..."
+python -m evaluation.scripts.longbench_v2.longbench_v2_eval \
+ --lib "$LIB" \
+ --version-dir "$VERSION_DIR" \
+ --workers "$WORKERS" \
+ --chat-model "$CHAT_MODEL"
+
+echo "All scripts completed successfully!"
diff --git a/evaluation/scripts/run_mmlongbench_eval.sh b/evaluation/scripts/run_mmlongbench_eval.sh
new file mode 100755
index 000000000..ff357f404
--- /dev/null
+++ b/evaluation/scripts/run_mmlongbench_eval.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+set -e
+
+ROOT_DIR=$(cd "$(dirname "$0")/../.." && pwd)
+cd "$ROOT_DIR"
+export PYTHONPATH="$ROOT_DIR"
+
+# Common parameters
+LIB="supermemory"
+WORKERS=10
+TOPK=20
+ADD_MODE="fine"
+SEARCH_MODE="fine"
+VERSION_DIR="test_1230"
+ASYNC_MODE="sync"
+CHAT_MODEL="gpt-4o-mini"
+
+## Add / Ingestion
+#echo "Running mmlongbench_ingestion.py..."
+#python -m evaluation.scripts.mmlongbench.mmlongbench_ingestion \
+# --lib "$LIB" \
+# --workers "$WORKERS" \
+# --version-dir "$VERSION_DIR" \
+# --mode "$ADD_MODE" \
+# --async-mode "$ASYNC_MODE"
+#
+## Search
+#echo "Running mmlongbench_search.py..."
+#python -m evaluation.scripts.mmlongbench.mmlongbench_search \
+# --lib "$LIB" \
+# --workers "$WORKERS" \
+# --version-dir "$VERSION_DIR" \
+# --top-k "$TOPK" \
+# --mode "$SEARCH_MODE"
+
+# Eval
+echo "Running mmlongbench_eval.py..."
+python -m evaluation.scripts.mmlongbench.mmlongbench_eval \
+ --lib "$LIB" \
+ --version-dir "$VERSION_DIR" \
+ --workers "$WORKERS" \
+ --chat-model "$CHAT_MODEL"
+
+echo "All scripts completed successfully!"
diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py
index 157c3f8ea..a7168f343 100644
--- a/evaluation/scripts/utils/client.py
+++ b/evaluation/scripts/utils/client.py
@@ -1,5 +1,6 @@
import json
import os
+import re
import sys
import time
import uuid
@@ -56,30 +57,22 @@ def __init__(self, enable_graph=False):
self.enable_graph = enable_graph
def add(self, messages, user_id, timestamp, batch_size=2):
- max_retries = 5
for i in range(0, len(messages), batch_size):
batch_messages = messages[i : i + batch_size]
- for attempt in range(max_retries):
- try:
- if self.enable_graph:
- self.client.add(
- messages=batch_messages,
- timestamp=timestamp,
- user_id=user_id,
- enable_graph=True,
- )
- else:
- self.client.add(
- messages=batch_messages,
- timestamp=timestamp,
- user_id=user_id,
- )
- break
- except Exception as e:
- if attempt < max_retries - 1:
- time.sleep(2**attempt)
- else:
- raise e
+ if self.enable_graph:
+ self.client.add(
+ messages=batch_messages,
+ timestamp=timestamp,
+ user_id=user_id,
+ enable_graph=True,
+ )
+ else:
+ self.client.add(
+ messages=batch_messages,
+ timestamp=timestamp,
+ user_id=user_id,
+ infer=False,
+ )
def search(self, query, user_id, top_k):
res = self.client.search(
@@ -143,56 +136,95 @@ def string_to_uuid(self, s: str, salt="memobase_client"):
class MemosApiClient:
- def __init__(self):
- self.memos_url = os.getenv("MEMOS_URL")
- self.headers = {"Content-Type": "application/json", "Authorization": os.getenv("MEMOS_KEY")}
+ """Product Add API 封装"""
+
+ def __init__(self, timeout: float = 600.0):
+ self.base_url = os.getenv("MEMOS_URL")
+ self.headers = {"Content-Type": "application/json"}
+ self.timeout = timeout
+
+ def add(
+ self,
+ messages,
+ user_id,
+ writable_cube_ids: list[str],
+ source_type: str,
+ mode: str,
+ async_mode: str,
+ ):
+ """
+ 调用 /product/add 接口
+
+ Args:
+ messages: 添加记忆信息
+ user_id: 用户ID
+ writable_cube_ids: 可写cube ID列表
+ source_type: 来源类型
+ mode: 模式 (fine/coarse)
+ async_mode: 异步模式 (sync/async)
+ """
+ url = f"{self.base_url}/product/add"
+
+ payload = {
+ "user_id": user_id,
+ "writable_cube_ids": writable_cube_ids,
+ "messages": messages,
+ "info": {"source_type": source_type},
+ "mode": mode,
+ "async_mode": async_mode,
+ }
+
+ response = requests.post(
+ url,
+ data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
+ headers=self.headers,
+ timeout=self.timeout,
+ )
+
+ if response.status_code != 200:
+ raise RuntimeError(f"HTTP {response.status_code}: {response.text}")
- def add(self, messages, user_id, conv_id, batch_size: int = 9999):
+ body = response.json()
+ if body.get("code") is not None and body.get("code") != 200:
+ raise RuntimeError(f"BUSINESS ERROR {body.get('code')}: {response.text}")
+
+ return body
+
+ def search(self, query, user_id, readable_cube_ids: list[str], top_k: str, mode: str):
"""
- messages = [{"role": "assistant", "content": data, "chat_time": date_str}]
+ 调用 /product/search 接口
+
+ Args:
+ query: 搜索查询
+ user_id: 用户ID
+ readable_cube_ids: 可读cube ID列表, 默认为[user_id]
+ top_k: 返回结果数量
"""
- url = f"{self.memos_url}/product/add"
- added_memories = []
- for i in range(0, len(messages), batch_size):
- batch_messages = messages[i : i + batch_size]
- payload = json.dumps(
- {
- "messages": batch_messages,
- "user_id": user_id,
- "mem_cube_id": user_id,
- "conversation_id": conv_id,
- }
- )
- response = requests.request("POST", url, data=payload, headers=self.headers)
- assert response.status_code == 200, response.text
- assert json.loads(response.text)["message"] == "Memory added successfully", (
- response.text
- )
- added_memories += json.loads(response.text)["data"]
- return added_memories
- def search(self, query, user_id, top_k):
- """Search memories."""
- url = f"{self.memos_url}/product/search"
- payload = json.dumps(
- {
- "query": query,
- "user_id": user_id,
- "mem_cube_id": user_id,
- "conversation_id": "",
- "top_k": top_k,
- "mode": os.getenv("SEARCH_MODE", "fast"),
- "include_preference": True,
- "pref_top_k": 6,
- },
- ensure_ascii=False,
- )
- response = requests.request("POST", url, data=payload, headers=self.headers)
- assert response.status_code == 200, response.text
- assert json.loads(response.text)["message"] == "Search completed successfully", (
- response.text
+ url = f"{self.base_url}/product/search"
+
+ if readable_cube_ids is None:
+ readable_cube_ids = [user_id]
+
+ payload = {
+ "query": query,
+ "user_id": user_id,
+ "readable_cube_ids": readable_cube_ids,
+ "top_k": top_k,
+ "mode": mode,
+ }
+
+ response = requests.post(
+ url,
+ data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
+ headers=self.headers,
+ timeout=self.timeout,
)
- return json.loads(response.text)["data"]
+
+ if response.status_code != 200:
+ raise RuntimeError(f"HTTP {response.status_code}: {response.text}")
+
+ return response.json()
class MemosApiOnlineClient:
@@ -276,44 +308,65 @@ def search(self, query, user_id, top_k):
class SupermemoryClient:
def __init__(self):
- from supermemory import Supermemory
-
- self.client = Supermemory(api_key=os.getenv("SUPERMEMORY_API_KEY"))
-
- def add(self, messages, user_id):
- content = "\n".join(
- [f"{msg['chat_time']} {msg['role']}: {msg['content']}" for msg in messages]
- )
- max_retries = 5
- for attempt in range(max_retries):
- try:
- self.client.memories.add(content=content, container_tag=user_id)
- break
- except Exception as e:
- if attempt < max_retries - 1:
- time.sleep(2**attempt)
- else:
- raise e
-
- def search(self, query, user_id, top_k):
- max_retries = 10
- for attempt in range(max_retries):
- try:
- results = self.client.search.memories(
- q=query,
- container_tag=user_id,
- threshold=0,
- rerank=True,
- rewrite_query=True,
- limit=top_k,
- )
- context = "\n\n".join([r.memory for r in results.results])
- return context
- except Exception as e:
- if attempt < max_retries - 1:
- time.sleep(2**attempt)
- else:
- raise e
+ self.api_key = os.getenv("SUPERMEMORY_API_KEY")
+ if not self.api_key:
+ raise ValueError(
+ "SUPERMEMORY_API_KEY environment variable is not set. Please set it in your .env file or environment."
+ )
+ self.add_url = "https://api.supermemory.ai/v3/documents"
+ self.search_url = "https://api.supermemory.ai/v3/search"
+
+ def _sanitize_tag(self, s: str) -> str:
+ t = str(s).strip()
+ t = os.path.splitext(t)[0]
+ t = t.replace(" ", "_")
+ t = re.sub(r"[^A-Za-z0-9_-]", "_", t)
+ t = re.sub(r"[_-]+", "_", t)
+ t = t.strip("_")
+ t = t.lower()
+ if not re.match(r"^[a-z0-9]", t or ""):
+ t = f"tag_{t}" if t else "tag_default"
+ return t
+
+ def add(self, content: str, user_id: str):
+ payload = {
+ "content": content,
+ "raw": content,
+ "containerTag": self._sanitize_tag(user_id),
+ }
+
+ headers = {
+ "Authorization": f"Bearer {self.api_key}",
+ "Content-Type": "application/json",
+ }
+
+ resp = requests.post(self.add_url, json=payload, headers=headers)
+ resp.raise_for_status()
+ return resp.json()
+
+ def search(self, query: str, user_id: str, top_k: int):
+ payload = {
+ "q": query,
+ "limit": top_k,
+ "containerTags": [self._sanitize_tag(user_id)],
+ "rerank": True,
+ }
+
+ headers = {
+ "Authorization": f"Bearer {self.api_key}",
+ "Content-Type": "application/json",
+ }
+ resp = requests.post(self.search_url, json=payload, headers=headers)
+ resp.raise_for_status()
+ data = resp.json()
+
+ chunk_list = []
+ res = [entry.get("chunks") for entry in data.get("results", [])]
+ for chunks in res:
+ for chunk in chunks:
+ chunk_list.append(chunk["content"])
+
+ return chunk_list
class MemuClient:
diff --git a/evaluation/scripts/utils/eval_score.py b/evaluation/scripts/utils/eval_score.py
new file mode 100644
index 000000000..02ef6eb53
--- /dev/null
+++ b/evaluation/scripts/utils/eval_score.py
@@ -0,0 +1,246 @@
+import re
+
+from collections import defaultdict
+from math import isclose
+
+
+def levenshtein_distance(s1, s2):
+ if len(s1) > len(s2):
+ s1, s2 = s2, s1
+
+ distances = range(len(s1) + 1)
+ for i2, c2 in enumerate(s2):
+ distances_ = [i2 + 1]
+ for i1, c1 in enumerate(s1):
+ if c1 == c2:
+ distances_.append(distances[i1])
+ else:
+ distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
+ distances = distances_
+ return distances[-1]
+
+
+def anls_compute(groundtruth, prediction, threshold=0.5):
+ dist = levenshtein_distance(groundtruth, prediction)
+ length = max(len(groundtruth.upper()), len(prediction.upper()))
+ value = 0.0 if length == 0 else float(dist) / float(length)
+ anls = 1.0 - value
+ if anls <= threshold:
+ anls = 0.0
+ return anls
+
+
+def is_float_equal(
+ reference, prediction, include_percentage: bool = False, is_close: float = False
+) -> bool:
+ def get_precision(gt_ans: float) -> int:
+ precision = 3
+ if "." in str(gt_ans):
+ precision = len(str(gt_ans).split(".")[-1])
+ return precision
+
+ reference = float(str(reference).strip().rstrip("%").strip())
+ try:
+ prediction = float(str(prediction).strip().rstrip("%").strip())
+ except Exception:
+ return False
+
+ gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]
+ for item in gt_result:
+ try:
+ if is_close and isclose(item, prediction, rel_tol=0.01):
+ return True
+ precision = max(min(get_precision(prediction), get_precision(item)), 2)
+ if round(prediction, precision) == round(item, precision):
+ return True
+ except Exception:
+ continue
+ return False
+
+
+def get_clean_string(s):
+ s = str(s).lower().strip()
+
+ for suffix in ["mile", "miles", "million"]:
+ if s.endswith(suffix):
+ s = s[: -len(suffix)].strip()
+
+ s = re.sub(r"\s*\([^)]*\)", "", s).strip()
+ s = re.sub(r"^['\"]|['\"]$", "", s).strip()
+ s = s.lstrip("$").rstrip("%").strip()
+
+ return s
+
+
+def is_exact_match(s):
+ flag = False
+ # Website
+ if "https://" in s:
+ flag = True
+ # code file
+ if s.endswith((".py", ".ipynb")) or s.startswith("page"):
+ flag = True
+ # telephone number
+ if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s):
+ flag = True
+ # time
+ if "a.m." in s or "p.m." in s:
+ flag = True
+ # YYYY-MM-DD
+ if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s):
+ flag = True
+ # YYYY-MM
+ if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s):
+ flag = True
+ # Email address
+ if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s):
+ flag = True
+ return flag
+
+
+def isfloat(num):
+ try:
+ float(num)
+ return True
+ except ValueError:
+ return False
+
+
+def eval_score(gt, pred, answer_type):
+ if answer_type == "Int":
+ try:
+ gt, pred = int(gt), int(float(pred))
+ except Exception:
+ pred = ""
+ score = gt == pred
+ elif answer_type == "Float":
+ try:
+ gt = float(get_clean_string(str(gt)))
+ pred = float(get_clean_string(str(pred)))
+ except Exception:
+ pred = ""
+ score = is_float_equal(gt, pred, include_percentage=True, is_close=True)
+ elif answer_type in ["Str", "None"]:
+ gt = get_clean_string(gt)
+ pred = get_clean_string(pred)
+ score = gt == pred if is_exact_match(gt) else anls_compute(gt, pred)
+ else:
+ if isinstance(gt, str) and gt.startswith("["):
+ gt = eval(gt)
+ if not isinstance(gt, list):
+ gt = [gt]
+ if isinstance(pred, str) and pred.startswith("["):
+ pred = eval(pred)
+ if not isinstance(pred, list):
+ pred = [pred]
+ print(len(gt), len(pred))
+ if len(gt) != len(pred):
+ score = 0.0
+ else:
+ gt = sorted([get_clean_string(a) for a in gt])
+ pred = sorted([get_clean_string(a) for a in pred])
+ print(gt, pred)
+ if isfloat(gt[0]) or is_exact_match(gt[0]):
+ score = "-".join(gt) == "-".join(pred)
+ else:
+ score = min(
+ [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)]
+ )
+
+ return float(score)
+
+
+def eval_acc_and_f1(samples):
+ evaluated_samples = [sample for sample in samples if "score" in sample]
+ if not evaluated_samples:
+ return 0.0, 0.0
+
+ acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples)
+ try:
+ recall = sum(
+ [
+ sample["score"]
+ for sample in evaluated_samples
+ if sample["answer"] != "Not answerable"
+ ]
+ ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"])
+ precision = sum(
+ [
+ sample["score"]
+ for sample in evaluated_samples
+ if sample["answer"] != "Not answerable"
+ ]
+ ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"])
+ f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0
+ except Exception:
+ f1 = 0.0
+
+ return acc, f1
+
+
+def show_results(samples, show_path=None):
+ for sample in samples:
+ sample["evidence_pages"] = eval(sample["evidence_pages"])
+ sample["evidence_sources"] = eval(sample["evidence_sources"])
+
+ with open(show_path, "w") as f:
+ acc, f1 = eval_acc_and_f1(samples)
+ f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n")
+ f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n")
+ f.write("-----------------------\n")
+
+ acc_single_page, _ = eval_acc_and_f1(
+ [sample for sample in samples if len(sample["evidence_pages"]) == 1]
+ )
+ acc_multi_page, _ = eval_acc_and_f1(
+ [
+ sample
+ for sample in samples
+ if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable"
+ ]
+ )
+ acc_neg, _ = eval_acc_and_f1(
+ [sample for sample in samples if sample["answer"] == "Not answerable"]
+ )
+
+ f.write(
+ "Single-page | Accuracy: {} | Question Number: {}\n".format(
+ acc_single_page,
+ len([sample for sample in samples if len(sample["evidence_pages"]) == 1]),
+ )
+ )
+ f.write(
+ "Cross-page | Accuracy: {} | Question Number: {}\n".format(
+ acc_multi_page,
+ len(
+ [
+ sample
+ for sample in samples
+ if len(sample["evidence_pages"]) != 1
+ and sample["answer"] != "Not answerable"
+ ]
+ ),
+ )
+ )
+ f.write(
+ "Unanswerable | Accuracy: {} | Question Number: {}\n".format(
+ acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"])
+ )
+ )
+ f.write("-----------------------\n")
+
+ source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list)
+ for sample in samples:
+ for answer_source in sample["evidence_sources"]:
+ source_sample_dict[answer_source].append(sample)
+ document_type_dict[sample["doc_type"]].append(sample)
+ for type, sub_samples in source_sample_dict.items():
+ f.write(
+ f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n"
+ )
+
+ f.write("-----------------------\n")
+ for type, sub_samples in document_type_dict.items():
+ f.write(
+ f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n"
+ )
diff --git a/evaluation/scripts/utils/extract_answer.py b/evaluation/scripts/utils/extract_answer.py
new file mode 100644
index 000000000..e527d4b97
--- /dev/null
+++ b/evaluation/scripts/utils/extract_answer.py
@@ -0,0 +1,58 @@
+import os
+
+from pathlib import Path
+
+import openai
+
+from dotenv import load_dotenv
+
+
+load_dotenv()
+
+client = openai.Client(
+ api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
+ base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+)
+
+PROMPT_PATH = Path("evaluation/scripts/utils/prompt_for_answer_extraction.md")
+with open(PROMPT_PATH, encoding="utf-8") as f:
+ EXTRACTION_PROMPT = f.read()
+
+
+def extract_answer(question: str, output: str, model_name: str = "gpt-4o-mini") -> str:
+ resp = client.chat.completions.create(
+ model=model_name,
+ messages=[
+ {"role": "user", "content": EXTRACTION_PROMPT},
+ {"role": "assistant", "content": f"\n\nQuestion:{question}\nAnalysis:{output}\n"},
+ ],
+ temperature=0.0,
+ max_tokens=256,
+ top_p=1,
+ frequency_penalty=0,
+ presence_penalty=0,
+ )
+ content = resp.choices[0].message.content or ""
+ return content
+
+
+def parse_extracted_answer(extracted_res: str, fallback_output: str) -> str:
+ try:
+ head = extracted_res.split("Answer format:")[0]
+ ans = head.split("Extracted answer:")[1].strip()
+ if ans:
+ return ans
+ except Exception:
+ pass
+ text = (fallback_output or "").strip()
+ low = text.lower()
+ if " yes" in low or low.startswith("yes"):
+ return "yes"
+ if " no" in low or low.startswith("no"):
+ return "no"
+ for sep in ["\n", ". ", ".", "?", "!"]:
+ if sep in text:
+ cand = text.split(sep)[0].strip()
+ if cand:
+ return cand
+ return text
diff --git a/evaluation/scripts/utils/metrics.py b/evaluation/scripts/utils/metrics.py
new file mode 100644
index 000000000..135a60cec
--- /dev/null
+++ b/evaluation/scripts/utils/metrics.py
@@ -0,0 +1,56 @@
+import threading
+
+
+class Metrics:
+ def __init__(self):
+ self.times_ms: list[float] = []
+ self.success_count = 0
+ self.fail_count = 0
+ self.errors = {}
+ self.lock = threading.Lock()
+
+ def record(self, duration_s: float, success: bool, error_msg: str | None = None):
+ ms = duration_s * 1000.0
+ with self.lock:
+ if success:
+ self.times_ms.append(ms)
+ self.success_count += 1
+ else:
+ self.fail_count += 1
+ if error_msg:
+ short_err = error_msg[:200] if len(error_msg) > 200 else error_msg
+ self.errors[short_err] = self.errors.get(short_err, 0) + 1
+
+ def summary(self) -> dict:
+ with self.lock:
+ if not self.times_ms:
+ return {
+ "stats": {},
+ "counts": {"success": self.success_count, "failed": self.fail_count},
+ "errors": dict(self.errors),
+ }
+ sorted_times = sorted(self.times_ms)
+ n = len(sorted_times)
+
+ def percentile(p: int):
+ if n == 1:
+ return sorted_times[0]
+ k = max(0, min(n - 1, round((p / 100) * (n - 1))))
+ return sorted_times[k]
+
+ mean = sum(sorted_times) / n
+ variance = sum((x - mean) ** 2 for x in sorted_times) / (n - 1) if n > 1 else 0.0
+ return {
+ "stats": {
+ "count": n,
+ "mean": mean,
+ "median": percentile(50),
+ "min": sorted_times[0],
+ "max": sorted_times[-1],
+ "p95": percentile(95),
+ "p99": percentile(99),
+ "std": variance**0.5,
+ },
+ "counts": {"success": self.success_count, "failed": self.fail_count},
+ "errors": dict(self.errors),
+ }
diff --git a/evaluation/scripts/utils/prompt_for_answer_extraction.md b/evaluation/scripts/utils/prompt_for_answer_extraction.md
new file mode 100644
index 000000000..a309c0935
--- /dev/null
+++ b/evaluation/scripts/utils/prompt_for_answer_extraction.md
@@ -0,0 +1,35 @@
+Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis.
+- Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List. If you find the analysis the question can not be answered from the given documents, type "Not answerable". Exception: If the analysis only tells you that it can not read/understand the images or documents, type "Fail to answer".
+- Please make your response as concise as possible. Also note that your response should be formatted as below:
+```
+Extracted answer: [answer]
+Answer format: [answer format]
+```
+
+Please read the following example, then extract the answer from the model response and type it at the end of the prompt.
+
+---
+Question: List the primary questions asked about the services in this report.
+Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n1. Is the service safe?\n2. Is the service effective?\n3. Is the service caring?\n4. Is the service responsive?\n5. Is the service well-led?
+Extracted answer: ['Is the servife safe?', 'Is the service effective', 'Is the serve caring?', 'Is the service responsive?', 'Is the service well-led?']
+Answer format: List
+
+---
+Question: How many regulations of the HSCA 2008 are breached in all according to this report?
+Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities) Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11: Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17: Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9. Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing, safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to notify the CQC of incidents.
+Extracted answer: 10
+Answer format: Integer
+
+---
+Question: According to the survey that is the percentage of Chinese who are paying more or about the same attention to politics after Trump's election?
+Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying more or about the same attention to politics after Trump's election. The report focuses primarily on American demographics and does not include specific details about the Chinese population in relation to this question. If you need information about a different demographic or a summary of the findings from the American demographic, I can certainly help with that!
+Extracted answer: Not answerable
+Answer format: String
+
+---
+Question: How many quotations from male respondent over 50 years old are included in this report?
+Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be able to help you with your question.
+Extracted answer: Fail to answer
+Answer format: String
+
+---
diff --git a/evaluation/scripts/utils/prompts.py b/evaluation/scripts/utils/prompts.py
index 32e6d6729..ba5f5db8f 100644
--- a/evaluation/scripts/utils/prompts.py
+++ b/evaluation/scripts/utils/prompts.py
@@ -65,6 +65,55 @@
{context}
"""
+MMLONGBENCH_ANSWER_PROMPT = """
+ You are a helpful assistant that can answer questions based on the provided memories and images.
+
+ {memories}
+
+ Read the above memories and answer this question
+ Please make your answer as concise as possible.
+"""
+
+LONGBENCH_V2_ANSWER_PROMPT = """
+Please read the following retrieved text chunks and answer the question below.
+
+
+$DOC$
+
+
+What is the correct answer to this question: $Q$
+Choices:
+(A) $C_A$
+(B) $C_B$
+(C) $C_C$
+(D) $C_D$
+
+Format your response as follows: "The correct answer is (insert answer here)".
+"""
+
+
+HOTPOT_ANSWER_PROMPT = """
+You are answering a question from the HotpotQA dataset.
+
+The question may require multi-hop reasoning across multiple supporting facts.
+Carefully read the provided context and identify the relevant evidence.
+Reason step by step to connect the facts and determine the correct answer.
+
+Important instructions:
+- Use only the information provided in the context.
+- Perform multi-step reasoning internally if needed.
+- The final answer must be a short factual answer (e.g., a name, place, date, or entity).
+- Do NOT include explanations, reasoning steps, or citations in the final output.
+
+Question:
+{question}
+
+Context:
+{context}
+
+Final Answer:
+
+"""
ZEP_CONTEXT_TEMPLATE = """
FACTS and ENTITIES represent relevant context to the current conversation.
diff --git a/evaluation/scripts/xinyu/eval_docs.py b/evaluation/scripts/xinyu/eval_docs.py
deleted file mode 100644
index 03a333201..000000000
--- a/evaluation/scripts/xinyu/eval_docs.py
+++ /dev/null
@@ -1,228 +0,0 @@
-import csv
-import json
-import os
-import re
-import traceback
-
-from concurrent.futures import ThreadPoolExecutor, as_completed
-
-from dotenv import load_dotenv
-
-from evaluation.scripts.mmlongbench.eval.extract_answer import extract_answer
-from evaluation.scripts.xinyu.eval.eval_score_llm import eval_acc_and_f1, eval_score, show_results
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_os.main import MOS
-
-
-load_dotenv()
-openapi_config = {
- "model_name_or_path": "gpt-4o",
- "top_k": 50,
- "remove_think_prefix": True,
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
-}
-neo4j_uri = os.getenv("NEO4J_URI", "bolt://47.117.41.207:7687")
-db_name = "stx-mmlongbench-003"
-doc_paths = [
- f
- for f in os.listdir("evaluation/data/xinyu/documents")
- if os.path.isfile(os.path.join("evaluation/data/xinyu/documents", f))
-]
-
-with open("evaluation/data/xinyu/all_samples_with_gt.json") as f:
- samples = json.load(f)
-
-
-def get_user_name(doc_file):
- csv_path = "evaluation/data/xinyu/user_doc_map.csv"
- if os.path.exists(csv_path):
- with open(csv_path, newline="", encoding="utf-8") as f:
- reader = csv.reader(f)
- for row in reader:
- uid, path = row[0], row[1]
- base = os.path.basename(path)
- if base == doc_file or os.path.splitext(base)[0] == os.path.splitext(doc_file)[0]:
- return uid
- return ""
-
-
-def process_doc(doc_file):
- user_name = get_user_name(doc_file)
- print(user_name, doc_file)
- config = {
- "user_id": user_name,
- "chat_model": {
- "backend": "openai",
- "config": openapi_config,
- },
- "mem_reader": {
- "backend": "simple_struct",
- "config": {
- "llm": {"backend": "openai", "config": openapi_config},
- "embedder": {
- "backend": "universal_api",
- "config": {
- "provider": "openai",
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "model_name_or_path": "text-embedding-3-large",
- "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
- "chunker": {
- "backend": "sentence",
- "config": {
- "tokenizer_or_token_counter": "gpt2",
- "chunk_size": 512,
- "chunk_overlap": 128,
- "min_sentences_per_chunk": 1,
- },
- },
- },
- },
- "max_turns_window": 20,
- "top_k": 5,
- "enable_textual_memory": True,
- "enable_activation_memory": False,
- "enable_parametric_memory": False,
- }
- mos_config = MOSConfig(**config)
- mos = MOS(mos_config)
-
- mem_cube_config = GeneralMemCubeConfig.model_validate(
- {
- "user_id": user_name,
- "cube_id": user_name,
- "text_mem": {
- "backend": "tree_text",
- "config": {
- "extractor_llm": {"backend": "openai", "config": openapi_config},
- "dispatcher_llm": {"backend": "openai", "config": openapi_config},
- "graph_db": {
- "backend": "neo4j",
- "config": {
- "uri": neo4j_uri,
- "user": "neo4j",
- "password": "iaarlichunyu",
- "db_name": db_name,
- "user_name": user_name,
- "use_multi_db": False,
- "auto_create": True,
- "embedding_dimension": 3072,
- },
- },
- "embedder": {
- "backend": "universal_api",
- "config": {
- "provider": "openai",
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "model_name_or_path": "text-embedding-3-large",
- "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
- "reorganize": False,
- },
- },
- "act_mem": {},
- "para_mem": {},
- }
- )
- mem_cube = GeneralMemCube(mem_cube_config)
-
- temp_dir = os.path.join("tmp", doc_file)
-
- if (not os.path.exists(temp_dir)) or (not os.listdir(temp_dir)):
- mem_cube.dump(temp_dir)
-
- mos.register_mem_cube(temp_dir, mem_cube_id=user_name)
-
- with open("evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md") as f:
- prompt = f.read()
-
- samples_res = []
- doc_samples = [s for s in samples if s.get("doc_id") == doc_file]
-
- if len(doc_samples) == 0:
- return []
-
- sample = doc_samples[0]
- question_list = sample["question"]
- answer_list = sample["answer"]
-
- for idx, question in enumerate(question_list):
- gt = answer_list.get(str(idx))
-
- try_cnt, is_success = 0, False
- while True:
- try:
- mos.clear_messages()
- response = mos.chat(question, user_name)
- is_success = True
- except Exception as e:
- print(f"[{doc_file}] Error:", e)
- traceback.print_exc()
- try_cnt += 1
- response = "Failed"
- if is_success or try_cnt > 5:
- break
-
- sample_item = dict(sample)
- sample_item["question"] = question
- sample_item["answer"] = gt
- sample_item["response"] = response
-
- extracted_res = extract_answer(sample_item["question"], response, prompt)
- sample_item["extracted_res"] = extracted_res
-
- print("--------------------------------------")
- pred_ans = extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip()
- score = eval_score(question, gt, response)
-
- sample_item["pred"] = pred_ans
- sample_item["score"] = score
- samples_res.append(sample_item)
-
- print(f"Question: {question}")
- print(f"Response: {sample_item['response']}")
- print(f"Ground true: {gt}\tPred: {sample_item['pred']}\tScore: {sample_item['score']}")
-
- print("samples_res length: ", len(samples_res))
- return samples_res
-
-
-if __name__ == "__main__":
- results = []
-
- with ThreadPoolExecutor(max_workers=4) as executor:
- future_to_doc = {executor.submit(process_doc, doc_file): doc_file for doc_file in doc_paths}
-
- for future in as_completed(future_to_doc):
- doc_file = future_to_doc[future]
- try:
- res = future.result()
- results.extend(res)
-
- if len(res) > 0:
- acc, f1 = eval_acc_and_f1(results)
- print()
- print(f"Avg acc: {acc}")
- print(f"Avg f1: {f1}")
-
- with open("evaluation/data/xinyu/test_results.json", "w", encoding="utf-8") as f:
- json.dump(results, f, ensure_ascii=False, indent=2)
-
- except Exception as e:
- print(f"[{doc_file}] failed with {e}")
- traceback.print_exc()
-
- acc, f1 = eval_acc_and_f1(results)
- print("--------------------------------------")
- print(f"Final avg acc: {acc}")
- print(f"Final avg f1: {f1}")
-
- show_results(
- results,
- show_path=re.sub(r"\.json$", ".txt", "evaluation/data/xinyu/test_results_report.txt"),
- )
diff --git a/evaluation/scripts/xinyu/import_docs.py b/evaluation/scripts/xinyu/import_docs.py
deleted file mode 100644
index f9d2619a4..000000000
--- a/evaluation/scripts/xinyu/import_docs.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import asyncio
-import os
-import traceback
-import uuid
-
-from memos import log
-from memos.configs.mem_reader import SimpleStructMemReaderConfig
-from memos.configs.memory import TreeTextMemoryConfig
-from memos.mem_reader.simple_struct import SimpleStructMemReader
-from memos.memories.textual.tree import TreeTextMemory
-
-
-logger = log.get_logger(__name__)
-db_name = "stx-mmlongbench-003"
-# Create a memory reader instance
-reader_config = SimpleStructMemReaderConfig.from_json_file(
- "examples/data/config/simple_struct_reader_config.json"
-)
-reader = SimpleStructMemReader(reader_config)
-
-tree_config = TreeTextMemoryConfig.from_json_file(
- "examples/data/config/tree_config_shared_database.json"
-)
-tree_config.graph_db.config.db_name = db_name
-# Processing Documents
-existing_names = {
- d for d in os.listdir("ppt_test_result") if os.path.isdir(os.path.join("ppt_test_result", d))
-}
-doc_paths = []
-for f in os.listdir("evaluation/data/xinyu/documents"):
- fp = os.path.join("evaluation/data/xinyu/documents", f)
- if os.path.isfile(fp):
- name = os.path.splitext(f)[0]
- if name in existing_names:
- continue
- doc_paths.append(fp)
-
-
-async def process_doc(doc_path):
- print(f"🔄 Processing document: {doc_path}")
- doc_file = doc_path.split("/")[-1].rsplit(".", 1)[0]
-
- # Generate random user id: 'user_' + random short hex
- user_id = "user_" + uuid.uuid4().hex[:8]
- # Persist mapping between user_id and doc_path
- with open("evaluation/data/xinyu/user_doc_map.csv", "a", encoding="utf-8") as f:
- f.write(f"{user_id},{doc_path}\n")
-
- tree_config.graph_db.config.user_name = user_id
- temp_dir = "tmp/" + doc_file
- my_tree_textual_memory = TreeTextMemory(tree_config)
- doc_memory = await reader.get_memory(
- [doc_path], "doc", info={"user_id": user_id, "session_id": "session_" + str(uuid.uuid4())}
- )
-
- count = 0
- for m_list in doc_memory:
- count += len(m_list)
- my_tree_textual_memory.add(m_list)
- print("total memories: ", count)
-
- my_tree_textual_memory.dump(temp_dir)
- return doc_path
-
-
-async def main():
- batch_size = 2
- for i in range(0, len(doc_paths), batch_size):
- batch = doc_paths[i : i + batch_size]
- print(f"🚀 Starting batch {i // batch_size + 1} with {len(batch)} docs")
-
- tasks = [process_doc(p) for p in batch]
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- for p, result in zip(batch, results, strict=False):
- if isinstance(result, Exception):
- print(f"❌ Error processing {p}: {result}")
- tb_text = "".join(traceback.TracebackException.from_exception(result).format())
- print(tb_text)
- else:
- print(f"✅ Finished {result}")
-
-
-if __name__ == "__main__":
- asyncio.run(main())
From c5f19f2d42cfd8cf4024ca38f273a9ebf4e896e0 Mon Sep 17 00:00:00 2001
From: stx <31013941@qq.com>
Date: Sun, 4 Jan 2026 11:01:32 +0800
Subject: [PATCH 3/4] feat: add evaluation pipline
---
.../scripts/mmlongbench/eval/__init__.py | 0
.../scripts/mmlongbench/eval/eval_score.py | 246 ----------------
.../mmlongbench/eval/extract_answer.py | 33 ---
.../eval/prompt_for_answer_extraction.md | 35 ---
evaluation/scripts/mmlongbench/eval_docs.py | 265 ------------------
evaluation/scripts/mmlongbench/import_docs.py | 88 ------
.../scripts/mmlongbench/models/__init__.py | 0
.../scripts/mmlongbench/multimodal_test.py | 185 ------------
8 files changed, 852 deletions(-)
delete mode 100644 evaluation/scripts/mmlongbench/eval/__init__.py
delete mode 100644 evaluation/scripts/mmlongbench/eval/eval_score.py
delete mode 100644 evaluation/scripts/mmlongbench/eval/extract_answer.py
delete mode 100644 evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md
delete mode 100644 evaluation/scripts/mmlongbench/eval_docs.py
delete mode 100644 evaluation/scripts/mmlongbench/import_docs.py
delete mode 100644 evaluation/scripts/mmlongbench/models/__init__.py
delete mode 100644 evaluation/scripts/mmlongbench/multimodal_test.py
diff --git a/evaluation/scripts/mmlongbench/eval/__init__.py b/evaluation/scripts/mmlongbench/eval/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/evaluation/scripts/mmlongbench/eval/eval_score.py b/evaluation/scripts/mmlongbench/eval/eval_score.py
deleted file mode 100644
index 02ef6eb53..000000000
--- a/evaluation/scripts/mmlongbench/eval/eval_score.py
+++ /dev/null
@@ -1,246 +0,0 @@
-import re
-
-from collections import defaultdict
-from math import isclose
-
-
-def levenshtein_distance(s1, s2):
- if len(s1) > len(s2):
- s1, s2 = s2, s1
-
- distances = range(len(s1) + 1)
- for i2, c2 in enumerate(s2):
- distances_ = [i2 + 1]
- for i1, c1 in enumerate(s1):
- if c1 == c2:
- distances_.append(distances[i1])
- else:
- distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
- distances = distances_
- return distances[-1]
-
-
-def anls_compute(groundtruth, prediction, threshold=0.5):
- dist = levenshtein_distance(groundtruth, prediction)
- length = max(len(groundtruth.upper()), len(prediction.upper()))
- value = 0.0 if length == 0 else float(dist) / float(length)
- anls = 1.0 - value
- if anls <= threshold:
- anls = 0.0
- return anls
-
-
-def is_float_equal(
- reference, prediction, include_percentage: bool = False, is_close: float = False
-) -> bool:
- def get_precision(gt_ans: float) -> int:
- precision = 3
- if "." in str(gt_ans):
- precision = len(str(gt_ans).split(".")[-1])
- return precision
-
- reference = float(str(reference).strip().rstrip("%").strip())
- try:
- prediction = float(str(prediction).strip().rstrip("%").strip())
- except Exception:
- return False
-
- gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]
- for item in gt_result:
- try:
- if is_close and isclose(item, prediction, rel_tol=0.01):
- return True
- precision = max(min(get_precision(prediction), get_precision(item)), 2)
- if round(prediction, precision) == round(item, precision):
- return True
- except Exception:
- continue
- return False
-
-
-def get_clean_string(s):
- s = str(s).lower().strip()
-
- for suffix in ["mile", "miles", "million"]:
- if s.endswith(suffix):
- s = s[: -len(suffix)].strip()
-
- s = re.sub(r"\s*\([^)]*\)", "", s).strip()
- s = re.sub(r"^['\"]|['\"]$", "", s).strip()
- s = s.lstrip("$").rstrip("%").strip()
-
- return s
-
-
-def is_exact_match(s):
- flag = False
- # Website
- if "https://" in s:
- flag = True
- # code file
- if s.endswith((".py", ".ipynb")) or s.startswith("page"):
- flag = True
- # telephone number
- if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s):
- flag = True
- # time
- if "a.m." in s or "p.m." in s:
- flag = True
- # YYYY-MM-DD
- if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s):
- flag = True
- # YYYY-MM
- if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s):
- flag = True
- # Email address
- if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s):
- flag = True
- return flag
-
-
-def isfloat(num):
- try:
- float(num)
- return True
- except ValueError:
- return False
-
-
-def eval_score(gt, pred, answer_type):
- if answer_type == "Int":
- try:
- gt, pred = int(gt), int(float(pred))
- except Exception:
- pred = ""
- score = gt == pred
- elif answer_type == "Float":
- try:
- gt = float(get_clean_string(str(gt)))
- pred = float(get_clean_string(str(pred)))
- except Exception:
- pred = ""
- score = is_float_equal(gt, pred, include_percentage=True, is_close=True)
- elif answer_type in ["Str", "None"]:
- gt = get_clean_string(gt)
- pred = get_clean_string(pred)
- score = gt == pred if is_exact_match(gt) else anls_compute(gt, pred)
- else:
- if isinstance(gt, str) and gt.startswith("["):
- gt = eval(gt)
- if not isinstance(gt, list):
- gt = [gt]
- if isinstance(pred, str) and pred.startswith("["):
- pred = eval(pred)
- if not isinstance(pred, list):
- pred = [pred]
- print(len(gt), len(pred))
- if len(gt) != len(pred):
- score = 0.0
- else:
- gt = sorted([get_clean_string(a) for a in gt])
- pred = sorted([get_clean_string(a) for a in pred])
- print(gt, pred)
- if isfloat(gt[0]) or is_exact_match(gt[0]):
- score = "-".join(gt) == "-".join(pred)
- else:
- score = min(
- [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)]
- )
-
- return float(score)
-
-
-def eval_acc_and_f1(samples):
- evaluated_samples = [sample for sample in samples if "score" in sample]
- if not evaluated_samples:
- return 0.0, 0.0
-
- acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples)
- try:
- recall = sum(
- [
- sample["score"]
- for sample in evaluated_samples
- if sample["answer"] != "Not answerable"
- ]
- ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"])
- precision = sum(
- [
- sample["score"]
- for sample in evaluated_samples
- if sample["answer"] != "Not answerable"
- ]
- ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"])
- f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0
- except Exception:
- f1 = 0.0
-
- return acc, f1
-
-
-def show_results(samples, show_path=None):
- for sample in samples:
- sample["evidence_pages"] = eval(sample["evidence_pages"])
- sample["evidence_sources"] = eval(sample["evidence_sources"])
-
- with open(show_path, "w") as f:
- acc, f1 = eval_acc_and_f1(samples)
- f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n")
- f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n")
- f.write("-----------------------\n")
-
- acc_single_page, _ = eval_acc_and_f1(
- [sample for sample in samples if len(sample["evidence_pages"]) == 1]
- )
- acc_multi_page, _ = eval_acc_and_f1(
- [
- sample
- for sample in samples
- if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable"
- ]
- )
- acc_neg, _ = eval_acc_and_f1(
- [sample for sample in samples if sample["answer"] == "Not answerable"]
- )
-
- f.write(
- "Single-page | Accuracy: {} | Question Number: {}\n".format(
- acc_single_page,
- len([sample for sample in samples if len(sample["evidence_pages"]) == 1]),
- )
- )
- f.write(
- "Cross-page | Accuracy: {} | Question Number: {}\n".format(
- acc_multi_page,
- len(
- [
- sample
- for sample in samples
- if len(sample["evidence_pages"]) != 1
- and sample["answer"] != "Not answerable"
- ]
- ),
- )
- )
- f.write(
- "Unanswerable | Accuracy: {} | Question Number: {}\n".format(
- acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"])
- )
- )
- f.write("-----------------------\n")
-
- source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list)
- for sample in samples:
- for answer_source in sample["evidence_sources"]:
- source_sample_dict[answer_source].append(sample)
- document_type_dict[sample["doc_type"]].append(sample)
- for type, sub_samples in source_sample_dict.items():
- f.write(
- f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n"
- )
-
- f.write("-----------------------\n")
- for type, sub_samples in document_type_dict.items():
- f.write(
- f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n"
- )
diff --git a/evaluation/scripts/mmlongbench/eval/extract_answer.py b/evaluation/scripts/mmlongbench/eval/extract_answer.py
deleted file mode 100644
index b7f7e6863..000000000
--- a/evaluation/scripts/mmlongbench/eval/extract_answer.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import os
-
-import openai
-
-from dotenv import load_dotenv
-
-
-load_dotenv()
-client = openai.Client(
- api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
-)
-
-
-def extract_answer(question, output, prompt, model_name="gpt-4o"):
- response = client.chat.completions.create(
- model=model_name,
- messages=[
- {
- "role": "user",
- "content": prompt,
- },
- {"role": "assistant", "content": f"\n\nQuestion:{question}\nAnalysis:{output}\n"},
- ],
- temperature=0.0,
- max_tokens=256,
- top_p=1,
- frequency_penalty=0,
- presence_penalty=0,
- )
- response = response.choices[0].message.content
-
- return response
diff --git a/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md b/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md
deleted file mode 100644
index a309c0935..000000000
--- a/evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md
+++ /dev/null
@@ -1,35 +0,0 @@
-Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis.
-- Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List. If you find the analysis the question can not be answered from the given documents, type "Not answerable". Exception: If the analysis only tells you that it can not read/understand the images or documents, type "Fail to answer".
-- Please make your response as concise as possible. Also note that your response should be formatted as below:
-```
-Extracted answer: [answer]
-Answer format: [answer format]
-```
-
-Please read the following example, then extract the answer from the model response and type it at the end of the prompt.
-
----
-Question: List the primary questions asked about the services in this report.
-Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n1. Is the service safe?\n2. Is the service effective?\n3. Is the service caring?\n4. Is the service responsive?\n5. Is the service well-led?
-Extracted answer: ['Is the servife safe?', 'Is the service effective', 'Is the serve caring?', 'Is the service responsive?', 'Is the service well-led?']
-Answer format: List
-
----
-Question: How many regulations of the HSCA 2008 are breached in all according to this report?
-Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities) Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11: Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17: Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9. Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing, safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to notify the CQC of incidents.
-Extracted answer: 10
-Answer format: Integer
-
----
-Question: According to the survey that is the percentage of Chinese who are paying more or about the same attention to politics after Trump's election?
-Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying more or about the same attention to politics after Trump's election. The report focuses primarily on American demographics and does not include specific details about the Chinese population in relation to this question. If you need information about a different demographic or a summary of the findings from the American demographic, I can certainly help with that!
-Extracted answer: Not answerable
-Answer format: String
-
----
-Question: How many quotations from male respondent over 50 years old are included in this report?
-Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be able to help you with your question.
-Extracted answer: Fail to answer
-Answer format: String
-
----
diff --git a/evaluation/scripts/mmlongbench/eval_docs.py b/evaluation/scripts/mmlongbench/eval_docs.py
deleted file mode 100644
index 510a0b1ed..000000000
--- a/evaluation/scripts/mmlongbench/eval_docs.py
+++ /dev/null
@@ -1,265 +0,0 @@
-import csv
-import json
-import os
-import re
-import traceback
-
-from concurrent.futures import ThreadPoolExecutor, as_completed
-
-from dotenv import load_dotenv
-from eval.eval_score import eval_acc_and_f1, eval_score, show_results
-from eval.extract_answer import extract_answer
-from tqdm import tqdm
-
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_os.main import MOS
-
-
-load_dotenv()
-openapi_config = {
- "model_name_or_path": "gpt-4o",
- "top_k": 50,
- "remove_think_prefix": True,
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
-}
-neo4j_uri = os.getenv("NEO4J_URI", "bolt://47.117.41.207:7687")
-db_name = "stx-mmlongbench-004"
-
-doc_paths = [
- f
- for f in os.listdir("evaluation/data/mmlongbench/documents")
- if os.path.isfile(os.path.join("evaluation/data/mmlongbench/documents", f))
-]
-
-with open("evaluation/data/mmlongbench/samples.json") as f:
- samples = json.load(f)
-
-RESULTS_PATH = "evaluation/data/mmlongbench/test_results.json"
-completed_pairs: set[tuple[str, str]] = set()
-
-
-def _load_existing_results():
- global completed_pairs
- if os.path.exists(RESULTS_PATH):
- try:
- with open(RESULTS_PATH, encoding="utf-8") as f:
- existing = json.load(f)
- for r in existing:
- did = r.get("doc_id")
- q = r.get("question")
- if did and q:
- completed_pairs.add((did, q))
- return existing
- except Exception:
- return []
- return []
-
-
-def _doc_has_pending(doc_file: str) -> bool:
- for s in samples:
- if s.get("doc_id") == doc_file and (doc_file, s.get("question")) not in completed_pairs:
- return True
- return False
-
-
-def get_user_name(doc_file):
- csv_path = "evaluation/data/mmlongbench/user_doc_map.csv"
- if os.path.exists(csv_path):
- with open(csv_path, newline="", encoding="utf-8") as f:
- reader = csv.reader(f)
- for row in reader:
- uid, path = row[0], row[1]
- base = os.path.basename(path)
- if base == doc_file or os.path.splitext(base)[0] == os.path.splitext(doc_file)[0]:
- return uid
- return ""
-
-
-def process_doc(doc_file):
- user_name = get_user_name(doc_file)
- print(user_name, doc_file)
- config = {
- "user_id": user_name,
- "chat_model": {
- "backend": "openai",
- "config": openapi_config,
- },
- "mem_reader": {
- "backend": "simple_struct",
- "config": {
- "llm": {"backend": "openai", "config": openapi_config},
- "embedder": {
- "backend": "universal_api",
- "config": {
- "provider": "openai",
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "model_name_or_path": "text-embedding-3-large",
- "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
- "chunker": {
- "backend": "sentence",
- "config": {
- "tokenizer_or_token_counter": "gpt2",
- "chunk_size": 512,
- "chunk_overlap": 128,
- "min_sentences_per_chunk": 1,
- },
- },
- },
- },
- "max_turns_window": 20,
- "top_k": 5,
- "enable_textual_memory": True,
- "enable_activation_memory": False,
- "enable_parametric_memory": False,
- }
- mos_config = MOSConfig(**config)
- mos = MOS(mos_config)
-
- mem_cube_config = GeneralMemCubeConfig.model_validate(
- {
- "user_id": user_name,
- "cube_id": user_name,
- "text_mem": {
- "backend": "tree_text",
- "config": {
- "extractor_llm": {"backend": "openai", "config": openapi_config},
- "dispatcher_llm": {"backend": "openai", "config": openapi_config},
- "graph_db": {
- "backend": "neo4j",
- "config": {
- "uri": neo4j_uri,
- "user": "neo4j",
- "password": "iaarlichunyu",
- "db_name": db_name,
- "user_name": user_name,
- "use_multi_db": False,
- "auto_create": True,
- "embedding_dimension": 3072,
- },
- },
- "embedder": {
- "backend": "universal_api",
- "config": {
- "provider": "openai",
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "model_name_or_path": "text-embedding-3-large",
- "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
- "reorganize": False,
- },
- },
- "act_mem": {},
- "para_mem": {},
- }
- )
- mem_cube = GeneralMemCube(mem_cube_config)
-
- temp_dir = "tmp/" + doc_file
- if not os.path.exists(temp_dir) or not os.listdir(temp_dir):
- mem_cube.dump(temp_dir)
-
- mos.register_mem_cube(temp_dir, mem_cube_id=user_name)
-
- with open("evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md") as f:
- prompt = f.read()
-
- samples_res = []
- doc_samples = [s for s in samples if s.get("doc_id") == doc_file]
- if len(doc_samples) == 0:
- return []
-
- for sample in tqdm(doc_samples, desc=f"Processing {doc_file}"):
- if (doc_file, sample.get("question")) in completed_pairs:
- continue
- messages = sample["question"]
- try_cnt, is_success = 0, False
-
- while True:
- try:
- mos.clear_messages()
- response = mos.chat(messages, user_name)
- is_success = True
- except Exception as e:
- print(f"[{doc_file}] Error:", e)
- traceback.print_exc()
- try_cnt += 1
- response = "Failed"
- if is_success or try_cnt > 5:
- break
-
- sample["response"] = response
- extracted_res = extract_answer(sample["question"], response, prompt)
- sample["extracted_res"] = extracted_res
-
- pred_ans = extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip()
- score = eval_score(sample["answer"], pred_ans, sample["answer_format"])
-
- sample["pred"] = pred_ans
- sample["score"] = score
- samples_res.append(sample)
-
- print("--------------------------------------")
- print(f"Question: {sample['question']}")
- print(f"Response: {sample['response']}")
- print(f"Ground true: {sample['answer']}\tPred: {sample['pred']}\tScore: {sample['score']}")
-
- return samples_res
-
-
-if __name__ == "__main__":
- results = _load_existing_results()
- total_samples = len(samples)
- processed_samples = len(completed_pairs)
- pending_samples = total_samples - processed_samples
- sample_doc_ids = [s.get("doc_id") for s in samples if s.get("doc_id")]
- all_docs_in_samples = set(sample_doc_ids)
- processed_docs = {d for d, _ in completed_pairs}
- with ThreadPoolExecutor(max_workers=4) as executor:
- pending_docs = [d for d in doc_paths if _doc_has_pending(d)]
- print("\n" + "=" * 80)
- print("📊 评测进度统计")
- print("=" * 80)
- print(f"✅ 已加载历史结果: {len(results)} 条")
- print(f"📂 数据集总样本: {total_samples}")
- print(f"🧪 已完成样本: {processed_samples}")
- print(f"⏳ 待处理样本: {pending_samples}")
- print(f"📄 数据集中总文档: {len(all_docs_in_samples)}")
- print(f"✔️ 已完成文档: {len(processed_docs)}")
- print(f"➡️ 待处理文档(本次将运行): {len(pending_docs)}")
- print("=" * 80 + "\n")
- future_to_doc = {
- executor.submit(process_doc, doc_file): doc_file for doc_file in pending_docs
- }
-
- for future in as_completed(future_to_doc):
- doc_file = future_to_doc[future]
- try:
- res = future.result()
- results.extend(res)
-
- if len(res) > 0:
- acc, f1 = eval_acc_and_f1(results)
- print()
- print(f"Avg acc: {acc}")
- print(f"Avg f1: {f1}")
-
- with open(RESULTS_PATH, "w", encoding="utf-8") as f:
- json.dump(results, f, ensure_ascii=False, indent=2)
- except Exception as e:
- print(f"[{doc_file}] failed with {e}")
-
- acc, f1 = eval_acc_and_f1(results)
- print("--------------------------------------")
- print(f"Final avg acc: {acc}")
- print(f"Final avg f1: {f1}")
-
- show_results(
- results,
- show_path=re.sub(r"\.json$", ".txt", "evaluation/data/mmlongbench/test_results_report.txt"),
- )
diff --git a/evaluation/scripts/mmlongbench/import_docs.py b/evaluation/scripts/mmlongbench/import_docs.py
deleted file mode 100644
index 540c8f960..000000000
--- a/evaluation/scripts/mmlongbench/import_docs.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import asyncio
-import os
-import traceback
-import uuid
-
-from memos import log
-from memos.configs.mem_reader import SimpleStructMemReaderConfig
-from memos.configs.memory import TreeTextMemoryConfig
-from memos.mem_reader.simple_struct import SimpleStructMemReader
-from memos.memories.textual.tree import TreeTextMemory
-
-
-logger = log.get_logger(__name__)
-db_name = "stx-mmlongbench-004"
-# Create a memory reader instance
-reader_config = SimpleStructMemReaderConfig.from_json_file(
- "examples/data/config/simple_struct_reader_config.json"
-)
-reader = SimpleStructMemReader(reader_config)
-
-tree_config = TreeTextMemoryConfig.from_json_file(
- "examples/data/config/tree_config_shared_database.json"
-)
-tree_config.graph_db.config.db_name = db_name
-# Processing Documents
-existing_names = {
- d for d in os.listdir("ppt_test_result") if os.path.isdir(os.path.join("ppt_test_result", d))
-}
-doc_paths = []
-for f in os.listdir("evaluation/data/mmlongbench/documents"):
- fp = os.path.join("evaluation/data/mmlongbench/documents", f)
- if os.path.isfile(fp):
- name = os.path.splitext(f)[0]
- if name in existing_names:
- continue
- doc_paths.append(fp)
-
-print("existing_names length:", len(existing_names))
-print("doc_paths length:", len(doc_paths))
-
-
-async def process_doc(doc_path):
- print(f"🔄 Processing document: {doc_path}")
- # Generate random user id: 'user_' + random short hex
- user_id = "user_" + uuid.uuid4().hex[:8]
- # Persist mapping between user_id and doc_path
- try:
- os.makedirs("evaluation/data/mmlongbench", exist_ok=True)
- with open("evaluation/data/mmlongbench/user_doc_map.csv", "a", encoding="utf-8") as f:
- f.write(f"{user_id},{doc_path}\n")
- except Exception as e:
- logger.error(f"Failed to write user-doc mapping: {e}")
-
- tree_config.graph_db.config.user_name = user_id
- my_tree_textual_memory = TreeTextMemory(tree_config)
- doc_memory = await reader.get_memory(
- [doc_path], "doc", info={"user_id": user_id, "session_id": "session_" + str(uuid.uuid4())}
- )
-
- count = 0
- for m_list in doc_memory:
- count += len(m_list)
- my_tree_textual_memory.add(m_list)
- print("total memories: ", count)
-
- return doc_path
-
-
-async def main():
- batch_size = 4
- for i in range(0, len(doc_paths), batch_size):
- batch = doc_paths[i : i + batch_size]
- print(f"🚀 Starting batch {i // batch_size + 1} with {len(batch)} docs")
-
- tasks = [process_doc(p) for p in batch]
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- for p, result in zip(batch, results, strict=False):
- if isinstance(result, Exception):
- print(f"❌ Error processing {p}: {result}")
- tb_text = "".join(traceback.TracebackException.from_exception(result).format())
- print(tb_text)
- else:
- print(f"✅ Finished {result}")
-
-
-if __name__ == "__main__":
- asyncio.run(main())
diff --git a/evaluation/scripts/mmlongbench/models/__init__.py b/evaluation/scripts/mmlongbench/models/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/evaluation/scripts/mmlongbench/multimodal_test.py b/evaluation/scripts/mmlongbench/multimodal_test.py
deleted file mode 100644
index 929215229..000000000
--- a/evaluation/scripts/mmlongbench/multimodal_test.py
+++ /dev/null
@@ -1,185 +0,0 @@
-import os
-import shutil
-
-from dotenv import load_dotenv
-
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_os.main import MOS
-
-
-load_dotenv()
-
-db_name = "stx-mmlongbench-002"
-user_id = "user_dc812220"
-
-# 1.1 Set openai config
-openapi_config = {
- "model_name_or_path": "gpt-4o",
- "top_k": 50,
- "remove_think_prefix": True,
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
-}
-# 1.2 Set neo4j config
-neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
-
-# 1.3 Create MOS Config
-config = {
- "user_id": user_id,
- "chat_model": {
- "backend": "openai",
- "config": openapi_config,
- },
- "mem_reader": {
- "backend": "simple_struct",
- "config": {
- "llm": {
- "backend": "openai",
- "config": openapi_config,
- },
- "embedder": {
- "backend": "universal_api",
- "config": {
- "provider": "openai",
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "model_name_or_path": "text-embedding-3-large",
- "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
- "chunker": {
- "backend": "sentence",
- "config": {
- "tokenizer_or_token_counter": "gpt2",
- "chunk_size": 512,
- "chunk_overlap": 128,
- "min_sentences_per_chunk": 1,
- },
- },
- },
- },
- "max_turns_window": 20,
- "top_k": 5,
- "enable_textual_memory": True,
- "enable_activation_memory": False,
- "enable_parametric_memory": False,
-}
-
-mos_config = MOSConfig(**config)
-mos = MOS(mos_config)
-
-config = GeneralMemCubeConfig.model_validate(
- {
- "user_id": user_id,
- "cube_id": f"{user_id}",
- "text_mem": {
- "backend": "tree_text",
- "config": {
- "extractor_llm": {
- "backend": "openai",
- "config": openapi_config,
- },
- "dispatcher_llm": {
- "backend": "openai",
- "config": openapi_config,
- },
- "graph_db": {
- "backend": "neo4j",
- "config": {
- "uri": neo4j_uri,
- "user": "neo4j",
- "password": "iaarlichunyu",
- "db_name": db_name,
- "user_name": user_id,
- "use_multi_db": False,
- "auto_create": True,
- "embedding_dimension": 3072,
- },
- },
- "embedder": {
- "backend": "universal_api",
- "config": {
- "provider": "openai",
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "model_name_or_path": "text-embedding-3-large",
- "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
- "reorganize": False,
- },
- },
- "act_mem": {},
- "para_mem": {},
- },
-)
-
-
-# Filter out embedding fields, keeping only necessary fields
-def filter_memory_data(memories_data):
- filtered_data = {}
- for key, value in memories_data.items():
- if key == "text_mem":
- filtered_data[key] = []
- for mem_group in value:
- # Check if it's the new data structure (list of TextualMemoryItem objects)
- if "memories" in mem_group and isinstance(mem_group["memories"], list):
- # New data structure: directly a list of TextualMemoryItem objects
- filtered_memories = []
- for memory_item in mem_group["memories"]:
- # Create filtered dictionary
- filtered_item = {
- "id": memory_item.id,
- "memory": memory_item.memory,
- "metadata": {},
- }
- # Filter metadata, excluding embedding
- if hasattr(memory_item, "metadata") and memory_item.metadata:
- for attr_name in dir(memory_item.metadata):
- if not attr_name.startswith("_") and attr_name != "embedding":
- attr_value = getattr(memory_item.metadata, attr_name)
- if not callable(attr_value):
- filtered_item["metadata"][attr_name] = attr_value
- filtered_memories.append(filtered_item)
-
- filtered_group = {
- "cube_id": mem_group.get("cube_id", ""),
- "memories": filtered_memories,
- }
- filtered_data[key].append(filtered_group)
- else:
- # Old data structure: dictionary with nodes and edges
- filtered_group = {
- "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])}
- }
- for node in mem_group["memories"].get("nodes", []):
- filtered_node = {
- "id": node.get("id"),
- "memory": node.get("memory"),
- "metadata": {
- k: v
- for k, v in node.get("metadata", {}).items()
- if k != "embedding"
- },
- }
- filtered_group["memories"]["nodes"].append(filtered_node)
- filtered_data[key].append(filtered_group)
- else:
- filtered_data[key] = value
- return filtered_data
-
-
-mem_cube = GeneralMemCube(config)
-
-temp_dir = f"/tmp/{user_id}"
-if os.path.exists(temp_dir):
- shutil.rmtree(temp_dir)
-mem_cube.dump(temp_dir)
-mos.register_mem_cube(temp_dir, mem_cube_id=user_id)
-
-
-print("start answering...")
-user_query = "图8美股变化的影响是什么"
-print(f"👤 User query: {user_query}")
-response = mos.chat(user_query)
-print(f"🤖 Response: {response}")
From 4d335db7a61b29d24a6bcce302506e394e47f1b0 Mon Sep 17 00:00:00 2001
From: stx <31013941@qq.com>
Date: Sun, 4 Jan 2026 11:02:50 +0800
Subject: [PATCH 4/4] feat: add evaluation pipline
---
.../mmlongbench/models/minicpm_llama3.py | 56 -------------------
1 file changed, 56 deletions(-)
delete mode 100644 evaluation/scripts/mmlongbench/models/minicpm_llama3.py
diff --git a/evaluation/scripts/mmlongbench/models/minicpm_llama3.py b/evaluation/scripts/mmlongbench/models/minicpm_llama3.py
deleted file mode 100644
index 7f6d4b743..000000000
--- a/evaluation/scripts/mmlongbench/models/minicpm_llama3.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import torch
-
-from PIL import Image
-from transformers import AutoModel, AutoTokenizer
-
-
-def init_model(cache_path):
- model_path = (
- cache_path
- if (cache_path is not None and cache_path != "None")
- else "openbmb/MiniCPM-Llama3-V-2_5"
- )
- model = AutoModel.from_pretrained(
- model_path,
- torch_dtype=torch.bfloat16,
- low_cpu_mem_usage=True,
- trust_remote_code=True,
- device_map="auto",
- ).eval()
-
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
- model.tokenizer = tokenizer
- return model
-
-
-def get_response_concat(model, question, image_path_list, max_new_tokens=1024, temperature=1.0):
- msgs = []
- system_prompt = "Answer in detail."
- if system_prompt:
- msgs.append({"type": "text", "value": system_prompt})
- if isinstance(image_path_list, list):
- msgs.extend([{"type": "image", "value": p} for p in image_path_list])
- else:
- msgs = [{"type": "image", "value": image_path_list}]
- msgs.append({"type": "text", "value": question})
-
- content = []
- for x in msgs:
- if x["type"] == "text":
- content.append(x["value"])
- elif x["type"] == "image":
- image = Image.open(x["value"]).convert("RGB")
- content.append(image)
- msgs = [{"role": "user", "content": content}]
-
- with torch.cuda.amp.autocast():
- res = model.chat(
- msgs=msgs,
- context=None,
- image=None,
- max_new_tokens=max_new_tokens,
- temperature=temperature,
- do_sample=temperature != 0.0,
- tokenizer=model.tokenizer,
- )
- return res