Skip to content

Commit 5885285

Browse files
authored
Ce add benchmark test (#3262)
* add repitation early stop cases * add repitation early stop cases * add bad cases * add bad cases * add evil cases * add benchmark gsm8k
1 parent 55ac449 commit 5885285

File tree

3 files changed

+192
-0
lines changed

3 files changed

+192
-0
lines changed

test/ce/server/gsm8k.parquet

409 KB
Binary file not shown.

test/ce/server/gsm8k.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
#!/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
# @author DDDivano
4+
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
5+
6+
7+
import os
8+
import re
9+
from concurrent.futures import ThreadPoolExecutor, as_completed
10+
from urllib.parse import urlparse, urlunparse
11+
12+
import openai
13+
from datasets import load_dataset
14+
from tqdm import tqdm
15+
16+
BASELINE = {
17+
"0.3B": 0.05,
18+
"21B": 0.49,
19+
"300B": 0.96,
20+
}
21+
baseline = BASELINE.get(os.environ.get("MODEL"), None)
22+
base_url = os.environ.get("URL", None)
23+
atol = 0.03
24+
if baseline is None:
25+
raise ValueError(f"Invalid MODEL value '{os.environ.get('MODEL')}', expected one of {list(BASELINE.keys())}")
26+
if base_url is None:
27+
raise ValueError(
28+
"Environment variable 'URL' is not set. "
29+
"Please specify the inference service address, e.g., 'http://localhost:8191/v1'."
30+
)
31+
32+
33+
def strip_path_suffix(url: str, suffix: str = "chat/completions") -> str:
34+
"""
35+
去除 URL 中的指定路径后缀(如 chat/completions)
36+
"""
37+
parsed = urlparse(url)
38+
# 移除末尾的 suffix(注意确保只移除结尾部分)
39+
if parsed.path.endswith("/" + suffix):
40+
new_path = parsed.path[: -(len(suffix) + 1)] # +1 是斜杠
41+
else:
42+
new_path = parsed.path
43+
# 重新构造 URL
44+
cleaned_url = urlunparse(
45+
(
46+
parsed.scheme,
47+
parsed.netloc,
48+
new_path.rstrip("/"), # 去掉末尾的斜杠
49+
"",
50+
"",
51+
"", # 忽略 params/query/fragment
52+
)
53+
)
54+
return cleaned_url
55+
56+
57+
# ========== OpenAI 客户端配置 ==========
58+
client = openai.OpenAI(
59+
api_key="DDDivano",
60+
# base_url="http://占位:8187/v1"
61+
base_url=strip_path_suffix(base_url),
62+
)
63+
64+
model_name = "eb"
65+
max_samples = 690
66+
max_tokens = 12288
67+
max_workers = 33
68+
69+
# ========== 加载数据集 ==========
70+
dataset = load_dataset("parquet", data_files="gsm8k.parquet", split="train")
71+
dataset = dataset.select(range(min(len(dataset), max_samples)))
72+
73+
74+
# ========== 提取 GT 中 "#### 数字" 格式的最终答案 ==========
75+
def extract_gt_answer(text):
76+
match = re.search(r"####\s*([\d,]+(?:\.\d+)?)", text)
77+
if match:
78+
return match.group(1).replace(",", "").strip()
79+
return None
80+
81+
82+
# ========== 提取模型输出中的“最后一句话”中的数字 ==========
83+
def extract_model_answer(text):
84+
if not text:
85+
return None
86+
text = text.replace(",", "").replace("$", "")
87+
lines = text.strip().splitlines()
88+
last_line = lines[-1] if lines else text
89+
match = re.search(r"-?\d+(?:\.\d+)?", last_line)
90+
return match.group(0) if match else None
91+
92+
93+
# ========== 数值比较函数 ==========
94+
def is_answer_equal(pred, gt, tol=1e-6):
95+
if pred is None or gt is None:
96+
return False
97+
try:
98+
return abs(float(pred) - float(gt)) < tol
99+
except:
100+
return pred == gt
101+
102+
103+
# ========== 构造 Prompt ==========
104+
def build_prompt(sample):
105+
return f"以下是一个数学问题,请直接给出最终答案。一定要把最终答案数字在最后输出。\n\n问题:{sample['question']}\n\n答案:"
106+
107+
108+
# ========== 模型请求函数 ==========
109+
def query_model(prompt):
110+
try:
111+
response = client.chat.completions.create(
112+
model=model_name,
113+
messages=[
114+
{"role": "system", "content": "你是一个数学专家,擅长严谨地解答数学问题。"},
115+
{"role": "user", "content": prompt},
116+
],
117+
temperature=1.0,
118+
top_p=0.8,
119+
max_tokens=max_tokens,
120+
)
121+
return response.choices[0].message.content.strip()
122+
except Exception as e:
123+
return f"[Error] {e}"
124+
125+
126+
# ========== 评估函数 ==========
127+
def evaluate_sample(sample):
128+
prompt = build_prompt(sample)
129+
model_output = query_model(prompt)
130+
131+
gt_value = extract_gt_answer(sample["answer"])
132+
pred_value = extract_model_answer(model_output)
133+
is_correct = is_answer_equal(pred_value, gt_value)
134+
135+
result = {
136+
"question": sample["question"],
137+
"gt_answer": gt_value,
138+
"model_answer": pred_value,
139+
"raw_gt_answer": sample["answer"],
140+
"raw_model_output": model_output,
141+
"is_correct": is_correct,
142+
}
143+
144+
return result
145+
146+
147+
# ========== 主流程 ==========
148+
149+
acc = []
150+
times = 3
151+
152+
for i in range(times):
153+
correct = 0
154+
total = 0
155+
results = []
156+
157+
print(f"🚀 Starting evaluation with {max_workers} threads...")
158+
159+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
160+
futures = [executor.submit(evaluate_sample, sample) for sample in dataset]
161+
for future in tqdm(as_completed(futures), total=len(futures), desc="Evaluating"):
162+
result = future.result()
163+
results.append(result)
164+
total += 1
165+
if result["is_correct"]:
166+
correct += 1
167+
else:
168+
print("\n❌ Wrong prediction:")
169+
print(f"Q: {result['question']}")
170+
print(f"GT: {result['gt_answer']}")
171+
print(f"Model: {result['model_answer']}")
172+
print(f"Full GT: {result['raw_gt_answer']}")
173+
print(f"Model Output: {result['raw_model_output']}")
174+
175+
# ========== 输出准确率 ==========
176+
accuracy = correct / total * 100 if total > 0 else 0.0
177+
print(f"\n🎯 Evaluation Complete: Accuracy = {accuracy:.2f}% ({correct}/{total})")
178+
acc.append(accuracy)
179+
180+
avg_acc = round(sum(acc) / times / 100, 4) # 优化百分数
181+
print(f"平均准确率:{avg_acc * 100:.2f}%")
182+
183+
assert (
184+
abs(avg_acc - baseline) <= atol
185+
), f"模型准确率 {avg_acc:.2f} 与基准 {baseline:.2f} 相差 {abs(avg_acc - baseline):.2f},超出容忍范围 {atol:.2f}"
186+
187+
# with open("eval_result_math.json", "w", encoding="utf-8") as f:
188+
# json.dump(results, f, indent=2, ensure_ascii=False)

test/ce/server/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
sympy
2+
tqdm
3+
openai
4+
datasets

0 commit comments

Comments
 (0)