Skip to content

Commit 8ea4828

Browse files
CaralHsiCarltonXiangfridayL
authored
feat: update evaluation; add general string parser (#715)
* hotfix:hotfix * test: add routers api * fix: doc fine mode bug * fix: doc fine mode bug * feat: init longbench_v2 * feat: more strict embedder trucation * feat: parallel processing fine mode in multi-modal-fine * feat: update parsers; add chunk info into source; remove origin_part * feat: modify chunk_content in file-fine-parser * fix: token counter bug * feat: enlarge polardb * feat: derease parallrl * feat: add image parser in file * feat: update file_content_parser * feat: modify long_bench_v2 * feat: modify long_bench_v2 * fix: image bug * feat: increase playground depth * feat: set parsed_text None in file parser * fix: file_ids bug in file-mode * feat: update evaluation * feat: update evaluation * feat: add general string prompt * fix: test server router * feat: update evluation * feat: decrease graph-db batch size to 5 * fix: default name in long_bench-v2/longbench_v2_search * fix: test bug * Update test_server_router.py * Update test_product_router.py * feat: comment --------- Co-authored-by: HarveyXiang <[email protected]> Co-authored-by: fridayL <[email protected]> Co-authored-by: chunyu li <[email protected]>
1 parent 4e2d87f commit 8ea4828

File tree

9 files changed

+486
-221
lines changed

9 files changed

+486
-221
lines changed

evaluation/scripts/long_bench-v2/longbench_v2_ingestion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def ingest_sample(
3333
# Get context and convert to messages
3434
context = sample.get("context", "")
3535

36-
# For memos, we ingest the context as document content
36+
# For memos, we ingest the context as a raw document content
3737
messages = [
3838
{
3939
"type": "file",
@@ -185,7 +185,7 @@ def main(frame, version="default", num_workers=10, max_samples=None):
185185
parser.add_argument(
186186
"--workers",
187187
type=int,
188-
default=3,
188+
default=2,
189189
help="Number of parallel workers",
190190
)
191191
parser.add_argument(

evaluation/scripts/long_bench-v2/longbench_v2_metric.py

Lines changed: 94 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,75 +4,80 @@
44

55

66
def calculate_accuracy(responses):
7-
"""Calculate accuracy metrics for LongBench v2."""
7+
"""Calculate accuracy metrics for LongBench v2.
8+
9+
Logic is aligned with longbench_stx.print_metrics, but returns a dict
10+
and additionally computes by_domain statistics.
11+
"""
812
total = len(responses)
913
if total == 0:
1014
return {}
1115

12-
# Overall accuracy
13-
correct = sum(1 for r in responses if r.get("judge", False))
14-
overall_acc = round(100 * correct / total, 1)
15-
16-
# By difficulty
17-
easy_items = [r for r in responses if r.get("difficulty") == "easy"]
18-
hard_items = [r for r in responses if r.get("difficulty") == "hard"]
19-
easy_acc = (
20-
round(100 * sum(1 for r in easy_items if r.get("judge", False)) / len(easy_items), 1)
21-
if easy_items
22-
else 0.0
23-
)
24-
hard_acc = (
25-
round(100 * sum(1 for r in hard_items if r.get("judge", False)) / len(hard_items), 1)
26-
if hard_items
27-
else 0.0
28-
)
29-
30-
# By length
31-
short_items = [r for r in responses if r.get("length") == "short"]
32-
medium_items = [r for r in responses if r.get("length") == "medium"]
33-
long_items = [r for r in responses if r.get("length") == "long"]
34-
35-
short_acc = (
36-
round(100 * sum(1 for r in short_items if r.get("judge", False)) / len(short_items), 1)
37-
if short_items
38-
else 0.0
39-
)
40-
medium_acc = (
41-
round(100 * sum(1 for r in medium_items if r.get("judge", False)) / len(medium_items), 1)
42-
if medium_items
43-
else 0.0
44-
)
45-
long_acc = (
46-
round(100 * sum(1 for r in long_items if r.get("judge", False)) / len(long_items), 1)
47-
if long_items
48-
else 0.0
49-
)
50-
51-
# By domain
16+
# Counters (aligned with longbench_stx.print_metrics)
17+
easy = hard = short = medium = long = 0
18+
easy_acc = hard_acc = short_acc = medium_acc = long_acc = 0
19+
total_prompt_tokens = 0
20+
21+
for pred in responses:
22+
acc = int(pred.get("judge", False))
23+
diff = pred.get("difficulty", "easy")
24+
length = pred.get("length", "short")
25+
26+
pt = pred.get("prompt_tokens")
27+
if isinstance(pt, int | float):
28+
total_prompt_tokens += int(pt)
29+
30+
if diff == "easy":
31+
easy += 1
32+
easy_acc += acc
33+
else:
34+
hard += 1
35+
hard_acc += acc
36+
37+
if length == "short":
38+
short += 1
39+
short_acc += acc
40+
elif length == "medium":
41+
medium += 1
42+
medium_acc += acc
43+
else:
44+
long += 1
45+
long_acc += acc
46+
47+
o_acc = round(100 * (easy_acc + hard_acc) / total, 2)
48+
e_acc = round(100 * easy_acc / easy, 2) if easy > 0 else 0.0
49+
h_acc = round(100 * hard_acc / hard, 2) if hard > 0 else 0.0
50+
s_acc = round(100 * short_acc / short, 2) if short > 0 else 0.0
51+
m_acc = round(100 * medium_acc / medium, 2) if medium > 0 else 0.0
52+
l_acc = round(100 * long_acc / long, 2) if long > 0 else 0.0
53+
54+
# Additional by-domain stats (extra vs. stx)
5255
domain_stats = {}
53-
for response in responses:
54-
domain = response.get("domain", "Unknown")
56+
for r in responses:
57+
domain = r.get("domain", "Unknown")
5558
if domain not in domain_stats:
5659
domain_stats[domain] = {"total": 0, "correct": 0}
5760
domain_stats[domain]["total"] += 1
58-
if response.get("judge", False):
61+
if r.get("judge", False):
5962
domain_stats[domain]["correct"] += 1
6063

6164
domain_acc = {
62-
domain: round(100 * stats["correct"] / stats["total"], 1)
65+
domain: round(100 * stats["correct"] / stats["total"], 2)
6366
for domain, stats in domain_stats.items()
6467
}
6568

6669
return {
67-
"overall": overall_acc,
68-
"easy": easy_acc,
69-
"hard": hard_acc,
70-
"short": short_acc,
71-
"medium": medium_acc,
72-
"long": long_acc,
70+
"overall": o_acc,
71+
"easy": e_acc,
72+
"hard": h_acc,
73+
"short": s_acc,
74+
"medium": m_acc,
75+
"long": l_acc,
7376
"by_domain": domain_acc,
7477
"total_samples": total,
75-
"correct_samples": correct,
78+
"correct_samples": easy_acc + hard_acc,
79+
"total_prompt_tokens": total_prompt_tokens,
80+
"avg_prompt_tokens": round(total_prompt_tokens / total, 2) if total > 0 else 0.0,
7681
}
7782

7883

@@ -92,11 +97,36 @@ def main(frame, version="default"):
9297
with open(responses_path, encoding="utf-8") as f:
9398
responses = json.load(f)
9499

95-
# Only keep entries with non-empty context (search_context) to align with response generation
96-
filtered = [r for r in responses if str(r.get("search_context", "")).strip() != ""]
97-
98-
# Calculate metrics
99-
metrics = calculate_accuracy(filtered)
100+
# Only keep entries that actually have search results:
101+
# - For new pipeline: non-empty memories_used list
102+
# - For older runs: non-empty search_context string
103+
def _has_search_results(r: dict) -> bool:
104+
mems = r.get("memories_used")
105+
if isinstance(mems, list) and any(str(m).strip() for m in mems):
106+
return True
107+
ctx = str(r.get("search_context", "")).strip()
108+
return ctx != ""
109+
110+
filtered = [r for r in responses if _has_search_results(r)]
111+
112+
# Calculate metrics (handle case where no samples have search results)
113+
if not filtered:
114+
print("⚠️ No responses with valid search results were found. Metrics will be zeroed.")
115+
metrics = {
116+
"overall": 0.0,
117+
"easy": 0.0,
118+
"hard": 0.0,
119+
"short": 0.0,
120+
"medium": 0.0,
121+
"long": 0.0,
122+
"by_domain": {},
123+
"total_samples": 0,
124+
"correct_samples": 0,
125+
"total_prompt_tokens": 0,
126+
"avg_prompt_tokens": 0.0,
127+
}
128+
else:
129+
metrics = calculate_accuracy(filtered)
100130

101131
# Save metrics
102132
output_path = f"results/long_bench_v2/{frame}-{version}/{frame}_longbench_v2_metrics.json"
@@ -112,12 +142,13 @@ def main(frame, version="default"):
112142
# Print summary table
113143
print("\n📊 Summary of Results:")
114144
print("-" * 80)
115-
print(f"{'Overall Accuracy':<30s}: {metrics['overall']:.1f}%")
116-
print(f"{'Easy':<30s}: {metrics['easy']:.1f}%")
117-
print(f"{'Hard':<30s}: {metrics['hard']:.1f}%")
118-
print(f"{'Short':<30s}: {metrics['short']:.1f}%")
119-
print(f"{'Medium':<30s}: {metrics['medium']:.1f}%")
120-
print(f"{'Long':<30s}: {metrics['long']:.1f}%")
145+
print(f"{'Overall Accuracy':<30s}: {metrics['overall']:.2f}%")
146+
print(f"{'Easy':<30s}: {metrics['easy']:.2f}%")
147+
print(f"{'Hard':<30s}: {metrics['hard']:.2f}%")
148+
print(f"{'Short':<30s}: {metrics['short']:.2f}%")
149+
print(f"{'Medium':<30s}: {metrics['medium']:.2f}%")
150+
print(f"{'Long':<30s}: {metrics['long']:.2f}%")
151+
print(f"{'Avg Prompt Tokens':<30s}: {metrics.get('avg_prompt_tokens', 0.0):.2f}")
121152
print("\nBy Domain:")
122153
for domain, acc in metrics["by_domain"].items():
123154
print(f" {domain:<28s}: {acc:.1f}%")

0 commit comments

Comments
 (0)