Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def demo_llm_webkit_with_preprocessed_html_evaluation():
print("1. 从真实数据集加载预处理HTML数据...")

# 使用DataLoader加载真实的样本数据
dataset_path = Path("/home/lulindong/Pycharm_projects/cc/WebMainBench_1904_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl")
dataset_path = Path("/home/lulindong/Pycharm_projects/cc/WebMainBench_1848_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl")
print(f"📂 数据集文件: {dataset_path}")

if not dataset_path.exists():
Expand Down Expand Up @@ -1101,7 +1101,6 @@ def demo_llm_webkit_with_preprocessed_html_evaluation():
# demo_extractor_comparison()
# demo_dataset_with_extraction() # 演示保存带有抽取内容的数据集
# demo_multi_extraction() # 演示多个抽取器同时评测
# demo_lld_workers_extraction()
print("\n✅ 示例运行完成!")

except Exception as e:
Expand Down
97 changes: 97 additions & 0 deletions tests/test_code_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# tests/test/test_code_extraction.py
# !/usr/bin/env python
"""测试code提取功能"""

import unittest
import sys
import os

# 添加项目根目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from webmainbench.metrics.base import BaseMetric, MetricResult


class TestCodeExtractionMetric(BaseMetric):
"""测试用的具体实现类"""

def _setup(self) -> None:
pass

def _calculate_score(self, predicted: str, groundtruth: str, **kwargs) -> MetricResult:
return MetricResult(
metric_name=self.name,
score=1.0,
details={"test": True}
)


class TestCodeExtraction(unittest.TestCase):
"""测试code提取功能"""

def setUp(self):
self.metric = TestCodeExtractionMetric("test_metric")

def test_empty_text(self):
"""测试空文本"""
result = BaseMetric._extract_from_markdown("")
self.assertEqual(result['code'], '')
self.assertEqual(result['text'], '')

def test_inline_code(self):
"""测试行内代码"""
text = "这是一个`行内代码`的例子"
result = BaseMetric._extract_from_markdown(text)
print(result)
self.assertEqual(result['code'], '行内代码')
self.assertEqual(result['text'], '这是一个的例子')

def test_code_block(self):
"""测试代码块"""
text = """
I have the following string: `"aaaabbbb"`
How can I get the last four characters and store them in a string using Python?
Like this:
```python
>>> mystr = "abcdefghijkl"
>>> mystr[-4:]
'ijkl'
```
"""

result = BaseMetric._extract_from_markdown(text)

# 验证提取的代码
expected_code = ("""
"aaaabbbb"
>>> mystr = "abcdefghijkl"
>>> mystr[-4:]
'ijkl'
""")
self.assertEqual(result['code'], expected_code.strip())

# 验证清理后的文本
expected_text = """
I have the following string:
How can I get the last four characters and store them in a string using Python?
Like this:
"""
self.assertEqual(result['text'], expected_text.strip())
self.assertEqual(result['formula'], '')

def test_code_with_leading_trailing_spaces(self):
"""测试代码前后有空格的情况"""
text = "前面 ` code ` 后面"
result = BaseMetric._extract_from_markdown(text)
self.assertEqual(result['code'], 'code') # 应该去除空格
self.assertEqual(result['text'], '前面 后面')

def test_multiline_inline_code(self):
"""测试多行行内代码(不应该匹配)"""
text = "`第一行\n第二行`"
result = BaseMetric._extract_from_markdown(text)
self.assertEqual(result['code'], '') # 不应该匹配多行行内代码
self.assertEqual(result['text'], text) # 原样保留

if __name__ == '__main__':
unittest.main()
169 changes: 169 additions & 0 deletions tests/test_formula_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#!/usr/bin/env python
"""测试Markdown公式提取功能"""

import unittest
import sys
import os

# 添加项目根目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from webmainbench.metrics.base import BaseMetric, MetricResult


class TestFormulaExtractionMetric(BaseMetric):
"""测试用的公式提取 metric 实现类"""

def _setup(self) -> None:
pass

def _calculate_score(self, predicted: str, groundtruth: str, **kwargs) -> MetricResult:
return MetricResult(
metric_name=self.name,
score=1.0,
details={"test": True}
)


class TestFormulaExtraction(unittest.TestCase):
"""测试Markdown公式提取功能"""

