Skip to content

Commit 0d60eca

Browse files
authored
Merge pull request #28 from pekopoke/dev
Dev:fix match formula and code
2 parents 4cf0655 + 7e63abd commit 0d60eca

File tree

5 files changed

+320
-29
lines changed

5 files changed

+320
-29
lines changed

examples/basic_usage.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ def demo_llm_webkit_with_preprocessed_html_evaluation():
957957
print("1. 从真实数据集加载预处理HTML数据...")
958958

959959
# 使用DataLoader加载真实的样本数据
960-
dataset_path = Path("/home/lulindong/Pycharm_projects/cc/WebMainBench_1904_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl")
960+
dataset_path = Path("/home/lulindong/Pycharm_projects/cc/WebMainBench_1848_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl")
961961
print(f"📂 数据集文件: {dataset_path}")
962962

963963
if not dataset_path.exists():
@@ -1101,7 +1101,6 @@ def demo_llm_webkit_with_preprocessed_html_evaluation():
11011101
# demo_extractor_comparison()
11021102
# demo_dataset_with_extraction() # 演示保存带有抽取内容的数据集
11031103
# demo_multi_extraction() # 演示多个抽取器同时评测
1104-
# demo_lld_workers_extraction()
11051104
print("\n✅ 示例运行完成!")
11061105

11071106
except Exception as e:

tests/test_code_extraction.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# tests/test/test_code_extraction.py
2+
# !/usr/bin/env python
3+
"""测试code提取功能"""
4+
5+
import unittest
6+
import sys
7+
import os
8+
9+
# 添加项目根目录到Python路径
10+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
11+
12+
from webmainbench.metrics.base import BaseMetric, MetricResult
13+
14+
15+
class TestCodeExtractionMetric(BaseMetric):
16+
"""测试用的具体实现类"""
17+
18+
def _setup(self) -> None:
19+
pass
20+
21+
def _calculate_score(self, predicted: str, groundtruth: str, **kwargs) -> MetricResult:
22+
return MetricResult(
23+
metric_name=self.name,
24+
score=1.0,
25+
details={"test": True}
26+
)
27+
28+
29+
class TestCodeExtraction(unittest.TestCase):
30+
"""测试code提取功能"""
31+
32+
def setUp(self):
33+
self.metric = TestCodeExtractionMetric("test_metric")
34+
35+
def test_empty_text(self):
36+
"""测试空文本"""
37+
result = BaseMetric._extract_from_markdown("")
38+
self.assertEqual(result['code'], '')
39+
self.assertEqual(result['text'], '')
40+
41+
def test_inline_code(self):
42+
"""测试行内代码"""
43+
text = "这是一个`行内代码`的例子"
44+
result = BaseMetric._extract_from_markdown(text)
45+
print(result)
46+
self.assertEqual(result['code'], '行内代码')
47+
self.assertEqual(result['text'], '这是一个的例子')
48+
49+
def test_code_block(self):
50+
"""测试代码块"""
51+
text = """
52+
I have the following string: `"aaaabbbb"`
53+
How can I get the last four characters and store them in a string using Python?
54+
Like this:
55+
```python
56+
>>> mystr = "abcdefghijkl"
57+
>>> mystr[-4:]
58+
'ijkl'
59+
```
60+
"""
61+
62+
result = BaseMetric._extract_from_markdown(text)
63+
64+
# 验证提取的代码
65+
expected_code = ("""
66+
"aaaabbbb"
67+
>>> mystr = "abcdefghijkl"
68+
>>> mystr[-4:]
69+
'ijkl'
70+
""")
71+
self.assertEqual(result['code'], expected_code.strip())
72+
73+
# 验证清理后的文本
74+
expected_text = """
75+
I have the following string:
76+
How can I get the last four characters and store them in a string using Python?
77+
Like this:
78+
"""
79+
self.assertEqual(result['text'], expected_text.strip())
80+
self.assertEqual(result['formula'], '')
81+
82+
def test_code_with_leading_trailing_spaces(self):
83+
"""测试代码前后有空格的情况"""
84+
text = "前面 ` code ` 后面"
85+
result = BaseMetric._extract_from_markdown(text)
86+
self.assertEqual(result['code'], 'code') # 应该去除空格
87+
self.assertEqual(result['text'], '前面 后面')
88+
89+
def test_multiline_inline_code(self):
90+
"""测试多行行内代码(不应该匹配)"""
91+
text = "`第一行\n第二行`"
92+
result = BaseMetric._extract_from_markdown(text)
93+
self.assertEqual(result['code'], '') # 不应该匹配多行行内代码
94+
self.assertEqual(result['text'], text) # 原样保留
95+
96+
if __name__ == '__main__':
97+
unittest.main()

