diff --git a/.gitignore b/.gitignore index 8319a4d2f..c6b9130e6 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ tmp/ **settings.json** evaluation/*tmp/ evaluation/results -evaluation/.env +.env !evaluation/configs-example/*.json evaluation/configs/* **tree_textual_memory_locomo** diff --git a/evaluation/data/personamem/.gitkeep b/evaluation/data/personamem/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/hotpot/data_loader.py b/evaluation/scripts/hotpot/data_loader.py new file mode 100644 index 000000000..871981036 --- /dev/null +++ b/evaluation/scripts/hotpot/data_loader.py @@ -0,0 +1,78 @@ +import json + +from pathlib import Path + +from datasets import load_dataset + + +def load_hotpot_data(data_dir: Path | str) -> list[dict]: + """ + Load HotpotQA dataset. + If dev_distractor_gold.json exists in data_dir, load it. + Otherwise, download from Hugging Face, convert to standard format, save, and load. + """ + data_dir = Path(data_dir) + data_dir.mkdir(parents=True, exist_ok=True) + file_path = data_dir / "dev_distractor_gold.json" + + if file_path.exists(): + print(f"Loading local dataset from {file_path}") + try: + with open(file_path, encoding="utf-8") as f: + return json.load(f) + except Exception as e: + print(f"Failed to load local file: {e}. Re-downloading...") + + print("Downloading HotpotQA dataset from Hugging Face...") + try: + dataset = load_dataset( + "hotpotqa/hotpot_qa", "distractor", split="validation", trust_remote_code=True + ) + except Exception as e: + print(f"Failed to download dataset: {e}") + raise + + print(f"Processing and saving dataset to {file_path}...") + items = [] + for item in dataset: + # Convert HF format to Standard format + # ID + qid = item.get("id") or item.get("_id") + + # Supporting Facts + sp = item.get("supporting_facts") + if isinstance(sp, dict): + sp_titles = sp.get("title", []) + sp_sent_ids = sp.get("sent_id", []) + sp_list = list(zip(sp_titles, sp_sent_ids, strict=False)) + else: + sp_list = sp or [] + + # Context + ctx = item.get("context") + if isinstance(ctx, dict): + ctx_titles = ctx.get("title", []) + ctx_sentences = ctx.get("sentences", []) + ctx_list = list(zip(ctx_titles, ctx_sentences, strict=False)) + else: + ctx_list = ctx or [] + + new_item = { + "_id": qid, + "question": item.get("question"), + "answer": item.get("answer"), + "supporting_facts": sp_list, + "context": ctx_list, + "type": item.get("type"), + "level": item.get("level"), + } + items.append(new_item) + + try: + with open(file_path, "w", encoding="utf-8") as f: + json.dump(items, f, ensure_ascii=False, indent=2) + print(f"Saved {len(items)} items to {file_path}") + except Exception as e: + print(f"Failed to save dataset: {e}") + + return items diff --git a/evaluation/scripts/hotpot/hotpot_eval.py b/evaluation/scripts/hotpot/hotpot_eval.py new file mode 100644 index 000000000..80315a65f --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_eval.py @@ -0,0 +1,231 @@ +import argparse +import importlib.util +import json +import os +import time + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm + +from evaluation.scripts.hotpot.data_loader import load_hotpot_data +from evaluation.scripts.utils.extract_answer import extract_answer, parse_extracted_answer +from evaluation.scripts.utils.metrics import Metrics +from evaluation.scripts.utils.prompts import HOTPOT_ANSWER_PROMPT + + +load_dotenv() + + +def llm_response( + oai_client, chat_model: str, context: str, question: str, question_date: str | None = None +) -> str: + prompt = HOTPOT_ANSWER_PROMPT.format(question=question, context=context) + resp = oai_client.chat.completions.create( + model=chat_model, + messages=[{"role": "system", "content": prompt}], + temperature=0, + ) + return resp.choices[0].message.content or "" + + +def _load_json_list(path: Path) -> list[dict]: + data = json.loads(path.read_text(encoding="utf-8")) + if isinstance(data, list): + return data + if isinstance(data, dict) and isinstance(data.get("results"), list): + return data.get("results") or [] + raise ValueError(f"Invalid json format: {path}") + + +def _save_pred( + pred_path: Path, pred_answers: dict, pred_sp: dict, perf: dict | None = None +) -> None: + pred_path.parent.mkdir(parents=True, exist_ok=True) + tmp = pred_path.with_suffix(pred_path.suffix + ".tmp") + safe_pred_answers = { + k: (v if isinstance(v, str) else ("" if v is None else str(v))) + for k, v in pred_answers.items() + } + obj = {"answer": safe_pred_answers, "sp": pred_sp} + if perf is not None: + obj["perf"] = perf + tmp.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, pred_path) + + +def run_eval(pred_path: Path, gold_path: Path): + spec = importlib.util.spec_from_file_location( + "hotpot_eval_v1", "evaluation/scripts/hotpot/hotpot_evaluate_v1.py" + ) + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + metrics = m.eval(str(pred_path), str(gold_path)) + + # Save metrics to pred_path (beginning of file) + try: + results_path = pred_path + current_data = {} + if results_path.exists(): + with open(results_path, encoding="utf-8") as f: + current_data = json.load(f) + + if isinstance(current_data, list): + new_data = [metrics, *current_data] + + elif isinstance(current_data, dict): + # Put metrics at the beginning + new_data = metrics.copy() + for k, v in current_data.items(): + if k not in new_data: + new_data[k] = v + else: + new_data = metrics + + with open(results_path, "w", encoding="utf-8") as f: + json.dump(new_data, f, indent=2, ensure_ascii=False) + except Exception as e: + print(f"Failed to save metrics to {results_path}: {e}") + + +def evaluate_one(oai_client, row: dict, chat_model: str) -> tuple[str, str, list]: + qid = str(row.get("_id")) + question = row.get("question") or "" + context = row.get("context") or "" + sp_list = row.get("sp") or [] + + raw_answer = llm_response(oai_client, chat_model, context=context, question=question) + extracted_res = extract_answer(question, raw_answer) + answer = parse_extracted_answer(extracted_res, raw_answer) + return qid, answer, sp_list + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + description="HotpotQA eval (OpenAI only, read search results)." + ) + parser.add_argument( + "--lib", + type=str, + default="memos", + choices=["memos", "mem0", "supermemory"], + ) + parser.add_argument("--workers", type=int, default=8) + parser.add_argument("--max_samples", type=int, default=None) + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + parser.add_argument("--chat-model", default=None, help="Chat model name") + parser.add_argument("--search-mode", default="fine", help="Search mode") + + args = parser.parse_args(argv) + + output_dir = Path(f"evaluation/data/hotpot/{args.version_dir}") + output_dir.mkdir(parents=True, exist_ok=True) + + if args.lib == "memos": + search_path = output_dir / f"{args.lib}_{args.search_mode}_search_results.json" + pred_path = output_dir / f"{args.lib}_{args.search_mode}_search_eval_results.json" + else: + search_path = output_dir / f"{args.lib}_search_results.json" + pred_path = output_dir / f"{args.lib}_eval_results.json" + gold_path = Path("evaluation/data/hotpot/dev_distractor_gold.json") + + if not search_path.exists(): + raise FileNotFoundError(f"Search results not found: {search_path}") + + if not gold_path.exists(): + load_hotpot_data("evaluation/data/hotpot") + + pred_answers: dict[str, str] = {} + pred_sp: dict[str, list] = {} + if pred_path.exists(): + try: + prev = json.loads(pred_path.read_text(encoding="utf-8")) + if isinstance(prev, dict) and isinstance(prev.get("answer"), dict): + pred_answers.update(prev["answer"]) + if isinstance(prev, dict) and isinstance(prev.get("sp"), dict): + pred_sp.update(prev["sp"]) + except Exception as e: + print(f"[Eval] failed to load existing pred: {e}") + + rows = _load_json_list(search_path) + if args.max_samples is not None: + rows = rows[: args.max_samples] + + pending = [r for r in rows if str(r.get("_id")) not in pred_answers] + print(f"[Eval] lib={args.lib} total={len(rows)} pending={len(pending)} workers={args.workers}") + if not pending: + run_eval(pred_path, gold_path) + return + + oai_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL") + ) + + processed = len(pred_answers) + + metrics = Metrics() + start_time = time.time() + + print("[Response model]: ", args.chat_model) + with ThreadPoolExecutor(max_workers=args.workers) as executor: + + def do_eval(row): + st = time.perf_counter() + try: + res = evaluate_one(oai_client, row, args.chat_model) + dur = time.perf_counter() - st + metrics.record(dur, True) + return res + except Exception as e: + dur = time.perf_counter() - st + metrics.record(dur, False, str(e)) + raise e + + futures = [executor.submit(do_eval, row) for row in pending] + for idx, f in enumerate( + tqdm(as_completed(futures), total=len(futures), desc="Evaluating"), 1 + ): + try: + qid, answer, sp_list = f.result() + pred_answers[qid] = answer + pred_sp[qid] = sp_list + processed += 1 + if idx % 20 == 0: + _save_pred(pred_path, pred_answers, pred_sp) + except Exception as e: + print(f"[Eval] Error: {e}") + + _save_pred(pred_path, pred_answers, pred_sp) + + # Save performance metrics (merge into pred json) + total_duration = time.time() - start_time + summary = metrics.summary() + perf_obj = { + "summary": summary, + "total_duration": total_duration, + "config": { + "workers": args.workers, + "chat_model": args.chat_model or os.getenv("CHAT_MODEL"), + "lib": args.lib, + }, + } + _save_pred(pred_path, pred_answers, pred_sp, perf=perf_obj) + run_eval(pred_path, gold_path) + + print("\n" + "=" * 60) + print("Evaluation finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["errors"]: + print("\nError stats:") + for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" [{count} times] {error[:100]}...") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/hotpot/hotpot_evaluate_v1.py b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py new file mode 100644 index 000000000..19f09996d --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_evaluate_v1.py @@ -0,0 +1,156 @@ +import re +import string +import sys + +from collections import Counter + +import ujson as json + + +def normalize_answer(s): + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + zero_metric = (0, 0, 0) + + if ( + normalized_prediction in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): + return zero_metric + if ( + normalized_ground_truth in ["yes", "no", "noanswer"] + and normalized_prediction != normalized_ground_truth + ): + return zero_metric + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return zero_metric + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, precision, recall + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def update_answer(metrics, prediction, gold): + em = exact_match_score(prediction, gold) + f1, prec, recall = f1_score(prediction, gold) + metrics["em"] += float(em) + metrics["f1"] += f1 + metrics["prec"] += prec + metrics["recall"] += recall + return em, prec, recall + + +def update_sp(metrics, prediction, gold): + cur_sp_pred = set(map(tuple, prediction)) + gold_sp_pred = set(map(tuple, gold)) + tp, fp, fn = 0, 0, 0 + for e in cur_sp_pred: + if e in gold_sp_pred: + tp += 1 + else: + fp += 1 + for e in gold_sp_pred: + if e not in cur_sp_pred: + fn += 1 + prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 + recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 + f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 + em = 1.0 if fp + fn == 0 else 0.0 + metrics["sp_em"] += em + metrics["sp_f1"] += f1 + metrics["sp_prec"] += prec + metrics["sp_recall"] += recall + return em, prec, recall + + +def eval(prediction_file, gold_file): + with open(prediction_file) as f: + prediction = json.load(f) + with open(gold_file) as f: + gold = json.load(f) + + evaluated_ids = set((prediction.get("answer") or {}).keys()) + gold = [dp for dp in gold if (dp.get("_id") or dp.get("id")) in evaluated_ids] + + metrics = { + "em": 0, + "f1": 0, + "prec": 0, + "recall": 0, + "sp_em": 0, + "sp_f1": 0, + "sp_prec": 0, + "sp_recall": 0, + "joint_em": 0, + "joint_f1": 0, + "joint_prec": 0, + "joint_recall": 0, + } + for dp in gold: + cur_id = dp["_id"] + can_eval_joint = True + if cur_id not in prediction["answer"]: + can_eval_joint = False + else: + em, prec, recall = update_answer(metrics, prediction["answer"][cur_id], dp["answer"]) + if cur_id not in prediction["sp"]: + can_eval_joint = False + else: + sp_em, sp_prec, sp_recall = update_sp( + metrics, prediction["sp"][cur_id], dp["supporting_facts"] + ) + + if can_eval_joint: + joint_prec = prec * sp_prec + joint_recall = recall * sp_recall + if joint_prec + joint_recall > 0: + joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) + else: + joint_f1 = 0.0 + joint_em = em * sp_em + + metrics["joint_em"] += joint_em + metrics["joint_f1"] += joint_f1 + metrics["joint_prec"] += joint_prec + metrics["joint_recall"] += joint_recall + + print("=========Eval Results===========") + n = len(gold) + if n > 0: + for k in metrics: + metrics[k] /= n + print(metrics) + else: + print(metrics) + return metrics + + +if __name__ == "__main__": + eval(sys.argv[1], sys.argv[2]) diff --git a/evaluation/scripts/hotpot/hotpot_ingestion.py b/evaluation/scripts/hotpot/hotpot_ingestion.py new file mode 100644 index 000000000..adbd3471c --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_ingestion.py @@ -0,0 +1,241 @@ +import argparse +import json +import os +import time + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from dotenv import load_dotenv +from tqdm import tqdm + +from evaluation.scripts.hotpot.data_loader import load_hotpot_data +from evaluation.scripts.utils.metrics import Metrics + + +load_dotenv() + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + return func(*args, **kwargs) + except Exception as e: + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def _get_lib_client(lib: str): + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + + +def _load_added_ids(records_path: Path) -> set[str]: + if not records_path.exists(): + return set() + try: + obj = json.loads(records_path.read_text(encoding="utf-8")) + ids = obj.get("added_ids") if isinstance(obj, dict) else None + if isinstance(ids, list): + return {str(x) for x in ids if x} + except Exception: + return set() + return set() + + +def _save_added_ids(records_path: Path, added_ids: set[str], perf: dict | None = None) -> None: + records_path.parent.mkdir(parents=True, exist_ok=True) + tmp = records_path.with_suffix(records_path.suffix + ".tmp") + obj = {"added_ids": sorted(added_ids)} + if perf is not None: + obj["perf"] = perf + tmp.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, records_path) + + +def _build_memory_texts(ctx: dict | list | None) -> list[str]: + tasks: list[str] = [] + for item in ctx: + if not isinstance(item, list | tuple) or len(item) != 2: + continue + title, sentences = item + if not isinstance(sentences, list): + continue + for idx, sentence in enumerate(sentences): + tasks.append( + json.dumps({"idx": idx, "title": title, "sentence": sentence}, ensure_ascii=False) + ) + return tasks + + +def add_context_memories( + client, + lib: str, + user_id: str, + ctx: dict | list | None, + mode: str = "fine", + async_mode: str = "sync", +) -> None: + tasks = _build_memory_texts(ctx) + if not tasks: + return + + if lib == "memos": + messages = [{"type": "text", "text": content} for content in tasks] + writable_cube_ids = [user_id] + retry_operation( + client.add, + messages=messages, + user_id=user_id, + writable_cube_ids=writable_cube_ids, + source_type="batch_import", + mode=mode, + async_mode=async_mode, + ) + return + + if lib == "mem0": + ts = int(time.time()) + messages = [{"role": "user", "content": content} for content in tasks] + retry_operation(client.add, messages=messages, user_id=user_id, timestamp=ts, batch_size=10) + return + + if lib == "supermemory": + for content in tasks: + retry_operation(client.add, content=content, user_id=user_id) + + +def ingest_one(client, lib: str, item: dict, version_dir: str) -> str: + qid = item.get("_id") or item.get("id") + ctx = item.get("context") + + user_id = version_dir + "_" + str(qid) + add_context_memories(client, lib, user_id, ctx) + return str(qid) + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser(description="HotpotQA ingestion (add only).") + parser.add_argument( + "--lib", + type=str, + default="memos", + choices=["memos", "mem0", "supermemory"], + ) + parser.add_argument("--workers", type=int, default=8) + parser.add_argument("--limit", type=int, default=None) + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + parser.add_argument( + "--mode", default="fine", choices=["fine", "fast"], help="Processing mode (default: fine)" + ) + parser.add_argument( + "--async-mode", default="sync", choices=["sync", "async"], help="Async mode (default: sync)" + ) + + args = parser.parse_args(argv) + + print("=" * 60) + print("hotpotQA Product Add Concurrent Tool") + print("=" * 60) + + output_dir = Path("evaluation/data/hotpot") + if args.version_dir: + output_dir = output_dir / args.version_dir + output_dir.mkdir(parents=True, exist_ok=True) + + items_list = load_hotpot_data("evaluation/data/hotpot") + if args.limit is not None: + items_list = items_list[: args.limit] + + records_path = output_dir / f"{args.lib}_added_records.json" + added_ids = _load_added_ids(records_path) + pending_items = [] + for it in items_list: + qid = it.get("_id") or it.get("id") + if str(qid) not in added_ids: + pending_items.append(it) + + print(f"[Add] lib={args.lib} total={len(items_list)} pending={len(pending_items)}") + if not pending_items: + return + + client = _get_lib_client(args.lib) + metrics = Metrics() + + def do_ingest(item): + start_time = time.perf_counter() + try: + sid = ingest_one(client, args.lib, item, args.version_dir) + duration = time.perf_counter() - start_time + metrics.record(duration, True) + return sid + except Exception as e: + duration = time.perf_counter() - start_time + metrics.record(duration, False, str(e)) + raise e + + start_time = time.time() + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = [executor.submit(do_ingest, it) for it in pending_items] + for idx, f in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Adding"), 1): + try: + sid = f.result() + if sid: + added_ids.add(str(sid)) + if idx % 20 == 0: + _save_added_ids(records_path, added_ids) + except Exception as e: + print(f"[Add] Error: {e}") + + _save_added_ids(records_path, added_ids) + print(f"[Add] saved records to {records_path}") + + total_duration = time.time() - start_time + summary = metrics.summary() + perf_obj = { + "summary": summary, + "total_duration": total_duration, + "config": { + "workers": args.workers, + "mode": args.mode, + "async_mode": args.async_mode, + "lib": args.lib, + }, + } + _save_added_ids(records_path, added_ids, perf=perf_obj) + + print("\n" + "=" * 60) + print("Ingestion finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/hotpot/hotpot_old.py b/evaluation/scripts/hotpot/hotpot_old.py new file mode 100644 index 000000000..1e986431d --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_old.py @@ -0,0 +1,309 @@ +import importlib.util +import json +import os +import sys +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm + + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from utils.client import MemosApiClient +from utils.prompts import LME_ANSWER_PROMPT, MEMOS_CONTEXT_TEMPLATE + +from evaluation.scripts.hotpot.data_loader import load_hotpot_data +from memos.reranker.strategies.dialogue_common import extract_texts_and_sp_from_sources + + +load_dotenv() +os.environ["SEARCH_MODE"] = os.environ.get("SEARCH_MODE", "fine") +client = MemosApiClient() +oai_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL") +) + +pred_answers = {} +pred_sp = {} +output_dir = "evaluation/data/hotpot/output" +os.makedirs(output_dir, exist_ok=True) +pred_path = os.path.join(output_dir, "dev_distractor_pred.json") +gold_path = "evaluation/data/hotpot/dev_distractor_gold.json" + + +def add_context_memories(user_id: str, ctx: dict | list | None): + tasks = [] + if isinstance(ctx, dict): + titles = ctx.get("title") or [] + sentences_list = ctx.get("sentences") or [] + for title, sentences in zip(titles, sentences_list, strict=False): + for idx, sentence in enumerate(sentences): + memory_content = f"{title}: {sentence} [#{idx}]" + tasks.append(memory_content) + elif isinstance(ctx, list): + for item in ctx: + if isinstance(item, list) and len(item) >= 2: + title = item[0] + sentences = item[1] + for idx, sentence in enumerate(sentences): + memory_content = f"{title}: {sentence} [#{idx}]" + tasks.append(memory_content) + + if not tasks: + return + + iso = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + messages = [{"role": "user", "content": content, "created_at": iso} for content in tasks] + client.add(messages=messages, user_id=user_id, conv_id=user_id) + + +def memos_search(user_id: str, query: str, top_k): + results = client.search(query=query, user_id=user_id, top_k=top_k) + memories = results["text_mem"][0]["memories"] + print("Search memories:", len(memories)) + + context = "\n".join([i["memory"] for i in memories]) + f"\n{results.get('pref_string', '')}" + context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context) + + # Extract supporting facts (sp) from raw sources + sp_list: list[list[str | int]] = [] + for m in memories: + sources = (m.get("metadata", {}) or {}).get("sources") or [] + texts, sps = extract_texts_and_sp_from_sources(sources) + for t, s in sps: + sp_list.append([t, s]) + + # De-duplicate while preserving order + seen = set() + dedup_sp = [] + for t, s in sp_list: + key = (t, s) + if key not in seen: + seen.add(key) + dedup_sp.append([t, s]) + + return context, dedup_sp + + +def llm_response(context: str, question: str, question_date: str | None = None) -> str: + prompt = LME_ANSWER_PROMPT.format( + question=question, question_date=question_date or "", context=context + ) + resp = oai_client.chat.completions.create( + model=os.getenv("CHAT_MODEL"), + messages=[{"role": "system", "content": prompt}], + temperature=0, + ) + return resp.choices[0].message.content or "" + + +def extract_answer(question: str, output: str, model_name: str | None = None) -> str: + try: + response = oai_client.chat.completions.create( + model=model_name or os.getenv("CHAT_MODEL"), + messages=[ + { + "role": "user", + "content": ( + "You are an answer extractor. Given a question and a verbose response, " + "return ONLY the concise final answer suitable for HotpotQA exact match.\n\n" + "Rules:\n" + "- If the question asks yes/no, answer strictly 'yes' or 'no'.\n" + "- Otherwise, output the shortest noun phrase/entity/date/number that answers the question.\n" + "- No explanations, no punctuation beyond what's necessary for the answer.\n\n" + f"Question: {question}\nVerbose response: {output}\nFinal answer:" + ), + } + ], + temperature=0.0, + max_tokens=64, + top_p=1, + ) + ans = (response.choices[0].message.content or "").strip() + return ans + except Exception: + text = (output or "").lower() + if " yes" in text or text.startswith("yes"): + return "yes" + if " no" in text or text.startswith("no"): + return "no" + for sep in ["\n", ". ", ".", "?", "!"]: + if sep in output: + cand = output.split(sep)[0].strip() + if cand: + return cand + return (output or "").strip() + + +def build_context_text(context_list): + parts = [] + for title, sentences in context_list: + text = " ".join(s.strip() for s in sentences if s.strip()) + parts.append(f"{title}: {text}") + return "\n".join(parts) + + +def ingest_context(item): + qid = item.get("_id") or item.get("id") + ctx = item.get("context") + add_context_memories(qid, ctx) + return qid + + +def search_and_ask(item): + qid = item.get("_id") or item.get("id") + question = item["question"] + try: + context, sp_list = memos_search(qid, question, top_k=7) + raw_answer = llm_response(context=context, question=question, question_date="") + answer = extract_answer(question, raw_answer) or "" + print("Question:", question) + print("Answer (raw):", raw_answer) + print("Answer (final):", answer) + pred_sp[qid] = sp_list + return qid, answer + except Exception as e: + print(f"[Question {qid}] Error:", e) + traceback.print_exc() + return qid, "" + + +def write_gold(data, out_path: str | None = None): + split = data.get("validation") + items_list = [split[i] for i in range(len(split))] + out = [] + for it in items_list: + qid = it.get("_id") or it.get("id") + sp = it.get("supporting_facts") + if isinstance(sp, dict): + titles = sp.get("title") or [] + sent_ids = sp.get("sent_id") or [] + sp_list = [[t, s] for t, s in zip(titles, sent_ids, strict=False)] + else: + sp_list = sp or [] + ctx = it.get("context") + if isinstance(ctx, dict): + titles = ctx.get("title") or [] + sentences = ctx.get("sentences") or [] + ctx_list = [[t, s] for t, s in zip(titles, sentences, strict=False)] + else: + ctx_list = ctx or [] + out.append( + { + "_id": qid, + "question": it.get("question"), + "answer": it.get("answer"), + "supporting_facts": sp_list, + "context": ctx_list, + } + ) + target_path = out_path or gold_path + tmp_path = target_path + ".tmp" + try: + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(out, f, ensure_ascii=False) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, target_path) + except Exception as e: + print("Failed to save gold:", e) + + +def run_eval(pred_file: str | None = None, gold_file: str | None = None): + spec = importlib.util.spec_from_file_location( + "hotpot_eval_v1", "evaluation/scripts/hotpot/hotpot_evaluate_v1.py" + ) + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + m.eval(pred_file or pred_path, gold_file or gold_path) + + +def save_pred(): + tmp_path = pred_path + ".tmp" + try: + safe_pred_answers = { + k: (v if isinstance(v, str) else ("" if v is None else str(v))) + for k, v in pred_answers.items() + } + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump({"answer": safe_pred_answers, "sp": pred_sp}, f, ensure_ascii=False, indent=2) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, pred_path) + except Exception as e: + print("Failed to save:", e) + + +def main(): + interval = 10 + items_list = load_hotpot_data("evaluation/data/hotpot") + + if os.path.exists(pred_path): + try: + with open(pred_path, encoding="utf-8") as f: + prev = json.load(f) + if isinstance(prev, dict) and isinstance(prev.get("answer"), dict): + prev_ans = { + k: (v if isinstance(v, str) else ("" if v is None else str(v))) + for k, v in prev["answer"].items() + } + pred_answers.update(prev_ans) + if isinstance(prev, dict) and isinstance(prev.get("sp"), dict): + pred_sp.update(prev["sp"]) + except Exception as e: + print("Failed to read historical predictions:", e) + + processed = len(pred_answers) + print("Starting evaluation, total samples:", len(items_list)) + print("Existing predictions:", processed) + + pending_items = [] + for it in items_list: + qid = it.get("_id") or it.get("id") + if qid not in pred_answers: + pending_items.append(it) + + if pending_items: + print(f"[Step1: Ingest] start, items={len(pending_items)}") + with ThreadPoolExecutor(max_workers=8) as executor: + futures = { + executor.submit(ingest_context, item): idx for idx, item in enumerate(pending_items) + } + for future in tqdm(as_completed(futures), total=len(futures), desc="Ingest"): + try: + future.result() + except Exception as e: + print("Ingest thread failed:", e) + traceback.print_exc() + + print(f"[Step2: QA] start, items={len(pending_items)}") + with ThreadPoolExecutor(max_workers=8) as executor: + futures = { + executor.submit(search_and_ask, item): idx for idx, item in enumerate(pending_items) + } + for future in tqdm(as_completed(futures), total=len(futures), desc="QA"): + try: + qid, answer = future.result() + except Exception as e: + print("QA thread failed:", e) + traceback.print_exc() + continue + pred_answers[qid] = answer + processed += 1 + if processed % 10 == 0: + print("Completed:", processed, "Remaining:", len(items_list) - processed) + save_pred() + if processed % interval == 0: + print("Stage evaluation, current progress:", processed) + run_eval() + + run_eval() + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/hotpot/hotpot_search.py b/evaluation/scripts/hotpot/hotpot_search.py new file mode 100644 index 000000000..6542d4539 --- /dev/null +++ b/evaluation/scripts/hotpot/hotpot_search.py @@ -0,0 +1,297 @@ +import argparse +import json +import os +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from tqdm import tqdm + +from evaluation.scripts.hotpot.data_loader import load_hotpot_data +from evaluation.scripts.utils.metrics import Metrics + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + result = func(*args, **kwargs) + if isinstance(result, dict) and "data" in result: + return result["data"] + return result + except Exception as e: + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def _get_lib_client(lib: str): + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + + +def _load_existing_results(output_path: Path) -> tuple[list[dict], set[str]]: + if not output_path.exists(): + return [], set() + try: + data = json.loads(output_path.read_text(encoding="utf-8")) + if isinstance(data, list): + ids = {str(r.get("_id")) for r in data if r.get("_id")} + return data, ids + if isinstance(data, dict) and isinstance(data.get("results"), list): + rows = data.get("results") or [] + ids = {str(r.get("_id")) for r in rows if r.get("_id")} + return rows, ids + except Exception: + return [], set() + return [], set() + + +def _save_json_list(path: Path, rows: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps({"results": rows}, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + +def get_sources_info(sources): + seen = set() + dedup_sp = [] + mem_texts = [] + + for source in sources: + if isinstance(source, str): + try: + obj = json.loads(source) + except json.JSONDecodeError: + continue + + title = obj.get("title") + idx = obj.get("idx") + sentence = obj.get("sentence") + + if title is None or idx is None: + continue + + key = (title, idx) + if key not in seen: + seen.add(key) + dedup_sp.append([title, idx]) + mem_texts.append(sentence) + + return mem_texts, dedup_sp + + +def memos_search( + client, user_id: str, query: str, top_k: int, search_mode: str +) -> tuple[str, list[list[str | int]]]: + readable_cube_ids = [user_id] + results = retry_operation( + client.search, + query=query, + user_id=user_id, + readable_cube_ids=readable_cube_ids, + top_k=top_k, + mode=search_mode, + ) + memories = results["text_mem"][0]["memories"] + mem_texts = [i["memory"] for i in memories] + + sources = [] + for m in memories: + source = (m.get("metadata", {}) or {}).get("sources") or [] + for s in source: + source_txt = json.loads(s["content"]) + sources.append(json.loads(source_txt)["content"]) + sources.extend(source) + + _, dedup_sp = get_sources_info(sources) + return mem_texts, dedup_sp + + +def mem0_search(client, user_id: str, query: str, top_k: int) -> tuple[str, list[list[str | int]]]: + res = retry_operation(client.search, query, user_id, top_k) + sources = [m.get("memory", "") for m in res.get("results", []) if m.get("memory")] + mem_texts, dedup_sp = get_sources_info(sources) + return mem_texts, dedup_sp + + +def supermemory_search( + client, user_id: str, query: str, top_k: int +) -> tuple[str, list[list[str | int]]]: + sources = retry_operation(client.search, query, user_id, top_k) + mem_texts, dedup_sp = get_sources_info(sources) + return mem_texts, dedup_sp + + +def search_one( + client, lib: str, item: dict, top_k: int, version_dir: str, search_mode: str +) -> dict: + qid = item.get("_id") or item.get("id") + question = item.get("question") or "" + user_id = version_dir + "_" + str(qid) + + if lib == "memos": + memories, sp_list = memos_search(client, user_id, str(question), top_k, search_mode) + elif lib == "mem0": + memories, sp_list = mem0_search(client, user_id, str(question), top_k) + elif lib == "supermemory": + memories, sp_list = supermemory_search(client, user_id, str(question), top_k) + else: + memories, sp_list = [], [] + + return { + "_id": str(qid), + "question": question, + "answer": item.get("answer"), + "memories": memories, + "sp": sp_list, + } + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser(description="HotpotQA search (search only).") + parser.add_argument( + "--lib", + type=str, + default="memos", + choices=["memos", "mem0", "supermemory"], + ) + parser.add_argument("--workers", type=int, default=8) + parser.add_argument("--top-k", type=int, default=7) + parser.add_argument( + "--limit", type=int, default=None, help="Limit number of samples (was max_samples)" + ) + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + parser.add_argument("--search-mode", default="fine", help="Search mode") + + args = parser.parse_args(argv) + + # Handle limit/max_samples compatibility + limit = args.limit if args.limit is not None else args.max_samples + + items_list = load_hotpot_data("evaluation/data/hotpot") + if limit is not None: + items_list = items_list[:limit] + + output_dir = Path(f"evaluation/data/hotpot/{args.version_dir}") + output_dir.mkdir(parents=True, exist_ok=True) + + if args.lib == "memos": + output_path = output_dir / f"{args.lib}_{args.search_mode}_search_results.json" + else: + output_path = output_dir / f"{args.lib}_search_results.json" + + output_path.parent.mkdir(parents=True, exist_ok=True) + + results, processed_ids = _load_existing_results(output_path) + pending_items = [] + for it in items_list: + qid = it.get("_id") or it.get("id") + if str(qid) not in processed_ids: + pending_items.append(it) + + print( + f"[Search] lib={args.lib} total={len(items_list)} pending={len(pending_items)} top_k={args.top_k}" + ) + if not pending_items: + return + + client = _get_lib_client(args.lib) + metrics = Metrics() + start_time = time.time() + + with ThreadPoolExecutor(max_workers=args.workers) as executor: + + def do_search(item): + st = time.perf_counter() + try: + r = search_one( + client, args.lib, item, args.top_k, args.version_dir, args.search_mode + ) + dur = time.perf_counter() - st + metrics.record(dur, True) + return r + except Exception as e: + dur = time.perf_counter() - st + metrics.record(dur, False, str(e)) + raise e + + futures = [executor.submit(do_search, it) for it in pending_items] + for idx, f in enumerate( + tqdm(as_completed(futures), total=len(futures), desc="Searching"), 1 + ): + try: + r = f.result() + results.append(r) + if idx % 20 == 0: + _save_json_list(output_path, results) + except Exception as e: + print(f"[Search] Error: {e}") + traceback.print_exc() + + _save_json_list(output_path, results) + print(f"[Search] saved {len(results)} rows to {output_path}") + + # Save performance metrics + total_duration = time.time() - start_time + summary = metrics.summary() + # Merge perf into results json file + combined_obj = { + "results": results, + "perf": { + "summary": summary, + "total_duration": total_duration, + "config": { + "workers": args.workers, + "top_k": args.top_k, + "limit": limit, + "search_mode": args.search_mode, + "lib": args.lib, + }, + }, + } + tmp = output_path.with_suffix(output_path.suffix + ".tmp") + tmp.write_text(json.dumps(combined_obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, output_path) + + print("\n" + "=" * 60) + print("Search finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + if summary["errors"]: + print("\nError stats:") + for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" [{count} times] {error[:100]}...") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/long_bench-v2/run_longbench_v2_eval.sh b/evaluation/scripts/long_bench-v2/run_longbench_v2_eval.sh new file mode 100755 index 000000000..917c57bfb --- /dev/null +++ b/evaluation/scripts/long_bench-v2/run_longbench_v2_eval.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +# Common parameters for all scripts +LIB="memos-api" +VERSION="long-bench-v2-1208-1556-async" +WORKERS=10 +TOPK=20 +MAX_SAMPLES="" # Empty means all samples +WAIT_INTERVAL=2 # seconds between polls +WAIT_TIMEOUT=900 # seconds per user + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --lib) + LIB="$2" + shift 2 + ;; + --version) + VERSION="$2" + shift 2 + ;; + --workers) + WORKERS="$2" + shift 2 + ;; + --top_k) + TOPK="$2" + shift 2 + ;; + --max_samples) + MAX_SAMPLES="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Build max_samples argument +MAX_SAMPLES_ARG="" +if [ -n "$MAX_SAMPLES" ]; then + MAX_SAMPLES_ARG="--max_samples $MAX_SAMPLES" +fi + +echo "Running LongBench v2 evaluation with:" +echo " LIB: $LIB" +echo " VERSION: $VERSION" +echo " WORKERS: $WORKERS" +echo " TOPK: $TOPK" +echo " MAX_SAMPLES: ${MAX_SAMPLES:-all}" +echo "" + +# Step 2: Search +echo "" +echo "==========================================" +echo "Step 2: Running longbench_v2_search.py..." +echo "==========================================" +python scripts/long_bench-v2/longbench_v2_search.py \ + --lib $LIB \ + --version $VERSION \ + --top_k $TOPK \ + --workers $WORKERS \ + $MAX_SAMPLES_ARG + +if [ $? -ne 0 ]; then + echo "Error running longbench_v2_search.py" + exit 1 +fi + +# Step 3: Response Generation +echo "" +echo "==========================================" +echo "Step 3: Running longbench_v2_responses.py..." +echo "==========================================" +python scripts/long_bench-v2/longbench_v2_responses.py \ + --lib $LIB \ + --version $VERSION \ + --workers $WORKERS + +if [ $? -ne 0 ]; then + echo "Error running longbench_v2_responses.py" + exit 1 +fi + +# Step 4: Metrics Calculation +echo "" +echo "==========================================" +echo "Step 4: Running longbench_v2_metric.py..." +echo "==========================================" +python scripts/long_bench-v2/longbench_v2_metric.py \ + --lib $LIB \ + --version $VERSION + +if [ $? -ne 0 ]; then + echo "Error running longbench_v2_metric.py" + exit 1 +fi + +echo "" +echo "==========================================" +echo "All steps completed successfully!" +echo "==========================================" +echo "" +echo "Results are saved in: results/long_bench-v2/$LIB-$VERSION/" +echo " - Search results: ${LIB}_longbench_v2_search_results.json" +echo " - Responses: ${LIB}_longbench_v2_responses.json" +echo " - Metrics: ${LIB}_longbench_v2_metrics.json" diff --git a/evaluation/scripts/longbench_v2/longbench_v2_eval.py b/evaluation/scripts/longbench_v2/longbench_v2_eval.py new file mode 100644 index 000000000..ae88c34f9 --- /dev/null +++ b/evaluation/scripts/longbench_v2/longbench_v2_eval.py @@ -0,0 +1,241 @@ +import argparse +import json +import os +import re +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from dotenv import load_dotenv +from openai import OpenAI +from tqdm import tqdm + +from evaluation.scripts.utils.prompts import LONGBENCH_V2_ANSWER_PROMPT + + +load_dotenv() + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + return func(*args, **kwargs) + except Exception as e: + traceback.print_exc() + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def extract_answer(response: str) -> str | None: + response = response.replace("*", "") + match = re.search(r"The correct answer is \(([A-D])\)", response) + if match: + return match.group(1) + match = re.search(r"The correct answer is ([A-D])", response) + if match: + return match.group(1) + return None + + +def llm_answer( + oai_client, model_name, memories: list[str], question: str, choices: dict +) -> tuple[str, int]: + doc_content = "\n\n".join([f"Retrieved chunk {idx + 1}: {m}" for idx, m in enumerate(memories)]) + prompt = ( + LONGBENCH_V2_ANSWER_PROMPT.replace("$DOC$", doc_content) + .replace("$Q$", question) + .replace("$C_A$", choices.get("A", "")) + .replace("$C_B$", choices.get("B", "")) + .replace("$C_C$", choices.get("C", "")) + .replace("$C_D$", choices.get("D", "")) + ) + messages = [{"role": "user", "content": prompt}] + resp = retry_operation( + oai_client.chat.completions.create, + model=model_name, + messages=messages, + temperature=0.1, + max_tokens=12800, + ) + return resp.choices[0].message.content or "", resp.usage.prompt_tokens + + +def print_metrics(results: list[dict], duration: float) -> None: + easy, hard, short, medium, long = 0, 0, 0, 0, 0 + easy_acc, hard_acc, short_acc, medium_acc, long_acc = 0, 0, 0, 0, 0 + total_tokens = 0 + + for pred in results: + acc = int(pred.get("judge", False)) + diff = pred.get("difficulty", "easy") + length = pred.get("length", "short") + tokens = pred.get("prompt_tokens", 0) + total_tokens += tokens + + if diff == "easy": + easy += 1 + easy_acc += acc + else: + hard += 1 + hard_acc += acc + + if length == "short": + short += 1 + short_acc += acc + elif length == "medium": + medium += 1 + medium_acc += acc + else: + long += 1 + long_acc += acc + + total = len(results) + if total == 0: + print("No results to calculate metrics.") + return + + o_acc = round(100 * (easy_acc + hard_acc) / total, 2) + e_acc = round(100 * easy_acc / easy, 2) if easy > 0 else 0 + h_acc = round(100 * hard_acc / hard, 2) if hard > 0 else 0 + s_acc = round(100 * short_acc / short, 2) if short > 0 else 0 + m_acc = round(100 * medium_acc / medium, 2) if medium > 0 else 0 + l_acc = round(100 * long_acc / long, 2) if long > 0 else 0 + avg_tokens = round(total_tokens / total, 2) + + print("\n" + "=" * 60) + print(f"{'Metric':<15} | {'Count':<10} | {'Accuracy (%)':<10}") + print("-" * 60) + print(f"{'Overall':<15} | {total:<10} | {o_acc:<10}") + print(f"{'Easy':<15} | {easy:<10} | {e_acc:<10}") + print(f"{'Hard':<15} | {hard:<10} | {h_acc:<10}") + print(f"{'Short':<15} | {short:<10} | {s_acc:<10}") + print(f"{'Medium':<15} | {medium:<10} | {m_acc:<10}") + print(f"{'Long':<15} | {long:<10} | {l_acc:<10}") + print("-" * 60) + print(f"{'Avg Tokens':<15} | {total:<10} | {avg_tokens:<10}") + print(f"Total Duration: {duration:.2f} seconds") + print("=" * 60 + "\n") + + +def _load_json_list(path: Path) -> list[dict]: + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, list): + raise ValueError(f"Invalid json format: {path}") + return data + + +def _save_json_list(path: Path, rows: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps(rows, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + +def evaluate_one(oai_client, model_name, row: dict) -> dict: + question = row.get("question") or "" + choices = row.get("choices") or {} + memories = row.get("memories_used") or [] + response, prompt_tokens = llm_answer( + oai_client, model_name, list(memories), str(question), dict(choices) + ) + pred = extract_answer(response) + judge = pred == row.get("answer") + out = dict(row) + out["response"] = response + out["pred"] = pred + out["judge"] = judge + out["prompt_tokens"] = prompt_tokens + out.pop("memories_used") + return out + + +def main() -> None: + parser = argparse.ArgumentParser(description="LongBench-v2 eval Tool") + parser.add_argument( + "--lib", + "-b", + required=True, + choices=["memos", "mem0", "supermemory"], + help="Product name to evaluate", + ) + parser.add_argument("--workers", "-w", type=int, default=20, help="Number of parallel threads") + parser.add_argument( + "--top-k", "-k", type=int, default=20, help="Top k results to use (default: 20)" + ) + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + parser.add_argument("--search_results_path", type=str, default=None) + parser.add_argument("--output_path", type=str, default=None) + parser.add_argument("--chat-model", type=str, default=None, help="Chat model for evaluation") + args = parser.parse_args() + + print("=" * 60) + print("LongBench-v2 Product Eval Tool") + print("=" * 60) + + start_time = time.time() + + output_dir = os.path.join("evaluation/data/longbench_v2", args.version_dir) + search_filename = f"{args.lib}_search_results.json" + search_path = Path(os.path.join(output_dir, search_filename)) + + if not search_path.exists(): + raise FileNotFoundError(f"Search results not found: {search_path}") + + search_rows = _load_json_list(search_path) + output_filename = f"{args.lib}_eval_results.json" + output_path = Path(os.path.join(output_dir, output_filename)) + + results: list[dict] = [] + processed_ids: set[str] = set() + + # Resume from checkpoint + if output_path.exists(): + try: + existing = _load_json_list(output_path) + results = existing + processed_ids = {str(r.get("_id")) for r in results if r.get("_id")} + print(f"Loaded {len(results)} existing results from checkpoint.") + except Exception as e: + print(f"Error loading checkpoint: {e}") + + pending = [r for r in search_rows if str(r.get("_id")) not in processed_ids] + print(f"[Eval] total={len(search_rows)} pending={len(pending)} workers={args.workers}") + if not pending: + print_metrics(results, time.time() - start_time) + return + + print("[Response model]: ", args.chat_model) + oai_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL") + ) + + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = [ + executor.submit(evaluate_one, oai_client, args.chat_model, row) for row in pending + ] + for idx, f in enumerate( + tqdm(as_completed(futures), total=len(futures), desc="Evaluating"), start=1 + ): + try: + res = f.result() + results.append(res) + if idx % 10 == 0: + _save_json_list(output_path, results) + except Exception as e: + print(f"Evaluation Error: {e}") + traceback.print_exc() + + _save_json_list(output_path, results) + print(f"Saved {len(results)} results to {output_path}") + print_metrics(results, time.time() - start_time) + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/longbench_v2/longbench_v2_ingestion.py b/evaluation/scripts/longbench_v2/longbench_v2_ingestion.py new file mode 100644 index 000000000..ec5257717 --- /dev/null +++ b/evaluation/scripts/longbench_v2/longbench_v2_ingestion.py @@ -0,0 +1,284 @@ +import argparse +import json +import os +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from datasets import load_dataset +from dotenv import load_dotenv +from langchain_text_splitters import Language, RecursiveCharacterTextSplitter +from tqdm import tqdm + +from evaluation.scripts.utils.metrics import Metrics + + +load_dotenv() + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + return func(*args, **kwargs) + except Exception as e: + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def _get_lib_client(lib: str): + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + + +def _load_dataset_jsonl(dataset_path: Path) -> list[dict]: + if not dataset_path.exists(): + dataset = load_dataset("zai-org/LongBench-v2", split="train") + dataset_path.parent.mkdir(parents=True, exist_ok=True) + with open(dataset_path, "w", encoding="utf-8") as f: + for i in range(len(dataset)): + s = dataset[i] + row = { + "_id": s.get("_id") or s.get("id") or str(i), + "domain": s.get("domain"), + "sub_domain": s.get("sub_domain"), + "difficulty": s.get("difficulty"), + "length": s.get("length"), + "question": s.get("question"), + "choice_A": s.get("choice_A"), + "choice_B": s.get("choice_B"), + "choice_C": s.get("choice_C"), + "choice_D": s.get("choice_D"), + "answer": s.get("answer"), + "context": s.get("context") or s.get("document") or s.get("documents"), + } + f.write(json.dumps(row, ensure_ascii=False) + "\n") + print(f"Successfully saved dataset to {dataset_path}") + + samples: list[dict] = [] + with open(dataset_path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + samples.append(json.loads(line)) + return samples + + +def _load_added_ids(records_path: Path) -> set[str]: + if not records_path.exists(): + return set() + try: + obj = json.loads(records_path.read_text(encoding="utf-8")) + ids = obj.get("added_ids") if isinstance(obj, dict) else None + if isinstance(ids, list): + return {str(x) for x in ids if x} + except Exception: + return set() + return set() + + +def _save_added_ids(records_path: Path, added_ids: set[str]) -> None: + records_path.parent.mkdir(parents=True, exist_ok=True) + tmp = records_path.with_suffix(records_path.suffix + ".tmp") + tmp.write_text( + json.dumps({"added_ids": sorted(added_ids)}, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + os.replace(tmp, records_path) + + +def ingest_context( + client, sample: dict, lib: str, mode: str = "fine", async_mode: str = "sync" +) -> str: + sample_id = str(sample.get("_id")) + user_id = sample_id + context = sample.get("context") or "" + chunker = RecursiveCharacterTextSplitter.from_language( + language=Language.PYTHON, chunk_size=5120, chunk_overlap=128 + ) + chunks = [p for p in chunker.split_text(context or "") if p.strip()] + + if lib == "memos": + messages = [{"type": "text", "text": p} for p in chunks] + writable_cube_ids = [user_id] + retry_operation( + client.add, + messages=messages, + user_id=user_id, + writable_cube_ids=writable_cube_ids, + source_type="batch_import", + mode=mode, + async_mode=async_mode, + ) + return + + if lib == "mem0": + messages = [{"role": "user", "content": p} for p in chunks] + ts = int(time.time()) + retry_operation(client.add, messages=messages, user_id=user_id, timestamp=ts, batch_size=10) + return + + if lib == "supermemory": + retry_operation(client.add, content=context, user_id=user_id) + + return sample_id + + +def parse_args(): + parser = argparse.ArgumentParser( + description="LongBench-v2 Product Add Concurrent Script", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate") + + parser.add_argument( + "--api-url", + default="http://127.0.0.1:8001", + help="MemOS API URL (default: http://127.0.0.1:8001)", + ) + + parser.add_argument("--workers", "-w", type=int, default=5, help="Concurrency (default: 10)") + + parser.add_argument( + "--timeout", type=float, default=1200, help="Request timeout in seconds (default: 120)" + ) + + parser.add_argument( + "--mode", default="fine", choices=["fine", "fast"], help="Processing mode (default: fine)" + ) + + parser.add_argument( + "--async-mode", default="sync", choices=["sync", "async"], help="Async mode (default: sync)" + ) + + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + + parser.add_argument( + "--dataset_path", + "-p", + default="evaluation/data/longbench_v2/longbenchv2_train.json", + help="Dataset path", + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + print("=" * 60) + print("LongBench-v2 Product Add Concurrent Tool") + print("=" * 60) + + dataset_path = Path(args.dataset_path) + dataset = _load_dataset_jsonl(dataset_path) + + version_output_dir = os.path.join("evaluation/data/longbench_v2", args.version_dir) + os.makedirs(version_output_dir, exist_ok=True) + output_path = os.path.join(version_output_dir, f"{args.lib}_add_results.json") + output_path = Path(output_path) + + added_ids = _load_added_ids(output_path) + pending = [s for s in dataset if str(s.get("_id")) not in added_ids] + print( + f"[Add] lib={args.lib} total={len(dataset)} pending={len(pending)} workers={args.workers}" + ) + if not pending: + return + + client = _get_lib_client(args.lib) + metrics = Metrics() + + def do_ingest(sample): + start_time = time.perf_counter() + try: + sid = ingest_context(client, sample, args.lib, args.mode, args.async_mode) + duration = time.perf_counter() - start_time + metrics.record(duration, True) + return sid + except Exception as e: + duration = time.perf_counter() - start_time + metrics.record(duration, False, str(e)) + raise e + + start_time = time.time() + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = [executor.submit(do_ingest, sample) for sample in pending] + for f in tqdm(as_completed(futures), total=len(futures), desc="Adding"): + try: + sid = f.result() + if sid: + added_ids.add(str(sid)) + if len(added_ids) % 10 == 0: + _save_added_ids(output_path, added_ids) + except Exception as e: + print(f"[Add] Error: {e}") + traceback.print_exc() + + _save_added_ids(output_path, added_ids) + print(f"[Add] saved records to {output_path}") + + total_duration = time.time() - start_time + perf_out = Path(version_output_dir) / f"{args.lib}_add_perf.json" + + summary = metrics.summary() + + with open(perf_out, "w", encoding="utf-8") as f: + json.dump( + { + "summary": summary, + "total_duration": total_duration, + "config": { + "workers": args.workers, + "mode": args.mode, + "async_mode": args.async_mode, + "dataset_path": args.dataset_path, + }, + }, + f, + ensure_ascii=False, + indent=2, + ) + print(f"[Add] saved performance metrics to {perf_out}") + + print("\n" + "=" * 60) + print("Ingestion finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + if summary["errors"]: + print("\nError stats:") + for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" [{count} times] {error[:100]}...") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/longbench_v2/longbench_v2_old.py b/evaluation/scripts/longbench_v2/longbench_v2_old.py new file mode 100644 index 000000000..5c6eba616 --- /dev/null +++ b/evaluation/scripts/longbench_v2/longbench_v2_old.py @@ -0,0 +1,415 @@ +import argparse +import json +import os +import re +import sys +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from pathlib import Path + +from dotenv import load_dotenv +from langchain_text_splitters import Language, RecursiveCharacterTextSplitter +from openai import OpenAI +from tqdm import tqdm + + +TEMPLATE_RAG = """Please read the following retrieved text chunks and answer the question below. + + +$DOC$ + + +What is the correct answer to this question: $Q$ +Choices: +(A) $C_A$ +(B) $C_B$ +(C) $C_C$ +(D) $C_D$ + +Format your response as follows: "The correct answer is (insert answer here)".""" + + +PROJECT_ROOT = Path(__file__).resolve().parents[3] +SCRIPTS_ROOT = Path(__file__).resolve().parents[1] +SRC_ROOT = PROJECT_ROOT / "src" +sys.path.append(str(SCRIPTS_ROOT)) +sys.path.append(str(SRC_ROOT)) +load_dotenv() + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + return func(*args, **kwargs) + except Exception as e: + traceback.print_exc() + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def _get_lib_client(lib: str): + if lib == "mem0": + from utils.client import Mem0Client # type: ignore + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from utils.client import SupermemoryClient # type: ignore + + return SupermemoryClient() + from utils.client import MemosApiClient # type: ignore + + return MemosApiClient() + + +def _get_clients(lib: str = "memos"): + client = _get_lib_client(lib) + openai_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL") + ) + return client, openai_client + + +def _dump_dataset_to_local(): + from datasets import load_dataset + + dataset = load_dataset("zai-org/LongBench-v2", split="train") + out_dir = Path("evaluation/data/longbenchV2") + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / "longbenchv2_train.json" + with open(out_path, "w", encoding="utf-8") as f: + for i in range(len(dataset)): + s = dataset[i] + row = { + "_id": s.get("_id") or s.get("id") or str(i), + "domain": s.get("domain"), + "sub_domain": s.get("sub_domain"), + "difficulty": s.get("difficulty"), + "length": s.get("length"), + "question": s.get("question"), + "choice_A": s.get("choice_A"), + "choice_B": s.get("choice_B"), + "choice_C": s.get("choice_C"), + "choice_D": s.get("choice_D"), + "answer": s.get("answer"), + "context": s.get("context") or s.get("document") or s.get("documents"), + } + f.write(json.dumps(row, ensure_ascii=False) + "\n") + print(f"Saved dataset to {out_path}") + return dataset + + +def add_context(client, user_id: str, context: str, lib: str) -> None: + iso = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + chunker = RecursiveCharacterTextSplitter.from_language( + language=Language.PYTHON, chunk_size=5120, chunk_overlap=128 + ) + paragraphs = [p for p in chunker.split_text(context or "") if p.strip()] + + if lib == "memos": + messages = [{"role": "user", "content": p, "created_at": iso} for p in paragraphs] + try: + retry_operation(client.add, messages=messages, user_id=user_id, conv_id=user_id) + print(f"[Add-memos]: successfully added {len(messages)} chunks to user {user_id}") + except Exception as e: + print(f"[Add-memos] failed: {e}") + + elif lib == "mem0": + messages = [{"role": "user", "content": p} for p in paragraphs] + ts = int(time.time()) + try: + retry_operation( + client.add, messages=messages, user_id=user_id, timestamp=ts, batch_size=10 + ) + print(f"[Add-mem0] user={user_id} total={len(messages)}") + except Exception as e: + print(f"[Add-mem0] failed: {e}") + + elif lib == "supermemory": + iso = datetime.utcnow().isoformat() + "Z" + content = "\n".join([f"{iso} user: {p}" for p in paragraphs]) + try: + retry_operation(client.add, content=content, user_id=user_id) + print(f"[Add-supermemory] user={user_id} total_chars={len(content)}") + except Exception as e: + print(f"[Add-supermemory] failed: {e}") + + +def memos_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]: + results = retry_operation(client.search, query=query, user_id=user_id, top_k=top_k) + memories = results["text_mem"][0]["memories"] + mem_texts = [m["memory"] for m in memories] + print(f"[Search-memos] user={user_id} top_k={top_k} memories={len(memories)}") + return mem_texts + + +def mem0_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]: + res = retry_operation(client.search, query, user_id, top_k) + results = res.get("results", []) + mem_texts = [m.get("memory", "") for m in results if m.get("memory")] + print(f"[Search-mem0] user={user_id} top_k={top_k} memories={len(mem_texts)}") + return mem_texts + + +def supermemory_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]: + chunk_list = retry_operation(client.search, query, user_id, top_k) + print(f"[Search-supermemory] user={user_id} top_k={top_k} memories={len(chunk_list)}") + return chunk_list + + +def extract_answer(response: str) -> str | None: + response = response.replace("*", "") + match = re.search(r"The correct answer is \(([A-D])\)", response) + if match: + return match.group(1) + else: + match = re.search(r"The correct answer is ([A-D])", response) + if match: + return match.group(1) + else: + return None + + +def llm_answer(oai_client, memories: list[str], question: str, choices: dict) -> tuple[str, int]: + # Join memories to form the retrieved context document + doc_content = "\n\n".join([f"Retrieved chunk {idx + 1}: {m}" for idx, m in enumerate(memories)]) + + prompt = ( + TEMPLATE_RAG.replace("$DOC$", doc_content) + .replace("$Q$", question) + .replace("$C_A$", choices.get("A", "")) + .replace("$C_B$", choices.get("B", "")) + .replace("$C_C$", choices.get("C", "")) + .replace("$C_D$", choices.get("D", "")) + ) + + messages = [ + {"role": "user", "content": prompt}, + ] + resp = retry_operation( + oai_client.chat.completions.create, + model=os.getenv("CHAT_MODEL"), + messages=messages, + temperature=0.1, + max_tokens=12800, + ) + return resp.choices[0].message.content or "", resp.usage.prompt_tokens + + +def ingest_sample(client, sample: dict, lib: str) -> None: + sample_id = str(sample.get("_id")) + user_id = sample_id + context = sample.get("context") or "" + add_context(client, user_id, str(context), lib) + + +def evaluate_sample(client, oai_client, sample: dict, top_k: int, lib: str) -> dict: + sample_id = str(sample.get("_id")) + user_id = sample_id + question = sample.get("question") or "" + choices = { + "A": sample.get("choice_A") or "", + "B": sample.get("choice_B") or "", + "C": sample.get("choice_C") or "", + "D": sample.get("choice_D") or "", + } + + if lib == "memos": + memories = memos_search(client, user_id, str(question), top_k=top_k) + elif lib == "mem0": + memories = mem0_search(client, user_id, str(question), top_k=top_k) + elif lib == "supermemory": + memories = supermemory_search(client, user_id, str(question), top_k=top_k) + else: + memories = [] + + response, prompt_tokens = llm_answer(oai_client, memories, str(question), choices) + pred = extract_answer(response) + judge = pred == sample.get("answer") + print("[Question]:", question) + print("[Choices]:", choices) + print("[Raw response]:", response) + print("[Answer]:", pred) + print("[Ground truth]:", sample.get("answer")) + print("[Prompt Tokens]:", prompt_tokens) + + out = { + "_id": sample_id, + "domain": sample.get("domain"), + "sub_domain": sample.get("sub_domain"), + "difficulty": sample.get("difficulty"), + "length": sample.get("length"), + "question": question, + "choice_A": choices["A"], + "choice_B": choices["B"], + "choice_C": choices["C"], + "choice_D": choices["D"], + "answer": sample.get("answer"), + "memories_used": memories, + "response": response, + "pred": pred, + "judge": judge, + "prompt_tokens": prompt_tokens, + } + return out + + +def print_metrics(results: list[dict], duration: float) -> None: + easy, hard, short, medium, long = 0, 0, 0, 0, 0 + easy_acc, hard_acc, short_acc, medium_acc, long_acc = 0, 0, 0, 0, 0 + total_tokens = 0 + + for pred in results: + acc = int(pred.get("judge", False)) + diff = pred.get("difficulty", "easy") + length = pred.get("length", "short") + tokens = pred.get("prompt_tokens", 0) + total_tokens += tokens + + if diff == "easy": + easy += 1 + easy_acc += acc + else: + hard += 1 + hard_acc += acc + + if length == "short": + short += 1 + short_acc += acc + elif length == "medium": + medium += 1 + medium_acc += acc + else: + long += 1 + long_acc += acc + + total = len(results) + if total == 0: + print("No results to calculate metrics.") + return + + o_acc = round(100 * (easy_acc + hard_acc) / total, 2) + e_acc = round(100 * easy_acc / easy, 2) if easy > 0 else 0 + h_acc = round(100 * hard_acc / hard, 2) if hard > 0 else 0 + s_acc = round(100 * short_acc / short, 2) if short > 0 else 0 + m_acc = round(100 * medium_acc / medium, 2) if medium > 0 else 0 + l_acc = round(100 * long_acc / long, 2) if long > 0 else 0 + avg_tokens = round(total_tokens / total, 2) + + print("\n" + "=" * 60) + print(f"{'Metric':<15} | {'Count':<10} | {'Accuracy (%)':<10}") + print("-" * 60) + print(f"{'Overall':<15} | {total:<10} | {o_acc:<10}") + print(f"{'Easy':<15} | {easy:<10} | {e_acc:<10}") + print(f"{'Hard':<15} | {hard:<10} | {h_acc:<10}") + print(f"{'Short':<15} | {short:<10} | {s_acc:<10}") + print(f"{'Medium':<15} | {medium:<10} | {m_acc:<10}") + print(f"{'Long':<15} | {long:<10} | {l_acc:<10}") + print("-" * 60) + print(f"{'Avg Tokens':<15} | {total:<10} | {avg_tokens:<10}") + print(f"Total Duration: {duration:.2f} seconds") + print("=" * 60 + "\n") + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate LongBench-v2 with different backends.") + parser.add_argument( + "--lib", + type=str, + default="memos", + choices=["memos", "mem0", "supermemory"], + help="Backend library to use (default: memos)", + ) + args = parser.parse_args() + + start_time = time.time() + print("[Response model]: ", os.getenv("CHAT_MODEL")) + + client, oai_client = _get_clients(lib=args.lib) + dataset = _dump_dataset_to_local() + results: list[dict] = [] + os.makedirs("evaluation/data/longbenchV2", exist_ok=True) + out_json = Path(f"evaluation/data/longbenchV2/test/{args.lib}_cot_results.json") + + # Checkpoint loading + processed_ids = set() + if out_json.exists(): + try: + with open(out_json, encoding="utf-8") as f: + existing_results = json.load(f) + if isinstance(existing_results, list): + results = existing_results + processed_ids = {r.get("_id") for r in results if r.get("_id")} + print(f"Loaded {len(results)} existing results from checkpoint.") + except Exception as e: + print(f"Error loading checkpoint: {e}") + + # Filter dataset to skip processed samples + remaining_dataset = [ + s + for s in dataset + if (s.get("_id") or s.get("id") or str(dataset.index(s))) not in processed_ids + ] + + # Concurrency settings + max_workers = 4 + print(f"Starting evaluation with {max_workers} workers using backend: {args.lib}") + print(f"Total dataset size: {len(dataset)}") + print(f"Already processed: {len(processed_ids)}") + print(f"Remaining to process: {len(remaining_dataset)}") + + if not remaining_dataset: + print("All samples have been processed.") + print_metrics(results, time.time() - start_time) + return + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Phase 1: Ingestion + print("Phase 1: Ingesting context...") + ingest_futures = [ + executor.submit(ingest_sample, client, sample, args.lib) for sample in remaining_dataset + ] + for f in tqdm(as_completed(ingest_futures), total=len(ingest_futures), desc="Ingesting"): + try: + f.result() + except Exception as e: + print(f"Ingestion Error: {e}") + + # Phase 2: Evaluation + print("Phase 2: Evaluating...") + futures = [ + executor.submit(evaluate_sample, client, oai_client, sample, 30, args.lib) + for sample in remaining_dataset + ] + + # Use tqdm for progress bar + for f in tqdm(as_completed(futures), total=len(futures), desc="Evaluating"): + try: + res = f.result() + results.append(res) + + # Save intermediate results every 10 samples + if len(results) % 10 == 0: + out_json.write_text( + json.dumps(results, ensure_ascii=False, indent=2), encoding="utf-8" + ) + print_metrics(results, time.time() - start_time) + except Exception as e: + print(f"Evaluation Error: {e}") + traceback.print_exc() + + # Final save + out_json.write_text(json.dumps(results, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"Saved {len(results)} results to {out_json}") + print_metrics(results, time.time() - start_time) + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/longbench_v2/longbench_v2_search.py b/evaluation/scripts/longbench_v2/longbench_v2_search.py new file mode 100644 index 000000000..a224ea3fd --- /dev/null +++ b/evaluation/scripts/longbench_v2/longbench_v2_search.py @@ -0,0 +1,284 @@ +import argparse +import json +import os +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from dotenv import load_dotenv +from tqdm import tqdm + +from evaluation.scripts.utils.metrics import Metrics + + +load_dotenv() + + +def retry_operation(func, *args, retries=5, delay=2, **kwargs): + for attempt in range(retries): + try: + return func(*args, **kwargs) + except Exception as e: + if attempt < retries - 1: + func_name = getattr(func, "__name__", "Operation") + print(f"[Retry] {func_name} failed: {e}. Retrying in {delay}s...") + time.sleep(delay) + delay *= 2 + else: + raise e + + +def _get_lib_client(lib: str): + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + + +def _load_dataset_jsonl(dataset_path: Path) -> list[dict]: + samples: list[dict] = [] + with open(dataset_path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + samples.append(json.loads(line)) + return samples + + +def memos_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]: + results = retry_operation(client.search, query=query, user_id=user_id, top_k=top_k) + memories = results["text_mem"][0]["memories"] + return [m["memory"] for m in memories] + + +def mem0_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]: + res = retry_operation(client.search, query, user_id, top_k) + results = res.get("results", []) + return [m.get("memory", "") for m in results if m.get("memory")] + + +def supermemory_search(client, user_id: str, query: str, top_k: int = 30) -> list[str]: + return retry_operation(client.search, query, user_id, top_k) + + +def _load_existing_results(output_path: Path) -> tuple[list[dict], set[str]]: + if not output_path.exists(): + return [], set() + try: + data = json.loads(output_path.read_text(encoding="utf-8")) + if isinstance(data, list): + ids = {str(r.get("_id")) for r in data if r.get("_id")} + return data, ids + if isinstance(data, dict) and isinstance(data.get("results"), list): + rows = data.get("results") or [] + ids = {str(r.get("_id")) for r in rows if r.get("_id")} + return rows, ids + except Exception: + return [], set() + return [], set() + + +def _save_json_list(path: Path, rows: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps({"results": rows}, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, path) + + +def search_one(client, sample: dict, lib: str, top_k: int) -> dict: + sample_id = str(sample.get("_id")) + user_id = sample_id + question = sample.get("question") or "" + choices = { + "A": sample.get("choice_A") or "", + "B": sample.get("choice_B") or "", + "C": sample.get("choice_C") or "", + "D": sample.get("choice_D") or "", + } + + if lib == "memos": + memories = memos_search(client, user_id, str(question), top_k=top_k) + elif lib == "mem0": + memories = mem0_search(client, user_id, str(question), top_k=top_k) + elif lib == "supermemory": + memories = supermemory_search(client, user_id, str(question), top_k=top_k) + else: + memories = [] + print(f"[{lib} Search] sample_id: {sample_id} search memories: {len(memories)}") + + return { + "_id": sample_id, + "domain": sample.get("domain"), + "sub_domain": sample.get("sub_domain"), + "difficulty": sample.get("difficulty"), + "length": sample.get("length"), + "question": question, + "choices": choices, + "answer": sample.get("answer"), + "memories_used": memories, + } + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Longbench-v2 Product Search Concurrent Script", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate") + + parser.add_argument( + "--dataset-path", + "-s", + default="evaluation/data/longbench_v2/longbenchv2_train.json", + help="Path to JSON file containing samples", + ) + + parser.add_argument( + "--api-url", + default="http://127.0.0.1:8001", + help="API service address (default: http://127.0.0.1:8001)", + ) + + parser.add_argument("--workers", "-c", type=int, default=5, help="Concurrency (default: 5)") + + parser.add_argument( + "--timeout", type=float, default=120.0, help="Request timeout in seconds (default: 120)" + ) + + parser.add_argument( + "--top-k", + "-k", + type=int, + default=20, + help="Number of results to return per search (default: 20)", + ) + + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + + parser.add_argument( + "--limit", + "-l", + type=int, + default=None, + help="Limit number of samples to process (for testing, default all)", + ) + + parser.add_argument( + "--mode", "-m", type=str, default="fast", help="Search mode (default: fast)" + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + print("=" * 60) + print("Longbench-v2 Product Search Concurrent Tool") + print("=" * 60) + + dataset_path = Path(args.dataset_path) + if not dataset_path.exists(): + raise FileNotFoundError(f"Dataset file not found: {dataset_path}") + dataset = _load_dataset_jsonl(dataset_path) + if args.limit is not None: + dataset = dataset[: args.limit] + + output_dir = os.path.join("evaluation/data/longbench_v2", args.version_dir) + os.makedirs(output_dir, exist_ok=True) + output_filename = f"{args.lib}_search_results.json" + output_path = Path(os.path.join(output_dir, output_filename)) + + results, processed_ids = _load_existing_results(output_path) + pending = [s for s in dataset if str(s.get("_id")) not in processed_ids] + if not pending: + return + + client = _get_lib_client(args.lib) + metrics = Metrics() + start_time = time.time() + + with ThreadPoolExecutor(max_workers=args.workers) as executor: + + def do_search(sample: dict) -> dict: + st = time.perf_counter() + r = search_one(client, sample, args.lib, args.top_k) + dur = time.perf_counter() - st + r["duration_ms"] = dur * 1000 + metrics.record(dur, True) + return r + + futures = [executor.submit(do_search, sample) for sample in pending] + for idx, f in enumerate( + tqdm(as_completed(futures), total=len(futures), desc="Searching"), start=1 + ): + try: + r = f.result() + results.append(r) + if idx % 10 == 0: + _save_json_list(output_path, results) + except Exception as e: + metrics.record(0.0, False, str(e)) + print(f"[Search] Error: {e}") + traceback.print_exc() + + _save_json_list(output_path, results) + print(f"[Search] saved {len(results)} rows to {output_path}") + + total_duration = time.time() - start_time + summary = metrics.summary() + combined_obj = { + "results": results, + "perf": { + "summary": summary, + "total_duration": total_duration, + "config": { + "workers": args.workers, + "top_k": args.top_k, + "dataset_path": str(dataset_path), + "limit": args.limit, + "mode": args.mode, + }, + }, + } + tmp = output_path.with_suffix(output_path.suffix + ".tmp") + tmp.write_text(json.dumps(combined_obj, ensure_ascii=False, indent=2), encoding="utf-8") + os.replace(tmp, output_path) + + print("\n" + "=" * 60) + print("Search finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + if summary["errors"]: + print("\nError stats:") + for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" [{count} times] {error[:100]}...") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/mmlongbench/eval_utils/__init__.py b/evaluation/scripts/mmlongbench/eval_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/scripts/mmlongbench/eval_utils/eval_score.py b/evaluation/scripts/mmlongbench/eval_utils/eval_score.py new file mode 100644 index 000000000..02ef6eb53 --- /dev/null +++ b/evaluation/scripts/mmlongbench/eval_utils/eval_score.py @@ -0,0 +1,246 @@ +import re + +from collections import defaultdict +from math import isclose + + +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] + + +def anls_compute(groundtruth, prediction, threshold=0.5): + dist = levenshtein_distance(groundtruth, prediction) + length = max(len(groundtruth.upper()), len(prediction.upper())) + value = 0.0 if length == 0 else float(dist) / float(length) + anls = 1.0 - value + if anls <= threshold: + anls = 0.0 + return anls + + +def is_float_equal( + reference, prediction, include_percentage: bool = False, is_close: float = False +) -> bool: + def get_precision(gt_ans: float) -> int: + precision = 3 + if "." in str(gt_ans): + precision = len(str(gt_ans).split(".")[-1]) + return precision + + reference = float(str(reference).strip().rstrip("%").strip()) + try: + prediction = float(str(prediction).strip().rstrip("%").strip()) + except Exception: + return False + + gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + for item in gt_result: + try: + if is_close and isclose(item, prediction, rel_tol=0.01): + return True + precision = max(min(get_precision(prediction), get_precision(item)), 2) + if round(prediction, precision) == round(item, precision): + return True + except Exception: + continue + return False + + +def get_clean_string(s): + s = str(s).lower().strip() + + for suffix in ["mile", "miles", "million"]: + if s.endswith(suffix): + s = s[: -len(suffix)].strip() + + s = re.sub(r"\s*\([^)]*\)", "", s).strip() + s = re.sub(r"^['\"]|['\"]$", "", s).strip() + s = s.lstrip("$").rstrip("%").strip() + + return s + + +def is_exact_match(s): + flag = False + # Website + if "https://" in s: + flag = True + # code file + if s.endswith((".py", ".ipynb")) or s.startswith("page"): + flag = True + # telephone number + if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s): + flag = True + # time + if "a.m." in s or "p.m." in s: + flag = True + # YYYY-MM-DD + if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s): + flag = True + # YYYY-MM + if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s): + flag = True + # Email address + if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s): + flag = True + return flag + + +def isfloat(num): + try: + float(num) + return True + except ValueError: + return False + + +def eval_score(gt, pred, answer_type): + if answer_type == "Int": + try: + gt, pred = int(gt), int(float(pred)) + except Exception: + pred = "" + score = gt == pred + elif answer_type == "Float": + try: + gt = float(get_clean_string(str(gt))) + pred = float(get_clean_string(str(pred))) + except Exception: + pred = "" + score = is_float_equal(gt, pred, include_percentage=True, is_close=True) + elif answer_type in ["Str", "None"]: + gt = get_clean_string(gt) + pred = get_clean_string(pred) + score = gt == pred if is_exact_match(gt) else anls_compute(gt, pred) + else: + if isinstance(gt, str) and gt.startswith("["): + gt = eval(gt) + if not isinstance(gt, list): + gt = [gt] + if isinstance(pred, str) and pred.startswith("["): + pred = eval(pred) + if not isinstance(pred, list): + pred = [pred] + print(len(gt), len(pred)) + if len(gt) != len(pred): + score = 0.0 + else: + gt = sorted([get_clean_string(a) for a in gt]) + pred = sorted([get_clean_string(a) for a in pred]) + print(gt, pred) + if isfloat(gt[0]) or is_exact_match(gt[0]): + score = "-".join(gt) == "-".join(pred) + else: + score = min( + [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)] + ) + + return float(score) + + +def eval_acc_and_f1(samples): + evaluated_samples = [sample for sample in samples if "score" in sample] + if not evaluated_samples: + return 0.0, 0.0 + + acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples) + try: + recall = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"]) + precision = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"]) + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0 + except Exception: + f1 = 0.0 + + return acc, f1 + + +def show_results(samples, show_path=None): + for sample in samples: + sample["evidence_pages"] = eval(sample["evidence_pages"]) + sample["evidence_sources"] = eval(sample["evidence_sources"]) + + with open(show_path, "w") as f: + acc, f1 = eval_acc_and_f1(samples) + f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n") + f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n") + f.write("-----------------------\n") + + acc_single_page, _ = eval_acc_and_f1( + [sample for sample in samples if len(sample["evidence_pages"]) == 1] + ) + acc_multi_page, _ = eval_acc_and_f1( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable" + ] + ) + acc_neg, _ = eval_acc_and_f1( + [sample for sample in samples if sample["answer"] == "Not answerable"] + ) + + f.write( + "Single-page | Accuracy: {} | Question Number: {}\n".format( + acc_single_page, + len([sample for sample in samples if len(sample["evidence_pages"]) == 1]), + ) + ) + f.write( + "Cross-page | Accuracy: {} | Question Number: {}\n".format( + acc_multi_page, + len( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 + and sample["answer"] != "Not answerable" + ] + ), + ) + ) + f.write( + "Unanswerable | Accuracy: {} | Question Number: {}\n".format( + acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"]) + ) + ) + f.write("-----------------------\n") + + source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list) + for sample in samples: + for answer_source in sample["evidence_sources"]: + source_sample_dict[answer_source].append(sample) + document_type_dict[sample["doc_type"]].append(sample) + for type, sub_samples in source_sample_dict.items(): + f.write( + f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) + + f.write("-----------------------\n") + for type, sub_samples in document_type_dict.items(): + f.write( + f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) diff --git a/evaluation/scripts/mmlongbench/mmlongbench_eval.py b/evaluation/scripts/mmlongbench/mmlongbench_eval.py new file mode 100644 index 000000000..f83e39e2d --- /dev/null +++ b/evaluation/scripts/mmlongbench/mmlongbench_eval.py @@ -0,0 +1,470 @@ +import base64 +import json +import mimetypes +import os +import re +import sys +import time +import traceback + +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from pathlib import Path +from typing import Any + +import openai + +from dotenv import load_dotenv +from tqdm import tqdm + +from evaluation.scripts.utils.eval_score import eval_acc_and_f1, eval_score, show_results +from evaluation.scripts.utils.extract_answer import extract_answer +from evaluation.scripts.utils.prompts import MMLONGBENCH_ANSWER_PROMPT +from memos.log import get_logger + + +logger = get_logger(__name__) + + +load_dotenv() + +# Initialize OpenAI Client +oai_client = openai.Client( + api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +) + + +def _encode_image_to_data_url(image_path: str) -> str | None: + """Encode local image file to base64 data URL for OpenAI-compatible image messages. + + Returns a data URL like: data:image/jpeg;base64,<...> + """ + try: + mime, _ = mimetypes.guess_type(image_path) + if not mime: + # default to jpeg + mime = "image/jpeg" + with open(image_path, "rb") as f: + b64 = base64.b64encode(f.read()).decode("ascii") + return f"data:{mime};base64,{b64}" + except Exception as e: + logger.warning(f"Failed to encode image '{image_path}' to data URL: {e}") + return None + + +# @lru_cache(maxsize=1) +def build_images_index() -> dict[str, str]: + """Scan `./ppt_test_result` recursively and index images by filename. + + New structure example: + ./ppt_test_result//extracted/file_*//auto/images/*.{png,jpg,jpeg,webp,gif} + + Also compatible with previous layouts. Returns mapping: + basename (e.g. img_123.jpg) -> absolute path + """ + base_dir = Path("/Users/tianxingshi/Desktop/lcy/ppt_test_result") + index: dict[str, str] = {} + if not base_dir.exists(): + return index + + # Recursively find any `auto/images` directories under ppt_test_result + for images_dir in base_dir.rglob("auto/images"): + if images_dir.is_dir(): + for img_file in images_dir.iterdir(): + if img_file.is_file(): + index[img_file.name] = str(img_file.resolve()) + return index + + +index_dict = build_images_index() + + +def get_images(sources: list) -> list[str]: + """Extract image absolute paths from metadata sources. + + Supports patterns like: ![](images/.jpg) or any 'images/...jpg' substring. + Falls back to scanning the ppt_test_result index to resolve basenames. + """ + if not sources: + return [] + + # Ensure index exists + + found: list[str] = [] + + md_img_pattern = re.compile(r"\[Image: images/\s*(.+?)\s*-") + images_substr_pattern = re.compile(r"images/[^\s)]+\.(?:png|jpg|jpeg|webp)", re.IGNORECASE) + for src in sources: + if not src: + continue + # 1) markdown image syntax + for m in md_img_pattern.findall(src): + candidate = m.strip() + # if it's a relative like 'images/xxx.jpg', resolve via index + basename = os.path.basename(candidate) + if basename in index_dict: + found.append(index_dict[basename]) + else: + # try direct path (absolute or relative) + p = Path(candidate) + if not p.is_absolute(): + p = Path.cwd() / p + if p.exists(): + found.append(str(p.resolve())) + + # 2) any 'images/xxx.jpg' substring + for m in images_substr_pattern.findall(src): + candidate = m.strip() + basename = os.path.basename(candidate) + if basename in index_dict: + found.append(index_dict[basename]) + else: + p = Path(candidate) + if not p.is_absolute(): + p = Path.cwd() / p + if p.exists(): + found.append(str(p.resolve())) + + # Deduplicate preserving order + dedup: list[str] = [] + seen = set() + for path in found: + if path not in seen: + dedup.append(path) + seen.add(path) + return dedup + + +def add_images_context( + current_messages: list[dict[str, Any]], images: list[str] +) -> list[dict[str, Any]]: + """Append images in OpenAI-compatible multi-part format and ensure message structure. + + - Deduplicates image paths. + - Ensures a system message exists with a concise CN vision instruction. + - Ensures the last user message has multi-part content: [text, image_url...]. + - Uses base64 data URLs. Limits to 6 images. + - In-place modification of `current_messages`. + """ + if not images: + return current_messages + + # Deduplicate images while preserving order + unique_images: list[str] = [] + seen_paths: set[str] = set() + for p in images: + if p not in seen_paths: + unique_images.append(p) + seen_paths.add(p) + + # Locate or create the last user message + user_idx = None + for i in range(len(current_messages) - 1, -1, -1): + if current_messages[i].get("role") == "user": + user_idx = i + break + + user_msg = current_messages[user_idx] + orig_content = user_msg.get("content", "") + + # Normalize user content to multi-part format using original query as text (no fallback) + content_parts: list[dict[str, Any]] + if isinstance(orig_content, str): + content_parts = [{"type": "text", "text": orig_content}] + elif isinstance(orig_content, list): + content_parts = orig_content + else: + content_parts = [{"type": "text", "text": str(orig_content)}] + + # 5) Append up to 3 images as data URLs + limit = 6 + count = 0 + + for img_path in unique_images: + if count >= limit: + break + data_url = _encode_image_to_data_url(img_path) + if data_url: + content_parts.append({"type": "image_url", "image_url": {"url": data_url}}) + count += 1 + user_msg["content"] = content_parts + current_messages[user_idx] = user_msg + logger.info( + f"Attached {count} images to user message (deduplicated from {len(images)}), {json.dumps(current_messages, ensure_ascii=False, indent=2)}" + ) + return current_messages + + +def multimodal_answer( + oai_client, + memories: list[str], + question: str, + top_k: int = 15, + sources: list | None = None, +) -> tuple[str, int | None]: + sources_texts: list[str] = [] + for source in sources[:top_k]: + source = source[0] + content = source.get("content") if isinstance(source, dict) else str(source) + if content: + sources_texts.append(content) + + image_paths = get_images(sources_texts) + system_prompt = MMLONGBENCH_ANSWER_PROMPT.format( + memories="\n\n".join(memories[:top_k]), question=question + ) + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": question}] + messages = add_images_context(messages, image_paths) + for _, msg in enumerate(messages): + if msg.get("role") == "user" and isinstance(msg.get("content"), list): + img_count = sum(1 for p in msg["content"] if p.get("type") == "image_url") + print( + f"DEBUG: user message has {len(memories[:top_k])} memories, {img_count} images attached" + ) + + resp = oai_client.chat.completions.create( + model=args.chat_model, messages=messages, temperature=0 + ) + return resp.choices[0].message.content or "", resp.usage.prompt_tokens + + +def process_single_item(item: dict, index: int, top_k: int = 20) -> dict: + """Process a single evaluation item""" + question = item["question"] + memories = item.get("memories", []) + sources = item.get("sources", []) + if not memories: + result = { + "response": None, + "extracted_res": None, + "pred": None, + "score": 0, + "eval_success": False, + "eval_error": "", + } + try: + # Get model response + response, prompt_tokens = multimodal_answer(oai_client, memories, question, top_k, sources) + + # Extract answer + extracted_res = extract_answer(question, response) + + # Parse extracted answer + try: + pred_ans = ( + extracted_res.split("Answer format:")[0].split("Extracted answer:")[1].strip() + ) + except Exception as e: + print("extract_answer error**********", e) + pred_ans = response.strip() + + # Calculate score + score = eval_score(item.get("answer"), pred_ans, item.get("answer_format", "Str")) + + # Build result + result = { + "response": response, + "extracted_res": extracted_res, + "pred": pred_ans, + "score": score, + "prompt_tokens": prompt_tokens, + "eval_success": True, + "eval_error": None, + } + + except Exception as e: + traceback.print_exc() + result = { + "response": None, + "extracted_res": None, + "pred": None, + "score": 0, + "eval_success": False, + "eval_error": str(e), + } + + return {"index": index, "result": result} + + +def run_eval( + questions_file: str | Path, + output_file: str | Path | None = None, + version_dir: str | Path | None = None, + max_workers: int = 10, + top_k: int = 20, +) -> None: + """ + Run evaluation + + Args: + version_dir: version directory + questions_file: Input questions file path + output_file: Output file path, overwrites input file if None + max_workers: Number of concurrent workers + """ + questions_file = Path(questions_file) + output_file = questions_file if output_file is None else Path(output_file) + + # Read input data + with open(questions_file, encoding="utf-8") as f: + data = json.load(f) + + items = data["results"] + total = len(items) + print(f"[Info] Starting evaluation, total {total} items, concurrency: {max_workers}") + + # Concurrent processing + results_map = {} + start_time = time.time() + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(process_single_item, item, i, top_k): i for i, item in enumerate(items) + } + + # Use tqdm to show progress bar + with tqdm( + total=total, + desc="Evaluation Progress", + unit="items", + bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", + ) as pbar: + for future in as_completed(futures): + result_data = future.result() + idx = result_data["index"] + result = result_data["result"] + results_map[idx] = result + + # Update progress bar, show current score + score = result.get("score", 0) + success = result.get("eval_success", False) + status = f"score={score:.2f}" if success else "ERROR" + pbar.set_postfix_str(status) + pbar.update(1) + + # Write results to each item in original data + for i, item in enumerate(items): + if i in results_map: + item.update(results_map[i]) + + # Update summary info + eval_duration = time.time() - start_time + + # Calculate evaluation statistics + success_count = sum(1 for r in results_map.values() if r.get("eval_success", False)) + failed_count = total - success_count + scores = [r.get("score", 0) for r in results_map.values() if r.get("eval_success", False)] + prompt_tokens_list = [ + r.get("prompt_tokens") + for r in results_map.values() + if r.get("eval_success", False) and isinstance(r.get("prompt_tokens"), int) + ] + avg_prompt_tokens = ( + (sum(prompt_tokens_list) / len(prompt_tokens_list)) if prompt_tokens_list else 0 + ) + + # Calculate acc and f1 + eval_results = [{**items[i], **results_map[i]} for i in range(len(items)) if i in results_map] + acc, f1 = eval_acc_and_f1(eval_results) + + # Update data summary + if "eval_summary" not in data: + data["eval_summary"] = {} + + data["eval_summary"] = { + "eval_duration_seconds": eval_duration, + "total_samples": total, + "success_count": success_count, + "failed_count": failed_count, + "accuracy": acc, + "f1_score": f1, + "avg_score": sum(scores) / len(scores) if scores else 0, + "avg_prompt_tokens": avg_prompt_tokens, + "max_workers": max_workers, + "eval_timestamp": datetime.now().isoformat(), + } + + # Save results to output file + output_file.parent.mkdir(parents=True, exist_ok=True) + with open(output_file, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + print(f"\n{'=' * 60}") + print("[Evaluation Finished]") + print(f" Total samples: {total}") + print(f" Success: {success_count}, Failed: {failed_count}") + print(f" Accuracy: {acc:.4f}") + print(f" F1 Score: {f1:.4f}") + print(f" Average Score: {data['eval_summary']['avg_score']:.4f}") + print(f" Average prompt_tokens: {data['eval_summary']['avg_prompt_tokens']:.2f}") + print(f" Duration: {eval_duration:.2f}s") + print(f" Results saved to: {output_file}") + print(f"{'=' * 60}") + + # Generate detailed report + report_path = version_dir / f"{args.lib}_eval_results.txt" + show_results(eval_results, show_path=str(report_path)) + print(f"[Report] Detailed report saved to: {report_path}") + + # Save concise metrics file + metrics_path = report_path.with_name(report_path.stem + "_metrics.json") + + metrics = { + "accuracy": acc, + "f1_score": f1, + "avg_score": data["eval_summary"]["avg_score"], + "avg_prompt_tokens": data["eval_summary"]["avg_prompt_tokens"], + "total_samples": total, + "success_count": success_count, + "failed_count": failed_count, + "eval_duration_seconds": eval_duration, + "eval_timestamp": data["eval_summary"]["eval_timestamp"], + } + with open(metrics_path, "w", encoding="utf-8") as f: + json.dump(metrics, f, ensure_ascii=False, indent=2) + print(f"[Metrics] Metrics saved to: {metrics_path}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="MMlongbench Evaluation Script") + parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate") + parser.add_argument("--workers", "-w", type=int, default=20, help="Concurrent workers") + parser.add_argument( + "--top-k", "-k", type=int, default=20, help="Top K results to use (default: 20)" + ) + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + parser.add_argument("--chat-model", "-m", default=None, help="chat model name") + + args = parser.parse_args() + + print("=" * 60) + print("MMLongBench-doc Product Eval Tool") + print("=" * 60) + + print("[Response model]: ", os.getenv("CHAT_MODEL")) + + base_dir = Path("evaluation/data/mmlongbench") + version_dir = base_dir / args.version_dir + input_filename = f"{args.lib}_search_results.json" + input_path = version_dir / input_filename + + if not input_path.exists(): + print(f"Error: Input file not found: {input_path}") + sys.exit(1) + + output_path = input_path + + print(f"[Info] Input file: {input_path}") + print(f"[Info] Output file: {output_path}") + print(f"[Response Model]: {args.chat_model}") + + run_eval( + questions_file=input_path, + output_file=output_path, + version_dir=version_dir, + max_workers=args.workers, + top_k=args.top_k, + ) diff --git a/evaluation/scripts/mmlongbench/mmlongbench_ingestion.py b/evaluation/scripts/mmlongbench/mmlongbench_ingestion.py new file mode 100644 index 000000000..618caf09e --- /dev/null +++ b/evaluation/scripts/mmlongbench/mmlongbench_ingestion.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 + +import argparse +import json +import threading +import time + +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +from langchain_text_splitters import Language, RecursiveCharacterTextSplitter + +from evaluation.scripts.utils.metrics import Metrics + + +def read_filenames(filepath: str) -> list[str]: + """ + Read filename list from file + Supports one filename per line, automatically removes empty lines and whitespace + """ + filenames = [] + with open(filepath, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: # Skip empty lines + filenames.append(line) + return filenames + + +def run_concurrent_add( + lib: str, + filenames: list[str], + url_prefix: str, + user_prefix: str, + workers: int, + source_type: str = "extreme_multimodal", + mode: str = "fine", + async_mode: str = "sync", +) -> dict: + """ + Execute concurrent add operations + + Args: + lib: Client name + filenames: List of filenames + url_prefix: URL prefix + user_prefix: User ID prefix + workers: Concurrency + source_type: Source type + mode: Mode + async_mode: Async mode + + Returns: + Statistics result + """ + + client = _get_lib_client(lib) + metrics = Metrics() + total_files = len(filenames) + completed = 0 + completed_lock = threading.Lock() + user_id = user_prefix + + def add_single_file(filename: str, doc_id: str = ""): + nonlocal completed + + file_id = filename # 文件名作为file_id + file_data = f"{url_prefix.rstrip('/')}/{filename}" # URL前缀 + 文件名 + + base_dir = Path("ppt_test_result") + all_md_files = list(base_dir.rglob("*.md")) + stem = Path(file_id).stem.lower() + name = file_id.lower() + md_path = "" + for md in all_md_files: + pstr = str(md).lower() + if (stem and stem in pstr) or (name and name in pstr): + md_path = md + text = md_path.read_text(encoding="utf-8", errors="ignore") + + start_time = time.perf_counter() + user_id = user_prefix + "_" + doc_id + writable_cube_ids = [user_id] + chat_time = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + result = None + try: + if lib == "memos": + messages = [ + { + "role": "user", + "content": [ + { + "type": "file", + "file": { + "file_id": file_id, + "filename": file_data, + "file_data": file_data, + }, + } + ], + "chat_time": chat_time, + } + ] + result = client.add( + messages=messages, + user_id=user_id, + writable_cube_ids=writable_cube_ids, + source_type=source_type, + mode=mode, + async_mode=async_mode, + ) + elif lib == "supermemory": + result = client.add(content=text, user_id=user_id) + elif lib == "mem0": + chunker = RecursiveCharacterTextSplitter.from_language( + language=Language.PYTHON, chunk_size=5120, chunk_overlap=128 + ) + paragraphs = [p for p in chunker.split_text(text) if p.strip()] + messages = [{"role": "user", "content": p} for p in paragraphs] + ts = int(time.time()) + + result = client.add(messages=messages, user_id=doc_id, timestamp=ts, batch_size=10) + + duration = time.perf_counter() - start_time + metrics.record(duration, True) + + with completed_lock: + completed += 1 + print( + f"[{completed}/{total_files}] ✓ Success: {filename} ({duration * 1000:.0f}ms)" + ) + + return True, result + + except Exception as e: + duration = time.perf_counter() - start_time + error_msg = str(e) + metrics.record(duration, False, error_msg) + + with completed_lock: + completed += 1 + print(f"[{completed}/{total_files}] ✗ Failed: {filename} - {error_msg[:100]}") + + return False, error_msg + + print(f"\nStarting concurrent add for {total_files} files...") + print(f"Concurrency: {workers}") + print(f"User ID: {user_id}") + print(f"URL prefix: {url_prefix}") + print("-" * 60) + + start_time = time.time() + + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = [] + for _, filename in enumerate(filenames): + doc_id = filename[:-3] + ".pdf" + future = executor.submit(add_single_file, filename, doc_id) + futures.append((filename, future)) + + # Wait for all tasks to complete + results = [] + for filename, future in futures: + try: + success, result = future.result() + results.append({"filename": filename, "success": success, "result": result}) + except Exception as e: + results.append({"filename": filename, "success": False, "result": str(e)}) + + end_time = time.time() + total_duration = end_time - start_time + + # Print statistics + summary = metrics.summary() + + print("\n" + "=" * 60) + print("Ingestion finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + if summary["errors"]: + print("\nError statistics:") + for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" [{count} times] {error[:100]}...") + + return {"summary": summary, "total_duration": total_duration, "results": results} + + +def parse_args(): + parser = argparse.ArgumentParser( + description="MMLongbench-doc Product Add Concurrent Script", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate") + parser.add_argument( + "--filenames-file", + "-f", + default="evaluation/data/mmlongbench/md_file_list.txt", + help="Path to text file containing list of filenames (one per line)", + ) + + parser.add_argument( + "--url-prefix", + "-u", + default="https://memos-knowledge-base-file-pre.oss-cn-shanghai.aliyuncs.com/ppt_md_files/", + help="URL prefix to be prepended to filenames", + ) + + parser.add_argument( + "--api-url", + default="http://127.0.0.1:8001", + help="MemOS API address (default: http://127.0.0.1:8001)", + ) + + parser.add_argument("--workers", "-w", type=int, default=5, help="Concurrency (default: 5)") + + parser.add_argument( + "--timeout", type=float, default=1200, help="Request timeout in seconds (default: 120)" + ) + + parser.add_argument( + "--source-type", + default="extreme_multimodal", + help="Source type (default: extreme_multimodal)", + ) + + parser.add_argument( + "--mode", default="fine", choices=["fine", "fast"], help="Processing mode (default: fine)" + ) + + parser.add_argument( + "--async-mode", default="sync", choices=["sync", "async"], help="Async mode (default: sync)" + ) + + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + + return parser.parse_args() + + +def _get_lib_client(lib: str): + if lib == "memos": + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + + +def main(): + args = parse_args() + + print("=" * 60) + print("MMLongbench-doc Product Add Concurrent Tool") + print("=" * 60) + + # Read filename list + print(f"\nReading filename list: {args.filenames_file}") + try: + filenames = read_filenames(args.filenames_file) + print(f"Read {len(filenames)} filenames") + if len(filenames) == 0: + print("Error: Filename list is empty!") + return + + # Show first few filenames + print("First 5 filenames:") + for fn in filenames[:5]: + print(f" - {fn}") + if len(filenames) > 5: + print(f" ... and {len(filenames) - 5} more files") + + except FileNotFoundError: + print(f"Error: File not found {args.filenames_file}") + return + except Exception as e: + print(f"Error: Failed to read file - {e}") + return + + # Execute concurrent add + result = run_concurrent_add( + lib=args.lib, + filenames=filenames, + url_prefix=args.url_prefix, + user_prefix=args.version_dir, + workers=args.workers, + source_type=args.source_type, + mode=args.mode, + async_mode=args.async_mode, + ) + + # Determine output file path + import os + + version_output_dir = os.path.join("evaluation/data/mmlongbench", args.version_dir) + os.makedirs(version_output_dir, exist_ok=True) + output_path = os.path.join(version_output_dir, f"{args.lib}_add_results.json") + + # Save results to file + if output_path: + with open(output_path, "w", encoding="utf-8") as f: + # Remove non-serializable content + output_data = { + "summary": result["summary"], + "total_duration": result["total_duration"], + "config": { + "filenames_file": args.filenames_file, + "url_prefix": args.url_prefix, + "api_url": args.api_url, + "concurrency": args.workers, + "source_type": args.source_type, + "mode": args.mode, + "async_mode": args.async_mode, + "version_dir": args.version_dir, + }, + } + json.dump(output_data, f, ensure_ascii=False, indent=2) + print(f"\nResults saved to: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/mmlongbench/mmlongbench_old.py b/evaluation/scripts/mmlongbench/mmlongbench_old.py new file mode 100644 index 000000000..9a887e987 --- /dev/null +++ b/evaluation/scripts/mmlongbench/mmlongbench_old.py @@ -0,0 +1,403 @@ +import argparse +import json +import os +import sys +import time + +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from dotenv import load_dotenv +from eval.eval_score import eval_acc_and_f1, eval_score, show_results # type: ignore +from eval.extract_answer import extract_answer # type: ignore +from langchain_text_splitters import Language, RecursiveCharacterTextSplitter +from openai import OpenAI + + +# Ensure project paths for imports +PROJECT_ROOT = Path(__file__).resolve().parents[3] +SCRIPTS_ROOT = Path(__file__).resolve().parents[1] +SRC_ROOT = PROJECT_ROOT / "src" +sys.path.append(str(SCRIPTS_ROOT)) +sys.path.append(str(SRC_ROOT)) +load_dotenv() +max_retries = 5 + + +def iter_markdown_files(base_dir: str | Path) -> Iterator[Path]: + base = Path(base_dir) + if not base.exists(): + return + # glob all 'auto/*.md' + for md in base.rglob("auto/*.md"): + if md.is_file(): + yield md + + +def _get_clients(): + from utils.client import MemosApiClient # type: ignore + from utils.prompts import MEMOS_CONTEXT_TEMPLATE # type: ignore + + from memos.mem_os.core import add_images_context, get_images # type: ignore + + memos_client = MemosApiClient() + openai_client = OpenAI( + api_key=os.getenv("CHAT_MODEL_API_KEY"), base_url=os.getenv("CHAT_MODEL_BASE_URL") + ) + return memos_client, openai_client, MEMOS_CONTEXT_TEMPLATE, add_images_context, get_images + + +def _get_lib_client(lib: str): + if lib == "mem0": + from utils.client import Mem0Client # type: ignore + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from utils.client import SupermemoryClient # type: ignore + + return SupermemoryClient() + from utils.client import MemosApiClient # type: ignore + + return MemosApiClient() + + +def _load_existing_results(output_path): + completed_pairs: set[tuple[str, str]] = set() + completed_docs: set[str] = set() + existing: list[dict] = [] + p = Path(output_path) + if p.exists(): + try: + existing = json.loads(p.read_text(encoding="utf-8")) + for r in existing: + did = str(r.get("doc_id") or "").strip() + q = str(r.get("question") or "").strip() + # normalize whitespace for robust resume + did_norm = did + q_norm = " ".join(q.split()) + if did: + completed_docs.add(did) + if did_norm and q_norm: + completed_pairs.add((did_norm, q_norm)) + except Exception: + existing = [] + return existing, completed_pairs, completed_docs + + +def add_context(client, doc_id: str, md_path: Path, lib) -> None: + text = md_path.read_text(encoding="utf-8", errors="ignore") + print(f"[Add context] doc_id={doc_id} path={md_path}") + + if lib == "memos": + messages = [ + { + "role": "user", + "content": { + "type": "file", + "file": {"file_id": doc_id, "filename": doc_id, "file_data": text}, + }, + } + ] + try: + client.add(messages=messages, user_id=doc_id, conv_id=doc_id) + print(f"[Add-memos] user={doc_id} total={len(messages)}") + except Exception as e: + print(f"[Add-memos] failed: {e}") + + elif lib == "mem0": + chunker = RecursiveCharacterTextSplitter.from_language( + language=Language.PYTHON, chunk_size=512, chunk_overlap=128 + ) + paragraphs = [p for p in chunker.split_text(text) if p.strip()] + + messages = [{"role": "user", "content": p} for p in paragraphs] + ts = int(time.time()) + try: + client.add(messages=messages, user_id=doc_id, timestamp=ts, batch_size=10) + print(f"[Add-mem0] user={doc_id} total={len(messages)}") + except Exception as e: + print(f"[Add-mem0] failed: {e}") + elif lib == "supermemory": + try: + doc_id = "stx_" + doc_id.replace(".pdf", "") + client.add(content=text, user_id=doc_id) + print(f"[Add-supermemory] user={doc_id}") + except Exception as e: + print(f"[Add-supermemory] failed: {e}") + + +def memos_search( + client, get_images, user_id: str, query: str, top_k: int = 15 +) -> tuple[list[str], list[str]]: + results = client.search(query=query, user_id=user_id, top_k=top_k) + memories = results["text_mem"][0]["memories"] + mem_texts = [m["memory"] for m in memories] + + # Collect possible image paths from memory texts (and any source content if present) + sources_texts: list[str] = [] + for m in memories: + srcs = (m.get("metadata", {}) or {}).get("sources") or [] + for s in srcs: + content = s.get("content") if isinstance(s, dict) else str(s) + if content: + sources_texts.append(content) + + image_paths = get_images(sources_texts) + print( + f"[Search] user={user_id} top_k={top_k} memories={len(memories), len(mem_texts)} images={len(image_paths)}" + ) + return mem_texts, image_paths + + +def mem0_search( + client, get_images, user_id: str, query: str, top_k: int = 15 +) -> tuple[list[str], list[str]]: + res = client.search(query, user_id, top_k) + results = res.get("results", []) + mem_texts = [m.get("memory", "") for m in results if m.get("memory")] + image_paths = get_images(mem_texts) + print( + f"[Search] user={user_id} top_k={top_k} memories={len(results)} images={len(image_paths)}" + ) + return mem_texts, image_paths + + +def supermemory_search( + client, get_images, user_id: str, query: str, top_k: int = 15 +) -> tuple[list[str], list[str]]: + chunk_list = client.search(query, user_id, top_k) + image_paths = get_images(chunk_list) + print( + f"[Search] user={user_id} top_k={top_k} memories={len(chunk_list)} images={len(image_paths)}" + ) + return chunk_list, image_paths + + +def multimodal_answer( + add_images_context, oai_client, memories: list[str], question: str, image_paths: list[str] +) -> tuple[str, int]: + from memos.mem_os.core import MOSCore # type: ignore + + system_prompt = MOSCore._build_system_prompt(MOSCore.__new__(MOSCore), memories) + + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": question}] + add_images_context(messages, image_paths) + print("[Response model]:", os.getenv("CHAT_MODEL")) + resp = oai_client.chat.completions.create( + model=os.getenv("CHAT_MODEL"), messages=messages, temperature=0 + ) + return resp.choices[0].message.content or "", resp.usage.prompt_tokens + + +def main(): + parser = argparse.ArgumentParser(description="MMLongBench Evaluation Script") + parser.add_argument( + "--lib", + type=str, + default="supermemory", + help="Backend library to use (memos, mem0, supermemory)", + ) + args = parser.parse_args() + + # Hardcoded parameters + ppt_root = "ppt_test_result" + questions_file = "evaluation/data/mmlongbench/samples.json" + top_k = 15 + workers = 4 + lib = args.lib + + print("[Memory util]:", lib) + + start_time = time.time() + client, oai_client, memos_context_template, add_images_context, get_images = _get_clients() + if questions_file and Path(questions_file).exists(): + data = json.loads(Path(questions_file).read_text(encoding="utf-8")) + print(f"[Load] samples={len(data)} file={questions_file}") + + # Build allowed doc list from documents directory (align with eval_docs grouping by doc_id) + docs_dir = Path("evaluation/data/mmlongbench/documents") + allowed_docs: set[str] = set() + if docs_dir.exists(): + for f in docs_dir.iterdir(): + if f.is_file(): + allowed_docs.add(f.name) + if allowed_docs: + print(f"[Filter] allowed_docs={len(allowed_docs)} from {docs_dir}") + + # Determine doc_ids present in samples and apply allowed filter + doc_ids_in_samples = {str(s.get("doc_id") or "").strip() for s in data if s.get("doc_id")} + doc_list = [d for d in doc_ids_in_samples if (not allowed_docs or d in allowed_docs)] + doc_list.sort() + + output_path = f"evaluation/data/mmlongbench/test/{lib}_add_fine_search_fine_results.json" + report_path = Path( + f"evaluation/data/mmlongbench/test/{lib}_add_fine_search_fine_report.txt" + ) + + # Resume state + existing, completed_pairs, completed_docs = _load_existing_results(output_path) + print(f"[Resume] loaded_results={len(existing)} completed_docs={len(completed_docs)}") + results: list[dict] = list(existing) + ingested_doc_ids: set[str] = set(completed_docs) + + base_dir = Path(ppt_root) + all_md_files: list[Path] = [] + if base_dir.exists(): + all_md_files = list(base_dir.rglob("*.md")) + + def _find_md_for_doc(doc_id_val: str) -> Path | None: + stem = Path(doc_id_val).stem.lower() + name = doc_id_val.lower() + for md in all_md_files: + pstr = str(md).lower() + if (stem and stem in pstr) or (name and name in pstr): + return md + return None + + to_ingest: list[tuple[str, Path]] = [] + for did in doc_list: + if did and did not in ingested_doc_ids: + mdp = _find_md_for_doc(did) + if mdp is not None: + to_ingest.append((did, mdp)) + else: + print(f"[Skip] markdown not found for doc_id={did}") + + # Phase 1: Ingestion + print("Phase 1: Ingesting context...") + if to_ingest: + print(f"[Ingest-Concurrent] tasks={len(to_ingest)} from {ppt_root}") + + def _ingest_one(doc_id_local: str, md_path_local: Path, lib_local: str = lib) -> str: + user_id_local = doc_id_local + c_local = client if lib_local == "memos" else _get_lib_client(lib_local) + add_context(c_local, user_id_local, md_path_local, lib=lib_local) + return doc_id_local + + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = [executor.submit(_ingest_one, did, mdp) for did, mdp in to_ingest] + for f in as_completed(futures): + try: + done_id = f.result() + ingested_doc_ids.add(done_id) + except Exception as e: + print(f"[Add-Error] {e}") + + # Phase 2: Evaluation + print("Phase 2: Evaluating...") + for doc_id in doc_list: + if not doc_id: + continue + print(f"\n===== [Doc] {doc_id} =====") + user_id = doc_id + + doc_samples = [s for s in data if str(s.get("doc_id") or "").strip() == doc_id] + + def _process_item( + item: dict, + doc_id_local: str = doc_id, + user_id_local: str = user_id, + lib_local: str = lib, + ) -> dict: + q = item["question"] + q_norm_local = " ".join(str(q).split()) + if (doc_id_local, q_norm_local) in completed_pairs: + return {"skip": True} + if lib_local == "memos": + memories, images = memos_search( + client, get_images, user_id_local, q, top_k=top_k + ) + elif lib_local == "mem0": + c_local = _get_lib_client(lib_local) + memories, images = mem0_search( + c_local, get_images, user_id_local, q, top_k=top_k + ) + elif lib_local == "supermemory": + c_local = _get_lib_client(lib_local) + memories, images = supermemory_search( + c_local, get_images, user_id_local, q, top_k=top_k + ) + else: + memories, images = [], [] + resp, prompt_tokens = multimodal_answer( + add_images_context, oai_client, memories, q, images + ) + with open( + Path("evaluation/scripts/mmlongbench/eval/prompt_for_answer_extraction.md"), + encoding="utf-8", + ) as f: + prompt_local = f.read() + extracted_res_local = extract_answer(q, resp, prompt_local) + try: + pred_ans_local = ( + extracted_res_local.split("Answer format:")[0] + .split("Extracted answer:")[1] + .strip() + ) + except Exception: + pred_ans_local = resp.strip() + score_local = eval_score( + item.get("answer"), pred_ans_local, item.get("answer_format", "Str") + ) + sr = dict(item) + sr["response"] = resp + sr["extracted_res"] = extracted_res_local + sr["pred"] = pred_ans_local + sr["score"] = score_local + sr["q_norm"] = q_norm_local + sr["images"] = images + sr["prompt_tokens"] = prompt_tokens + sr["memory_used"] = memories + return sr + + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = [executor.submit(_process_item, it) for it in doc_samples] + for f in as_completed(futures): + try: + res = f.result() + if res.get("skip"): + continue + print("[images used]:", res.get("images") or []) + print( + f"[QA] doc_id={doc_id} images={len(res.get('images') or [])} score={res.get('score')}" + ) + results.append( + {k: v for k, v in res.items() if k not in ("q_norm", "images")} + ) + completed_pairs.add((doc_id, res.get("q_norm") or "")) + + print("[Question]:", res.get("question")) + print("[Answer]:", res.get("pred")) + print("[Ground truth]:", res.get("answer")) + print("[Score]:", res.get("score")) + print("[Prompt Tokens]:", res.get("prompt_tokens")) + out_json = Path(output_path) + out_json.parent.mkdir(parents=True, exist_ok=True) + out_json.write_text( + json.dumps(results, ensure_ascii=False, indent=2), encoding="utf-8" + ) + acc, f1 = eval_acc_and_f1(results) + total_target = sum( + 1 for s in data if str(s.get("doc_id") or "") in doc_list + ) + total_tokens = sum(r.get("prompt_tokens", 0) for r in results) + avg_tokens = round(total_tokens / len(results), 2) if results else 0 + print( + f"[Metric] acc={acc} f1={f1} avg_tokens={avg_tokens} processed={len(results)}/{total_target}" + ) + except Exception as e: + print(f"[Error] processing item: {e}") + + show_results(results, show_path=str(report_path)) + print(f"[Report] saved to {report_path}") + + end_time = time.time() + total_duration = end_time - start_time + with open(report_path, "a", encoding="utf-8") as f: + f.write(f"\nTotal Evaluation Time: {total_duration:.2f} seconds\n") + print(f"[Time] Total duration: {total_duration:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/mmlongbench/mmlongbench_search.py b/evaluation/scripts/mmlongbench/mmlongbench_search.py new file mode 100644 index 000000000..59c8e86ea --- /dev/null +++ b/evaluation/scripts/mmlongbench/mmlongbench_search.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 + +import argparse +import json +import threading +import time + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from evaluation.scripts.utils.metrics import Metrics + + +def load_samples(filepath: str) -> list[dict]: + """ + Read sample list from JSON file + """ + with open(filepath, encoding="utf-8") as f: + samples = json.load(f) + return samples + + +def memos_search( + client, user_id: str, query: str, top_k: int, mode: str, readable_cube_ids: list[str] +) -> tuple[list[str], list[str]]: + results = client.search( + query=query, user_id=user_id, readable_cube_ids=readable_cube_ids, top_k=top_k, mode=mode + ) + sources = [] + if "text_mem" in results["data"] and results["data"]["text_mem"]: + memories = results["data"]["text_mem"][0].get("memories", []) + sources.extend( + m["metadata"].get("sources", []) for m in memories if m["metadata"].get("sources", []) + ) + return [m.get("memory", "") for m in memories], sources + return [], [] + + +def mem0_search(client, user_id: str, query: str, top_k: int = 15) -> tuple[list[str], list[str]]: + res = client.search(query, user_id, top_k) + results = res.get("results", []) + mem_texts = [m.get("memory", "") for m in results if m.get("memory")] + return mem_texts, mem_texts + + +def supermemory_search( + client, user_id: str, query: str, top_k: int = 15 +) -> tuple[list[str], list[str]]: + chunk_list = client.search(query, user_id, top_k) + print(chunk_list) + + return chunk_list, chunk_list + + +def _get_lib_client(lib: str): + if lib == "memos": + from evaluation.scripts.utils.client import MemosApiClient + + return MemosApiClient() + if lib == "mem0": + from evaluation.scripts.utils.client import Mem0Client + + return Mem0Client(enable_graph=False) + if lib == "supermemory": + from evaluation.scripts.utils.client import SupermemoryClient + + return SupermemoryClient() + + +def run_concurrent_search( + lib: str, samples: list[dict], user_prefix: str, concurrency: int, top_k: int, mode: str +) -> dict: + """ + Execute concurrent search operations + + Args: + lib: Client name + samples: Sample list, each containing doc_id and question + user_prefix: User ID prefix + concurrency: Concurrency + top_k: Number of results to return + mode: Query mode ['fast', 'fine'] + + Returns: + Search results + """ + + client = _get_lib_client(lib) + metrics = Metrics() + total_samples = len(samples) + completed = 0 + completed_lock = threading.Lock() + + # 用于存储所有搜索结果 + all_results = [] + results_lock = threading.Lock() + + user_id = user_prefix + + def search_single(sample: dict, index: int): + nonlocal completed + + doc_id = sample.get("doc_id", "") + question = sample.get("question", "") + + user_id = user_prefix + "_" + doc_id + readable_cube_ids = [user_id] + start_time = time.perf_counter() + try: + memories, sources = [], [] + if lib == "memos": + memories, sources = memos_search( + client=client, + query=question, + user_id=user_id, + readable_cube_ids=readable_cube_ids, + top_k=top_k, + mode=mode, + ) + elif lib == "mem0": + memories, sources = mem0_search(client, user_id, question, top_k=top_k) + elif lib == "supermemory": + memories, sources = supermemory_search(client, user_id, question, top_k=top_k) + + duration = time.perf_counter() - start_time + metrics.record(duration, True) + + result = { + "index": index, + "doc_id": doc_id, + "question": question, + "answer": sample.get("answer", ""), + "evidence_pages": sample.get("evidence_pages", ""), + "evidence_sources": sample.get("evidence_sources", ""), + "answer_format": sample.get("answer_format", ""), + "doc_type": sample.get("doc_type", ""), + "memories": memories, + "sources": sources, + "memory_count": len(memories), + "success": True, + "duration_ms": duration * 1000, + "mode": mode, + } + + with results_lock: + all_results.append(result) + + with completed_lock: + completed += 1 + print( + f"[{completed}/{total_samples}] ✓ Success: {doc_id[:30]}... ({duration * 1000:.0f}ms, {len(memories)} memories)" + ) + + return True, result + + except Exception as e: + duration = time.perf_counter() - start_time + error_msg = str(e) + metrics.record(duration, False, error_msg) + + result = { + "index": index, + "doc_id": doc_id, + "question": question, + "answer": sample.get("answer", ""), + "evidence_pages": sample.get("evidence_pages", ""), + "evidence_sources": sample.get("evidence_sources", ""), + "answer_format": sample.get("answer_format", ""), + "doc_type": sample.get("doc_type", ""), + "memories": [], + "memory_count": 0, + "success": False, + "error": error_msg, + "duration_ms": duration * 1000, + } + + with results_lock: + all_results.append(result) + + with completed_lock: + completed += 1 + print( + f"[{completed}/{total_samples}] ✗ Failed: {doc_id[:30]}... - {error_msg[:80]}" + ) + + return False, result + + print(f"\nStarting concurrent search for {total_samples} questions...") + print(f"Concurrency: {concurrency}") + print(f"User ID: {user_id}") + print(f"Top-K: {top_k}") + print("-" * 60) + + start_time = time.time() + + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = [] + for i, sample in enumerate(samples): + future = executor.submit(search_single, sample, i) + futures.append(future) + + # Wait for all tasks to complete + for future in as_completed(futures): + try: + future.result() + except Exception as e: + print(f"Task execution exception: {e}") + + end_time = time.time() + total_duration = end_time - start_time + + # Sort results by original index + all_results.sort(key=lambda x: x["index"]) + + # Print statistics + summary = metrics.summary() + + print("\n" + "=" * 60) + print("Search finished! Statistics:") + print("=" * 60) + print(f"Total duration: {total_duration:.2f}s") + print(f"Success: {summary['counts']['success']} / Failed: {summary['counts']['failed']}") + + if summary["stats"]: + stats = summary["stats"] + qps = stats["count"] / total_duration if total_duration > 0 else 0 + print(f"QPS: {qps:.2f}") + print("Latency stats (ms):") + print(f" Mean: {stats['mean']:.2f}") + print(f" Median: {stats['median']:.2f}") + print(f" Min: {stats['min']:.2f}") + print(f" Max: {stats['max']:.2f}") + print(f" P95: {stats['p95']:.2f}") + print(f" P99: {stats['p99']:.2f}") + + if summary["errors"]: + print("\nError statistics:") + for error, count in sorted(summary["errors"].items(), key=lambda x: x[1], reverse=True)[:5]: + print(f" [{count} times] {error[:100]}...") + + return {"summary": summary, "total_duration": total_duration, "results": all_results} + + +def parse_args(): + parser = argparse.ArgumentParser( + description="MMLongbench-doc Product Search Concurrent Script", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument("--lib", "-b", required=True, help="Product name to evaluate") + + parser.add_argument( + "--samples-file", + "-s", + default="evaluation/data/mmlongbench/samples.json", + help="Path to JSON file containing samples", + ) + + parser.add_argument( + "--api-url", + default="http://127.0.0.1:8001", + help="API service address (default: http://127.0.0.1:8001)", + ) + + parser.add_argument("--api-key", default="", help="API key (optional)") + + parser.add_argument("--workers", "-c", type=int, default=5, help="Concurrency (default: 5)") + + parser.add_argument( + "--timeout", type=float, default=120.0, help="Request timeout in seconds (default: 120)" + ) + + parser.add_argument( + "--top-k", + "-k", + type=int, + default=20, + help="Number of results to return per search (default: 20)", + ) + + parser.add_argument("--version-dir", "-v", default=None, help="Version directory name") + + parser.add_argument( + "--limit", + "-l", + type=int, + default=None, + help="Limit number of samples to process (for testing, default all)", + ) + + parser.add_argument( + "--mode", "-m", type=str, default="fast", help="Search mode (default: fast)" + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + print("=" * 60) + print("MMLongbench-doc Product Search Concurrent Tool") + print("=" * 60) + + # Read sample data + samples_path = "evaluation/data/mmlongbench/samples.json" + print(f"\nReading sample file: {samples_path}") + try: + samples = load_samples(samples_path) + print(f"Total {len(samples)} samples read") + + # Limit number of samples + if args.limit and args.limit > 0: + samples = samples[: args.limit] + print(f"Limiting to first {len(samples)} samples") + + if len(samples) == 0: + print("Error: Sample list is empty!") + return + + # Show first few samples + print("First 3 samples:") + for sample in samples[:3]: + doc_id = sample.get("doc_id", "N/A") + question = sample.get("question", "N/A")[:50] + print(f" - {doc_id}: {question}...") + if len(samples) > 3: + print(f" ... and {len(samples) - 3} more samples") + + except FileNotFoundError: + print(f"Error: File not found {args.samples_file}") + return + except json.JSONDecodeError as e: + print(f"Error: JSON parse failed - {e}") + return + except Exception as e: + print(f"Error: Failed to read file - {e}") + return + + # Execute concurrent search + result = run_concurrent_search( + lib=args.lib, + samples=samples, + user_prefix=args.version_dir, + concurrency=args.workers, + top_k=args.top_k, + mode=args.mode, + ) + + # Determine output file path + import os + + output_dir = os.path.join("evaluation/data/mmlongbench", args.version_dir) + os.makedirs(output_dir, exist_ok=True) + output_filename = f"{args.lib}_search_results.json" + output_path = os.path.join(output_dir, output_filename) + + # Save results + output_data = { + "summary": result["summary"], + "total_duration": result["total_duration"], + "config": { + "samples_file": args.samples_file, + "api_url": args.api_url, + "workers": args.workers, + "top_k": args.top_k, + }, + "results": result["results"], + } + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(output_data, f, ensure_ascii=False, indent=2) + + print(f"\nResults saved to: {output_path}") + + # Calculate valid results + success_results = [r for r in result["results"] if r["success"]] + total_memories = sum(r["memory_count"] for r in success_results) + avg_memories = total_memories / len(success_results) if success_results else 0 + print(f"Average {avg_memories:.1f} memories returned per question") + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/run_hotpot_eval.sh b/evaluation/scripts/run_hotpot_eval.sh new file mode 100755 index 000000000..3ee5177c6 --- /dev/null +++ b/evaluation/scripts/run_hotpot_eval.sh @@ -0,0 +1,48 @@ +#!/bin/bash +set -e + +ROOT_DIR=$(cd "$(dirname "$0")/../.." && pwd) +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR" + +# Common parameters +LIB="mem0" +WORKERS=20 +TOPK=7 +ADD_MODE="fine" +SEARCH_MODE="fine" +VERSION_DIR="test_0101_07" +ASYNC_MODE="sync" +CHAT_MODEL="gpt-4o-mini" +LIMIT=100 + +# Add / Ingestion +#echo "Running hotpot_ingestion.py..." +#python -m evaluation.scripts.hotpot.hotpot_ingestion \ +# --lib "$LIB" \ +# --workers "$WORKERS" \ +# --version-dir "$VERSION_DIR" \ +# --mode "$ADD_MODE" \ +# --async-mode "$ASYNC_MODE" \ +# --limit "$LIMIT" + +# Search +#echo "Running hotpot_search.py..." +#python -m evaluation.scripts.hotpot.hotpot_search \ +# --lib "$LIB" \ +# --workers "$WORKERS" \ +# --version-dir "$VERSION_DIR" \ +# --top-k "$TOPK" \ +# --search-mode "$SEARCH_MODE" \ +# --limit "$LIMIT" + +# Eval +echo "Running hotpot_eval.py..." +python -m evaluation.scripts.hotpot.hotpot_eval \ + --lib "$LIB" \ + --version-dir "$VERSION_DIR" \ + --workers "$WORKERS" \ + --search-mode "$SEARCH_MODE" \ + --chat-model "$CHAT_MODEL" + +echo "All scripts completed successfully!" diff --git a/evaluation/scripts/run_longbench_v2_eval.sh b/evaluation/scripts/run_longbench_v2_eval.sh index 917c57bfb..9af4572fc 100755 --- a/evaluation/scripts/run_longbench_v2_eval.sh +++ b/evaluation/scripts/run_longbench_v2_eval.sh @@ -1,110 +1,45 @@ #!/bin/bash +set -e -# Common parameters for all scripts -LIB="memos-api" -VERSION="long-bench-v2-1208-1556-async" -WORKERS=10 -TOPK=20 -MAX_SAMPLES="" # Empty means all samples -WAIT_INTERVAL=2 # seconds between polls -WAIT_TIMEOUT=900 # seconds per user - -# Parse command line arguments -while [[ $# -gt 0 ]]; do - case $1 in - --lib) - LIB="$2" - shift 2 - ;; - --version) - VERSION="$2" - shift 2 - ;; - --workers) - WORKERS="$2" - shift 2 - ;; - --top_k) - TOPK="$2" - shift 2 - ;; - --max_samples) - MAX_SAMPLES="$2" - shift 2 - ;; - *) - echo "Unknown option: $1" - exit 1 - ;; - esac -done - -# Build max_samples argument -MAX_SAMPLES_ARG="" -if [ -n "$MAX_SAMPLES" ]; then - MAX_SAMPLES_ARG="--max_samples $MAX_SAMPLES" -fi - -echo "Running LongBench v2 evaluation with:" -echo " LIB: $LIB" -echo " VERSION: $VERSION" -echo " WORKERS: $WORKERS" -echo " TOPK: $TOPK" -echo " MAX_SAMPLES: ${MAX_SAMPLES:-all}" -echo "" - -# Step 2: Search -echo "" -echo "==========================================" -echo "Step 2: Running longbench_v2_search.py..." -echo "==========================================" -python scripts/long_bench-v2/longbench_v2_search.py \ - --lib $LIB \ - --version $VERSION \ - --top_k $TOPK \ - --workers $WORKERS \ - $MAX_SAMPLES_ARG - -if [ $? -ne 0 ]; then - echo "Error running longbench_v2_search.py" - exit 1 -fi +ROOT_DIR=$(cd "$(dirname "$0")/../.." && pwd) +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR" -# Step 3: Response Generation -echo "" -echo "==========================================" -echo "Step 3: Running longbench_v2_responses.py..." -echo "==========================================" -python scripts/long_bench-v2/longbench_v2_responses.py \ - --lib $LIB \ - --version $VERSION \ - --workers $WORKERS - -if [ $? -ne 0 ]; then - echo "Error running longbench_v2_responses.py" - exit 1 -fi - -# Step 4: Metrics Calculation -echo "" -echo "==========================================" -echo "Step 4: Running longbench_v2_metric.py..." -echo "==========================================" -python scripts/long_bench-v2/longbench_v2_metric.py \ - --lib $LIB \ - --version $VERSION - -if [ $? -ne 0 ]; then - echo "Error running longbench_v2_metric.py" - exit 1 -fi - -echo "" -echo "==========================================" -echo "All steps completed successfully!" -echo "==========================================" -echo "" -echo "Results are saved in: results/long_bench-v2/$LIB-$VERSION/" -echo " - Search results: ${LIB}_longbench_v2_search_results.json" -echo " - Responses: ${LIB}_longbench_v2_responses.json" -echo " - Metrics: ${LIB}_longbench_v2_metrics.json" +# Common parameters +LIB="supermemory" +WORKERS=5 +TOPK=20 +ADD_MODE="fine" +SEARCH_MODE="fine" +VERSION_DIR="test_1231" +ASYNC_MODE="sync" +CHAT_MODEL="gpt-4o-mini" + +# Add / Ingestion +echo "Running longbench_v2_ingestion.py..." +python -m evaluation.scripts.longbench_v2.longbench_v2_ingestion \ + --lib "$LIB" \ + --workers "$WORKERS" \ + --version-dir "$VERSION_DIR" \ + --mode "$ADD_MODE" \ + --async-mode "$ASYNC_MODE" + +# Search +echo "Running longbench_v2_search.py..." +python -m evaluation.scripts.longbench_v2.longbench_v2_search \ + --lib "$LIB" \ + --workers "$WORKERS" \ + --version-dir "$VERSION_DIR" \ + --top-k "$TOPK" \ + --mode "$SEARCH_MODE" \ + --limit 30 + +# Eval +echo "Running longbench_v2_eval.py..." +python -m evaluation.scripts.longbench_v2.longbench_v2_eval \ + --lib "$LIB" \ + --version-dir "$VERSION_DIR" \ + --workers "$WORKERS" \ + --chat-model "$CHAT_MODEL" + +echo "All scripts completed successfully!" diff --git a/evaluation/scripts/run_mmlongbench_eval.sh b/evaluation/scripts/run_mmlongbench_eval.sh new file mode 100755 index 000000000..ff357f404 --- /dev/null +++ b/evaluation/scripts/run_mmlongbench_eval.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -e + +ROOT_DIR=$(cd "$(dirname "$0")/../.." && pwd) +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR" + +# Common parameters +LIB="supermemory" +WORKERS=10 +TOPK=20 +ADD_MODE="fine" +SEARCH_MODE="fine" +VERSION_DIR="test_1230" +ASYNC_MODE="sync" +CHAT_MODEL="gpt-4o-mini" + +## Add / Ingestion +#echo "Running mmlongbench_ingestion.py..." +#python -m evaluation.scripts.mmlongbench.mmlongbench_ingestion \ +# --lib "$LIB" \ +# --workers "$WORKERS" \ +# --version-dir "$VERSION_DIR" \ +# --mode "$ADD_MODE" \ +# --async-mode "$ASYNC_MODE" +# +## Search +#echo "Running mmlongbench_search.py..." +#python -m evaluation.scripts.mmlongbench.mmlongbench_search \ +# --lib "$LIB" \ +# --workers "$WORKERS" \ +# --version-dir "$VERSION_DIR" \ +# --top-k "$TOPK" \ +# --mode "$SEARCH_MODE" + +# Eval +echo "Running mmlongbench_eval.py..." +python -m evaluation.scripts.mmlongbench.mmlongbench_eval \ + --lib "$LIB" \ + --version-dir "$VERSION_DIR" \ + --workers "$WORKERS" \ + --chat-model "$CHAT_MODEL" + +echo "All scripts completed successfully!" diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 157c3f8ea..a7168f343 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -1,5 +1,6 @@ import json import os +import re import sys import time import uuid @@ -56,30 +57,22 @@ def __init__(self, enable_graph=False): self.enable_graph = enable_graph def add(self, messages, user_id, timestamp, batch_size=2): - max_retries = 5 for i in range(0, len(messages), batch_size): batch_messages = messages[i : i + batch_size] - for attempt in range(max_retries): - try: - if self.enable_graph: - self.client.add( - messages=batch_messages, - timestamp=timestamp, - user_id=user_id, - enable_graph=True, - ) - else: - self.client.add( - messages=batch_messages, - timestamp=timestamp, - user_id=user_id, - ) - break - except Exception as e: - if attempt < max_retries - 1: - time.sleep(2**attempt) - else: - raise e + if self.enable_graph: + self.client.add( + messages=batch_messages, + timestamp=timestamp, + user_id=user_id, + enable_graph=True, + ) + else: + self.client.add( + messages=batch_messages, + timestamp=timestamp, + user_id=user_id, + infer=False, + ) def search(self, query, user_id, top_k): res = self.client.search( @@ -143,56 +136,95 @@ def string_to_uuid(self, s: str, salt="memobase_client"): class MemosApiClient: - def __init__(self): - self.memos_url = os.getenv("MEMOS_URL") - self.headers = {"Content-Type": "application/json", "Authorization": os.getenv("MEMOS_KEY")} + """Product Add API 封装""" + + def __init__(self, timeout: float = 600.0): + self.base_url = os.getenv("MEMOS_URL") + self.headers = {"Content-Type": "application/json"} + self.timeout = timeout + + def add( + self, + messages, + user_id, + writable_cube_ids: list[str], + source_type: str, + mode: str, + async_mode: str, + ): + """ + 调用 /product/add 接口 + + Args: + messages: 添加记忆信息 + user_id: 用户ID + writable_cube_ids: 可写cube ID列表 + source_type: 来源类型 + mode: 模式 (fine/coarse) + async_mode: 异步模式 (sync/async) + """ + url = f"{self.base_url}/product/add" + + payload = { + "user_id": user_id, + "writable_cube_ids": writable_cube_ids, + "messages": messages, + "info": {"source_type": source_type}, + "mode": mode, + "async_mode": async_mode, + } + + response = requests.post( + url, + data=json.dumps(payload, ensure_ascii=False).encode("utf-8"), + headers=self.headers, + timeout=self.timeout, + ) + + if response.status_code != 200: + raise RuntimeError(f"HTTP {response.status_code}: {response.text}") - def add(self, messages, user_id, conv_id, batch_size: int = 9999): + body = response.json() + if body.get("code") is not None and body.get("code") != 200: + raise RuntimeError(f"BUSINESS ERROR {body.get('code')}: {response.text}") + + return body + + def search(self, query, user_id, readable_cube_ids: list[str], top_k: str, mode: str): """ - messages = [{"role": "assistant", "content": data, "chat_time": date_str}] + 调用 /product/search 接口 + + Args: + query: 搜索查询 + user_id: 用户ID + readable_cube_ids: 可读cube ID列表, 默认为[user_id] + top_k: 返回结果数量 """ - url = f"{self.memos_url}/product/add" - added_memories = [] - for i in range(0, len(messages), batch_size): - batch_messages = messages[i : i + batch_size] - payload = json.dumps( - { - "messages": batch_messages, - "user_id": user_id, - "mem_cube_id": user_id, - "conversation_id": conv_id, - } - ) - response = requests.request("POST", url, data=payload, headers=self.headers) - assert response.status_code == 200, response.text - assert json.loads(response.text)["message"] == "Memory added successfully", ( - response.text - ) - added_memories += json.loads(response.text)["data"] - return added_memories - def search(self, query, user_id, top_k): - """Search memories.""" - url = f"{self.memos_url}/product/search" - payload = json.dumps( - { - "query": query, - "user_id": user_id, - "mem_cube_id": user_id, - "conversation_id": "", - "top_k": top_k, - "mode": os.getenv("SEARCH_MODE", "fast"), - "include_preference": True, - "pref_top_k": 6, - }, - ensure_ascii=False, - ) - response = requests.request("POST", url, data=payload, headers=self.headers) - assert response.status_code == 200, response.text - assert json.loads(response.text)["message"] == "Search completed successfully", ( - response.text + url = f"{self.base_url}/product/search" + + if readable_cube_ids is None: + readable_cube_ids = [user_id] + + payload = { + "query": query, + "user_id": user_id, + "readable_cube_ids": readable_cube_ids, + "top_k": top_k, + "mode": mode, + } + + response = requests.post( + url, + data=json.dumps(payload, ensure_ascii=False).encode("utf-8"), + headers=self.headers, + timeout=self.timeout, ) - return json.loads(response.text)["data"] + + if response.status_code != 200: + raise RuntimeError(f"HTTP {response.status_code}: {response.text}") + + return response.json() class MemosApiOnlineClient: @@ -276,44 +308,65 @@ def search(self, query, user_id, top_k): class SupermemoryClient: def __init__(self): - from supermemory import Supermemory - - self.client = Supermemory(api_key=os.getenv("SUPERMEMORY_API_KEY")) - - def add(self, messages, user_id): - content = "\n".join( - [f"{msg['chat_time']} {msg['role']}: {msg['content']}" for msg in messages] - ) - max_retries = 5 - for attempt in range(max_retries): - try: - self.client.memories.add(content=content, container_tag=user_id) - break - except Exception as e: - if attempt < max_retries - 1: - time.sleep(2**attempt) - else: - raise e - - def search(self, query, user_id, top_k): - max_retries = 10 - for attempt in range(max_retries): - try: - results = self.client.search.memories( - q=query, - container_tag=user_id, - threshold=0, - rerank=True, - rewrite_query=True, - limit=top_k, - ) - context = "\n\n".join([r.memory for r in results.results]) - return context - except Exception as e: - if attempt < max_retries - 1: - time.sleep(2**attempt) - else: - raise e + self.api_key = os.getenv("SUPERMEMORY_API_KEY") + if not self.api_key: + raise ValueError( + "SUPERMEMORY_API_KEY environment variable is not set. Please set it in your .env file or environment." + ) + self.add_url = "https://api.supermemory.ai/v3/documents" + self.search_url = "https://api.supermemory.ai/v3/search" + + def _sanitize_tag(self, s: str) -> str: + t = str(s).strip() + t = os.path.splitext(t)[0] + t = t.replace(" ", "_") + t = re.sub(r"[^A-Za-z0-9_-]", "_", t) + t = re.sub(r"[_-]+", "_", t) + t = t.strip("_") + t = t.lower() + if not re.match(r"^[a-z0-9]", t or ""): + t = f"tag_{t}" if t else "tag_default" + return t + + def add(self, content: str, user_id: str): + payload = { + "content": content, + "raw": content, + "containerTag": self._sanitize_tag(user_id), + } + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(self.add_url, json=payload, headers=headers) + resp.raise_for_status() + return resp.json() + + def search(self, query: str, user_id: str, top_k: int): + payload = { + "q": query, + "limit": top_k, + "containerTags": [self._sanitize_tag(user_id)], + "rerank": True, + } + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + resp = requests.post(self.search_url, json=payload, headers=headers) + resp.raise_for_status() + data = resp.json() + + chunk_list = [] + res = [entry.get("chunks") for entry in data.get("results", [])] + for chunks in res: + for chunk in chunks: + chunk_list.append(chunk["content"]) + + return chunk_list class MemuClient: diff --git a/evaluation/scripts/utils/eval_score.py b/evaluation/scripts/utils/eval_score.py new file mode 100644 index 000000000..02ef6eb53 --- /dev/null +++ b/evaluation/scripts/utils/eval_score.py @@ -0,0 +1,246 @@ +import re + +from collections import defaultdict +from math import isclose + + +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] + + +def anls_compute(groundtruth, prediction, threshold=0.5): + dist = levenshtein_distance(groundtruth, prediction) + length = max(len(groundtruth.upper()), len(prediction.upper())) + value = 0.0 if length == 0 else float(dist) / float(length) + anls = 1.0 - value + if anls <= threshold: + anls = 0.0 + return anls + + +def is_float_equal( + reference, prediction, include_percentage: bool = False, is_close: float = False +) -> bool: + def get_precision(gt_ans: float) -> int: + precision = 3 + if "." in str(gt_ans): + precision = len(str(gt_ans).split(".")[-1]) + return precision + + reference = float(str(reference).strip().rstrip("%").strip()) + try: + prediction = float(str(prediction).strip().rstrip("%").strip()) + except Exception: + return False + + gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + for item in gt_result: + try: + if is_close and isclose(item, prediction, rel_tol=0.01): + return True + precision = max(min(get_precision(prediction), get_precision(item)), 2) + if round(prediction, precision) == round(item, precision): + return True + except Exception: + continue + return False + + +def get_clean_string(s): + s = str(s).lower().strip() + + for suffix in ["mile", "miles", "million"]: + if s.endswith(suffix): + s = s[: -len(suffix)].strip() + + s = re.sub(r"\s*\([^)]*\)", "", s).strip() + s = re.sub(r"^['\"]|['\"]$", "", s).strip() + s = s.lstrip("$").rstrip("%").strip() + + return s + + +def is_exact_match(s): + flag = False + # Website + if "https://" in s: + flag = True + # code file + if s.endswith((".py", ".ipynb")) or s.startswith("page"): + flag = True + # telephone number + if re.fullmatch(r"\b\d+(-\d+|\s\d+)?\b", s): + flag = True + # time + if "a.m." in s or "p.m." in s: + flag = True + # YYYY-MM-DD + if re.fullmatch(r"\b\d{4}[-\s]\d{2}[-\s]\d{2}\b", s): + flag = True + # YYYY-MM + if re.fullmatch(r"\b\d{4}[-\s]\d{2}\b", s): + flag = True + # Email address + if re.fullmatch(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", s): + flag = True + return flag + + +def isfloat(num): + try: + float(num) + return True + except ValueError: + return False + + +def eval_score(gt, pred, answer_type): + if answer_type == "Int": + try: + gt, pred = int(gt), int(float(pred)) + except Exception: + pred = "" + score = gt == pred + elif answer_type == "Float": + try: + gt = float(get_clean_string(str(gt))) + pred = float(get_clean_string(str(pred))) + except Exception: + pred = "" + score = is_float_equal(gt, pred, include_percentage=True, is_close=True) + elif answer_type in ["Str", "None"]: + gt = get_clean_string(gt) + pred = get_clean_string(pred) + score = gt == pred if is_exact_match(gt) else anls_compute(gt, pred) + else: + if isinstance(gt, str) and gt.startswith("["): + gt = eval(gt) + if not isinstance(gt, list): + gt = [gt] + if isinstance(pred, str) and pred.startswith("["): + pred = eval(pred) + if not isinstance(pred, list): + pred = [pred] + print(len(gt), len(pred)) + if len(gt) != len(pred): + score = 0.0 + else: + gt = sorted([get_clean_string(a) for a in gt]) + pred = sorted([get_clean_string(a) for a in pred]) + print(gt, pred) + if isfloat(gt[0]) or is_exact_match(gt[0]): + score = "-".join(gt) == "-".join(pred) + else: + score = min( + [anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred, strict=False)] + ) + + return float(score) + + +def eval_acc_and_f1(samples): + evaluated_samples = [sample for sample in samples if "score" in sample] + if not evaluated_samples: + return 0.0, 0.0 + + acc = sum([sample["score"] for sample in evaluated_samples]) / len(evaluated_samples) + try: + recall = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["answer"] != "Not answerable"]) + precision = sum( + [ + sample["score"] + for sample in evaluated_samples + if sample["answer"] != "Not answerable" + ] + ) / len([sample for sample in evaluated_samples if sample["pred"] != "Not answerable"]) + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0.0 else 0.0 + except Exception: + f1 = 0.0 + + return acc, f1 + + +def show_results(samples, show_path=None): + for sample in samples: + sample["evidence_pages"] = eval(sample["evidence_pages"]) + sample["evidence_sources"] = eval(sample["evidence_sources"]) + + with open(show_path, "w") as f: + acc, f1 = eval_acc_and_f1(samples) + f.write(f"Overall Acc: {acc} | Question Number: {len(samples)}\n") + f.write(f"Overall F1-score: {f1} | Question Number: {len(samples)}\n") + f.write("-----------------------\n") + + acc_single_page, _ = eval_acc_and_f1( + [sample for sample in samples if len(sample["evidence_pages"]) == 1] + ) + acc_multi_page, _ = eval_acc_and_f1( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 and sample["answer"] != "Not answerable" + ] + ) + acc_neg, _ = eval_acc_and_f1( + [sample for sample in samples if sample["answer"] == "Not answerable"] + ) + + f.write( + "Single-page | Accuracy: {} | Question Number: {}\n".format( + acc_single_page, + len([sample for sample in samples if len(sample["evidence_pages"]) == 1]), + ) + ) + f.write( + "Cross-page | Accuracy: {} | Question Number: {}\n".format( + acc_multi_page, + len( + [ + sample + for sample in samples + if len(sample["evidence_pages"]) != 1 + and sample["answer"] != "Not answerable" + ] + ), + ) + ) + f.write( + "Unanswerable | Accuracy: {} | Question Number: {}\n".format( + acc_neg, len([sample for sample in samples if sample["answer"] == "Not answerable"]) + ) + ) + f.write("-----------------------\n") + + source_sample_dict, document_type_dict = defaultdict(list), defaultdict(list) + for sample in samples: + for answer_source in sample["evidence_sources"]: + source_sample_dict[answer_source].append(sample) + document_type_dict[sample["doc_type"]].append(sample) + for type, sub_samples in source_sample_dict.items(): + f.write( + f"Evidence Sources: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) + + f.write("-----------------------\n") + for type, sub_samples in document_type_dict.items(): + f.write( + f"Document Type: {type} | Accuracy: {eval_acc_and_f1(sub_samples)[0]} | Question Number: {len(sub_samples)}\n" + ) diff --git a/evaluation/scripts/utils/extract_answer.py b/evaluation/scripts/utils/extract_answer.py new file mode 100644 index 000000000..e527d4b97 --- /dev/null +++ b/evaluation/scripts/utils/extract_answer.py @@ -0,0 +1,58 @@ +import os + +from pathlib import Path + +import openai + +from dotenv import load_dotenv + + +load_dotenv() + +client = openai.Client( + api_key=os.getenv("OPENAI_API_KEY", "sk-xxxxx"), + base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), +) + +PROMPT_PATH = Path("evaluation/scripts/utils/prompt_for_answer_extraction.md") +with open(PROMPT_PATH, encoding="utf-8") as f: + EXTRACTION_PROMPT = f.read() + + +def extract_answer(question: str, output: str, model_name: str = "gpt-4o-mini") -> str: + resp = client.chat.completions.create( + model=model_name, + messages=[ + {"role": "user", "content": EXTRACTION_PROMPT}, + {"role": "assistant", "content": f"\n\nQuestion:{question}\nAnalysis:{output}\n"}, + ], + temperature=0.0, + max_tokens=256, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + ) + content = resp.choices[0].message.content or "" + return content + + +def parse_extracted_answer(extracted_res: str, fallback_output: str) -> str: + try: + head = extracted_res.split("Answer format:")[0] + ans = head.split("Extracted answer:")[1].strip() + if ans: + return ans + except Exception: + pass + text = (fallback_output or "").strip() + low = text.lower() + if " yes" in low or low.startswith("yes"): + return "yes" + if " no" in low or low.startswith("no"): + return "no" + for sep in ["\n", ". ", ".", "?", "!"]: + if sep in text: + cand = text.split(sep)[0].strip() + if cand: + return cand + return text diff --git a/evaluation/scripts/utils/metrics.py b/evaluation/scripts/utils/metrics.py new file mode 100644 index 000000000..135a60cec --- /dev/null +++ b/evaluation/scripts/utils/metrics.py @@ -0,0 +1,56 @@ +import threading + + +class Metrics: + def __init__(self): + self.times_ms: list[float] = [] + self.success_count = 0 + self.fail_count = 0 + self.errors = {} + self.lock = threading.Lock() + + def record(self, duration_s: float, success: bool, error_msg: str | None = None): + ms = duration_s * 1000.0 + with self.lock: + if success: + self.times_ms.append(ms) + self.success_count += 1 + else: + self.fail_count += 1 + if error_msg: + short_err = error_msg[:200] if len(error_msg) > 200 else error_msg + self.errors[short_err] = self.errors.get(short_err, 0) + 1 + + def summary(self) -> dict: + with self.lock: + if not self.times_ms: + return { + "stats": {}, + "counts": {"success": self.success_count, "failed": self.fail_count}, + "errors": dict(self.errors), + } + sorted_times = sorted(self.times_ms) + n = len(sorted_times) + + def percentile(p: int): + if n == 1: + return sorted_times[0] + k = max(0, min(n - 1, round((p / 100) * (n - 1)))) + return sorted_times[k] + + mean = sum(sorted_times) / n + variance = sum((x - mean) ** 2 for x in sorted_times) / (n - 1) if n > 1 else 0.0 + return { + "stats": { + "count": n, + "mean": mean, + "median": percentile(50), + "min": sorted_times[0], + "max": sorted_times[-1], + "p95": percentile(95), + "p99": percentile(99), + "std": variance**0.5, + }, + "counts": {"success": self.success_count, "failed": self.fail_count}, + "errors": dict(self.errors), + } diff --git a/evaluation/scripts/utils/prompt_for_answer_extraction.md b/evaluation/scripts/utils/prompt_for_answer_extraction.md new file mode 100644 index 000000000..a309c0935 --- /dev/null +++ b/evaluation/scripts/utils/prompt_for_answer_extraction.md @@ -0,0 +1,35 @@ +Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis. +- Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List. If you find the analysis the question can not be answered from the given documents, type "Not answerable". Exception: If the analysis only tells you that it can not read/understand the images or documents, type "Fail to answer". +- Please make your response as concise as possible. Also note that your response should be formatted as below: +``` +Extracted answer: [answer] +Answer format: [answer format] +``` + +Please read the following example, then extract the answer from the model response and type it at the end of the prompt. + +--- +Question: List the primary questions asked about the services in this report. +Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n1. Is the service safe?\n2. Is the service effective?\n3. Is the service caring?\n4. Is the service responsive?\n5. Is the service well-led? +Extracted answer: ['Is the servife safe?', 'Is the service effective', 'Is the serve caring?', 'Is the service responsive?', 'Is the service well-led?'] +Answer format: List + +--- +Question: How many regulations of the HSCA 2008 are breached in all according to this report? +Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities) Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11: Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17: Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9. Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing, safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to notify the CQC of incidents. +Extracted answer: 10 +Answer format: Integer + +--- +Question: According to the survey that is the percentage of Chinese who are paying more or about the same attention to politics after Trump's election? +Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying more or about the same attention to politics after Trump's election. The report focuses primarily on American demographics and does not include specific details about the Chinese population in relation to this question. If you need information about a different demographic or a summary of the findings from the American demographic, I can certainly help with that! +Extracted answer: Not answerable +Answer format: String + +--- +Question: How many quotations from male respondent over 50 years old are included in this report? +Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be able to help you with your question. +Extracted answer: Fail to answer +Answer format: String + +--- diff --git a/evaluation/scripts/utils/prompts.py b/evaluation/scripts/utils/prompts.py index 32e6d6729..ba5f5db8f 100644 --- a/evaluation/scripts/utils/prompts.py +++ b/evaluation/scripts/utils/prompts.py @@ -65,6 +65,55 @@ {context} """ +MMLONGBENCH_ANSWER_PROMPT = """ + You are a helpful assistant that can answer questions based on the provided memories and images. + + {memories} + + Read the above memories and answer this question + Please make your answer as concise as possible. +""" + +LONGBENCH_V2_ANSWER_PROMPT = """ +Please read the following retrieved text chunks and answer the question below. + + +$DOC$ + + +What is the correct answer to this question: $Q$ +Choices: +(A) $C_A$ +(B) $C_B$ +(C) $C_C$ +(D) $C_D$ + +Format your response as follows: "The correct answer is (insert answer here)". +""" + + +HOTPOT_ANSWER_PROMPT = """ +You are answering a question from the HotpotQA dataset. + +The question may require multi-hop reasoning across multiple supporting facts. +Carefully read the provided context and identify the relevant evidence. +Reason step by step to connect the facts and determine the correct answer. + +Important instructions: +- Use only the information provided in the context. +- Perform multi-step reasoning internally if needed. +- The final answer must be a short factual answer (e.g., a name, place, date, or entity). +- Do NOT include explanations, reasoning steps, or citations in the final output. + +Question: +{question} + +Context: +{context} + +Final Answer: + +""" ZEP_CONTEXT_TEMPLATE = """ FACTS and ENTITIES represent relevant context to the current conversation.