def setUp(self):
self.metric = TestFormulaExtractionMetric("test_formula_metric")

def test_inline_formula_extraction(self):
"""测试行内公式提取"""
text = """这是行内公式示例: $E = mc^2$,这是普通文本。"""

result = self.metric._extract_from_markdown(text)

# 验证公式被提取
self.assertIn('E = mc^2', result['formula'])

# 验证文本中公式标记被移除
self.assertNotIn('$E = mc^2$', result['text'])
self.assertIn('这是行内公式示例: ,这是普通文本。', result['text'])

def test_block_formula_extraction(self):
"""测试行间公式提取"""
text = """这是行间公式:
$$
\\int_{-\\infty}^{\\infty} e^{-x^2} dx = \\sqrt{\\pi}
$$
公式结束"""

result = self.metric._extract_from_markdown(text)

# 验证公式被提取
self.assertIn('\\int_{-\\infty}^{\\infty} e^{-x^2} dx = \\sqrt{\\pi}', result['formula'])

# 修正:允许提取后有多个空行
self.assertIn('这是行间公式:', result['text'])
self.assertIn('公式结束', result['text'])
# 检查原始公式位置是否被清空
self.assertNotIn('$$', result['text'])

def test_escaped_dollar_signs(self):
"""测试转义美元符号不被识别为公式"""
text = """
这是转义的美元符号: \\$100,不会被识别为公式。
而这个是公式: $a + b = c$
"""

result = self.metric._extract_from_markdown(text)
# 验证转义的美元符号不被提取
self.assertNotIn('100', result['formula'])
# 验证正常公式被提取
self.assertIn('a + b = c', result['formula'])
# 验证转义符号保留在文本中
self.assertIn('\\$100', result['text'])

def test_multiple_formulas(self):
"""测试多个公式提取"""
text = """公式1: $a = b + c$
公式2: $$x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$$
公式3: $E_k = \\frac{1}{2}mv^2$"""

result = self.metric._extract_from_markdown(text)

# 验证所有公式被提取
self.assertIn('a = b + c', result['formula'])
self.assertIn('x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}', result['formula'])
self.assertIn('E_k = \\frac{1}{2}mv^2', result['formula'])

# 验证公式间的分隔
self.assertIn('\n', result['formula'])

def test_formula_with_special_characters(self):
"""测试包含特殊字符的公式"""
text = """复杂公式: $\\sum_{i=1}^n i = \\frac{n(n+1)}{2}$
带希腊字母: $$\\alpha + \\beta = \\gamma$$"""

result = self.metric._extract_from_markdown(text)

# 验证特殊字符处理正确
self.assertIn('\\sum_{i=1}^n i = \\frac{n(n+1)}{2}', result['formula'])
self.assertIn('\\alpha + \\beta = \\gamma', result['formula'])

def test_formula_within_text(self):
"""测试文本中的公式提取"""
text = """根据相对论 $E = mc^2$,能量和质量可以互相转换。
更复杂的情况如 $$\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\epsilon_0}$$ 所示。"""

result = self.metric._extract_from_markdown(text)

# 验证公式被提取
self.assertIn('E = mc^2', result['formula'])
self.assertIn('\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\epsilon_0}', result['formula'])

# 修正:允许提取后有多个空格
self.assertIn('根据相对论 ,能量和质量可以互相转换。', result['text'])
self.assertIn('更复杂的情况如 所示。', result['text'])

def test_empty_formulas(self):
"""测试空公式处理"""
text = """空行内公式: $ $
空行间公式: $$ $$"""

result = self.metric._extract_from_markdown(text)

# 验证空公式被提取但内容为空
self.assertTrue(result['formula'].strip() == '')

# 验证空公式标记从文本中移除
self.assertNotIn('$ $', result['text'])
self.assertNotIn('$$ $$', result['text'])

def test_formula_at_document_edges(self):
"""测试文档开头和结尾的公式"""
# 开头的公式
text1 = """$start = 0$
后续文本"""
result1 = self.metric._extract_from_markdown(text1)
self.assertIn('start = 0', result1['formula'])

# 结尾的公式
text2 = """前置文本
$$end = 1$$"""
result2 = self.metric._extract_from_markdown(text2)
self.assertIn('end = 1', result2['formula'])