tests/test_formula_extraction.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#!/usr/bin/env python
2+
"""测试Markdown公式提取功能"""
3+
4+
import unittest
5+
import sys
6+
import os
7+
8+
# 添加项目根目录到Python路径
9+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
10+
11+
from webmainbench.metrics.base import BaseMetric, MetricResult
12+
13+
14+
class TestFormulaExtractionMetric(BaseMetric):
15+
"""测试用的公式提取 metric 实现类"""
16+
17+
def _setup(self) -> None:
18+
pass
19+
20+
def _calculate_score(self, predicted: str, groundtruth: str, **kwargs) -> MetricResult:
21+
return MetricResult(
22+
metric_name=self.name,
23+
score=1.0,
24+
details={"test": True}
25+
)
26+
27+
28+
class TestFormulaExtraction(unittest.TestCase):
29+
"""测试Markdown公式提取功能"""
30+
31+
def setUp(self):
32+
self.metric = TestFormulaExtractionMetric("test_formula_metric")
33+
34+
def test_inline_formula_extraction(self):
35+
"""测试行内公式提取"""
36+
text = """这是行内公式示例: $E = mc^2$,这是普通文本。"""
37+
38+
result = self.metric._extract_from_markdown(text)
39+
40+
# 验证公式被提取
41+
self.assertIn('E = mc^2', result['formula'])
42+
43+
# 验证文本中公式标记被移除
44+
self.assertNotIn('$E = mc^2$', result['text'])
45+
self.assertIn('这是行内公式示例: ,这是普通文本。', result['text'])
46+
47+
def test_block_formula_extraction(self):
48+
"""测试行间公式提取"""
49+
text = """这是行间公式:
50+
$$
51+
\\int_{-\\infty}^{\\infty} e^{-x^2} dx = \\sqrt{\\pi}
52+
$$
53+
公式结束"""
54+
55+
result = self.metric._extract_from_markdown(text)
56+
57+
# 验证公式被提取
58+
self.assertIn('\\int_{-\\infty}^{\\infty} e^{-x^2} dx = \\sqrt{\\pi}', result['formula'])
59+
60+
# 修正:允许提取后有多个空行
61+
self.assertIn('这是行间公式:', result['text'])
62+
self.assertIn('公式结束', result['text'])
63+
# 检查原始公式位置是否被清空
64+
self.assertNotIn('$$', result['text'])
65+
66+
def test_escaped_dollar_signs(self):
67+
"""测试转义美元符号不被识别为公式"""
68+
text = """
69+
这是转义的美元符号: \\$100,不会被识别为公式。
70+
而这个是公式: $a + b = c$
71+
"""
72+
73+
result = self.metric._extract_from_markdown(text)
74+
# 验证转义的美元符号不被提取
75+
self.assertNotIn('100', result['formula'])
76+
# 验证正常公式被提取
77+
self.assertIn('a + b = c', result['formula'])
78+
# 验证转义符号保留在文本中
79+
self.assertIn('\\$100', result['text'])
80+
81+
def test_multiple_formulas(self):
82+
"""测试多个公式提取"""
83+
text = """公式1: $a = b + c$
84+
公式2: $$x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$$
85+
公式3: $E_k = \\frac{1}{2}mv^2$"""
86+
87+
result = self.metric._extract_from_markdown(text)
88+
89+
# 验证所有公式被提取
90+
self.assertIn('a = b + c', result['formula'])
91+
self.assertIn('x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}', result['formula'])
92+
self.assertIn('E_k = \\frac{1}{2}mv^2', result['formula'])
93+
94+
# 验证公式间的分隔
95+
self.assertIn('\n', result['formula'])
96+
97+
def test_formula_with_special_characters(self):
98+
"""测试包含特殊字符的公式"""
99+
text = """复杂公式: $\\sum_{i=1}^n i = \\frac{n(n+1)}{2}$
100+
带希腊字母: $$\\alpha + \\beta = \\gamma$$"""
101+
102+
result = self.metric._extract_from_markdown(text)
103+
104+
# 验证特殊字符处理正确
105+
self.assertIn('\\sum_{i=1}^n i = \\frac{n(n+1)}{2}', result['formula'])
106+
self.assertIn('\\alpha + \\beta = \\gamma', result['formula'])
107+
108+
def test_formula_within_text(self):
109+
"""测试文本中的公式提取"""
110+
text = """根据相对论 $E = mc^2$,能量和质量可以互相转换。
111+
更复杂的情况如 $$\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\epsilon_0}$$ 所示。"""
112+
113+
result = self.metric._extract_from_markdown(text)
114+
115+
# 验证公式被提取
116+
self.assertIn('E = mc^2', result['formula'])
117+
self.assertIn('\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\epsilon_0}', result['formula'])
118+
119+
# 修正:允许提取后有多个空格
120+
self.assertIn('根据相对论 ,能量和质量可以互相转换。', result['text'])
121+
self.assertIn('更复杂的情况如 所示。', result['text'])
122+
123+
def test_empty_formulas(self):
124+
"""测试空公式处理"""
125+
text = """空行内公式: $ $
126+
空行间公式: $$ $$"""
127+
128+
result = self.metric._extract_from_markdown(text)
129+
130+
# 验证空公式被提取但内容为空
131+
self.assertTrue(result['formula'].strip() == '')
132+
133+
# 验证空公式标记从文本中移除
134+
self.assertNotIn('$ $', result['text'])
135+
self.assertNotIn('$$ $$', result['text'])
136+
137+
def test_formula_at_document_edges(self):
138+
"""测试文档开头和结尾的公式"""
139+
# 开头的公式
140+
text1 = """$start = 0$
141+
后续文本"""
142+
result1 = self.metric._extract_from_markdown(text1)
143+
self.assertIn('start = 0', result1['formula'])
144+
145+
# 结尾的公式
146+
text2 = """前置文本
147+
$$end = 1$$"""
148+
result2 = self.metric._extract_from_markdown(text2)
149+
self.assertIn('end = 1', result2['formula'])
150+
151+
def test_formula_within_table(self):
152+
"""测试表格中的公式提取"""
153+
text = """| 公式类型 | 示例 |
154+
|----------|------|
155+
| 行内公式 | $a + b = c$ |
156+
| 行间公式 | $$\\int_0^1 x dx = 0.5$$ |"""
157+
158+
result = self.metric._extract_from_markdown(text)
159+
160+
# 验证表格中的公式被提取
161+
self.assertIn('a + b = c', result['formula'])
162+
self.assertIn('\\int_0^1 x dx = 0.5', result['formula'])
163+
164+
# 验证表格结构仍然被正确提取
165+
self.assertIn('| 公式类型 | 示例 |', result['table'])
166+
167+
168+
if __name__ == '__main__':
169+
unittest.main()