def test_formula_within_table(self):
"""测试表格中的公式提取"""
text = """| 公式类型 | 示例 |
|----------|------|
| 行内公式 | $a + b = c$ |
| 行间公式 | $$\\int_0^1 x dx = 0.5$$ |"""

result = self.metric._extract_from_markdown(text)

# 验证表格中的公式被提取
self.assertIn('a + b = c', result['formula'])
self.assertIn('\\int_0^1 x dx = 0.5', result['formula'])

# 验证表格结构仍然被正确提取
self.assertIn('| 公式类型 | 示例 |', result['table'])


if __name__ == '__main__':
unittest.main()
24 changes: 12 additions & 12 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def test_code_edit_metric(self):
self.assertTrue(code_result.success)
self.assertIsInstance(code_result.score, float)
# 验证固定内容的确定分数
self.assertAlmostEqual(code_result.score, 0.918367, places=5,
msg=f"code_edit分数应该是0.918367,实际: {code_result.score}")
self.assertAlmostEqual(code_result.score, 0.9487179487179487, places=5,
msg=f"code_edit分数应该是0.9487179487179487,实际: {code_result.score}")

# 验证详细信息
self.assertEqual(code_result.details['content_type'], 'code')
Expand Down Expand Up @@ -164,8 +164,8 @@ def test_text_edit_metric(self):
self.assertTrue(text_result.success)
self.assertIsInstance(text_result.score, float)
# 验证固定内容的确定分数
self.assertAlmostEqual(text_result.score, 0.769231, places=5,
msg=f"text_edit分数应该是0.769231,实际: {text_result.score}")
self.assertAlmostEqual(text_result.score, 0.7692307692307692, places=5,
msg=f"text_edit分数应该是0.7692307692307692,实际: {text_result.score}")

# 验证详细信息
self.assertEqual(text_result.details['content_type'], 'text')
Expand Down Expand Up @@ -299,14 +299,14 @@ def hello_world():
# 验证文本编辑距离(固定内容应该有确定分数)
self.assertIn("text_edit", results)
self.assertTrue(results["text_edit"].success)
self.assertAlmostEqual(results["text_edit"].score, 1.000000, places=5,
msg=f"text_edit分数应该是1.000000,实际: {results['text_edit'].score}")
self.assertAlmostEqual(results["text_edit"].score, 1.0, places=5,
msg=f"text_edit分数应该是1.0,实际: {results['text_edit'].score}")

# 验证代码编辑距离(缺少python标识符导致轻微差异)
self.assertIn("code_edit", results)
self.assertTrue(results["code_edit"].success)
self.assertAlmostEqual(results["code_edit"].score, 0.905797, places=5,
msg=f"code_edit分数应该是0.905797,实际: {results['code_edit'].score}")
self.assertAlmostEqual(results["code_edit"].score, 1.0, places=5,
msg=f"code_edit分数应该是1.0,实际: {results['code_edit'].score}")

def test_table_sample_edit_distance(self):
"""测试表格样本的编辑距离"""
Expand Down Expand Up @@ -367,14 +367,14 @@ def test_formula_sample_edit_distance(self):
# 验证公式编辑距离(符号转义导致的固定低分)
self.assertIn("formula_edit", results)
self.assertTrue(results["formula_edit"].success)
self.assertAlmostEqual(results["formula_edit"].score, 0.000000, places=5,
msg=f"formula_edit分数应该是0.000000,实际: {results['formula_edit'].score}")
self.assertAlmostEqual(results["formula_edit"].score, 0.0, places=5,
msg=f"formula_edit分数应该是0.0,实际: {results['formula_edit'].score}")

# 验证文本编辑距离(去除公式后的纯文本,也受符号转义影响)
self.assertIn("text_edit", results)
self.assertTrue(results["text_edit"].success)
self.assertAlmostEqual(results["text_edit"].score, 0.320000, places=5,
msg=f"text_edit分数应该是0.320000,实际: {results['text_edit'].score}")
self.assertAlmostEqual(results["text_edit"].score, 0.31999999999999995, places=5,
msg=f"text_edit分数应该是0.31999999999999995,实际: {results['text_edit'].score}")

def test_overall_score_calculation(self):
"""测试综合分数计算"""
Expand Down
Loading