tests/test_metrics.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def test_code_edit_metric(self):
9191
self.assertTrue(code_result.success)
9292
self.assertIsInstance(code_result.score, float)
9393
# 验证固定内容的确定分数
94-
self.assertAlmostEqual(code_result.score, 0.918367, places=5,
95-
msg=f"code_edit分数应该是0.918367,实际: {code_result.score}")
94+
self.assertAlmostEqual(code_result.score, 0.9487179487179487, places=5,
95+
msg=f"code_edit分数应该是0.9487179487179487,实际: {code_result.score}")
9696

9797
# 验证详细信息
9898
self.assertEqual(code_result.details['content_type'], 'code')
@@ -164,8 +164,8 @@ def test_text_edit_metric(self):
164164
self.assertTrue(text_result.success)
165165
self.assertIsInstance(text_result.score, float)
166166
# 验证固定内容的确定分数
167-
self.assertAlmostEqual(text_result.score, 0.769231, places=5,
168-
msg=f"text_edit分数应该是0.769231,实际: {text_result.score}")
167+
self.assertAlmostEqual(text_result.score, 0.7692307692307692, places=5,
168+
msg=f"text_edit分数应该是0.7692307692307692,实际: {text_result.score}")
169169

170170
# 验证详细信息
171171
self.assertEqual(text_result.details['content_type'], 'text')
@@ -299,14 +299,14 @@ def hello_world():
299299
# 验证文本编辑距离(固定内容应该有确定分数)
300300
self.assertIn("text_edit", results)
301301
self.assertTrue(results["text_edit"].success)
302-
self.assertAlmostEqual(results["text_edit"].score, 1.000000, places=5,
303-
msg=f"text_edit分数应该是1.000000,实际: {results['text_edit'].score}")
302+
self.assertAlmostEqual(results["text_edit"].score, 1.0, places=5,
303+
msg=f"text_edit分数应该是1.0,实际: {results['text_edit'].score}")
304304

305305
# 验证代码编辑距离(缺少python标识符导致轻微差异)
306306
self.assertIn("code_edit", results)
307307
self.assertTrue(results["code_edit"].success)
308-
self.assertAlmostEqual(results["code_edit"].score, 0.905797, places=5,
309-
msg=f"code_edit分数应该是0.905797,实际: {results['code_edit'].score}")
308+
self.assertAlmostEqual(results["code_edit"].score, 1.0, places=5,
309+
msg=f"code_edit分数应该是1.0,实际: {results['code_edit'].score}")
310310

311311
def test_table_sample_edit_distance(self):
312312
"""测试表格样本的编辑距离"""
@@ -367,14 +367,14 @@ def test_formula_sample_edit_distance(self):
367367
# 验证公式编辑距离(符号转义导致的固定低分)
368368
self.assertIn("formula_edit", results)
369369
self.assertTrue(results["formula_edit"].success)
370-
self.assertAlmostEqual(results["formula_edit"].score, 0.000000, places=5,
371-
msg=f"formula_edit分数应该是0.000000,实际: {results['formula_edit'].score}")
370+
self.assertAlmostEqual(results["formula_edit"].score, 0.0, places=5,
371+
msg=f"formula_edit分数应该是0.0,实际: {results['formula_edit'].score}")
372372

373373
# 验证文本编辑距离(去除公式后的纯文本,也受符号转义影响)
374374
self.assertIn("text_edit", results)
375375
self.assertTrue(results["text_edit"].success)
376-
self.assertAlmostEqual(results["text_edit"].score, 0.320000, places=5,
377-
msg=f"text_edit分数应该是0.320000,实际: {results['text_edit'].score}")
376+
self.assertAlmostEqual(results["text_edit"].score, 0.31999999999999995, places=5,
377+
msg=f"text_edit分数应该是0.31999999999999995,实际: {results['text_edit'].score}")
378378

379379
def test_overall_score_calculation(self):
380380
"""测试综合分数计算"""

0 commit comments

Comments
 (0)