diff --git a/examples/basic_usage.py b/examples/basic_usage.py index 7231a76..157737b 100755 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -957,7 +957,7 @@ def demo_llm_webkit_with_preprocessed_html_evaluation(): print("1. 从真实数据集加载预处理HTML数据...") # 使用DataLoader加载真实的样本数据 - dataset_path = Path("/home/lulindong/Pycharm_projects/cc/WebMainBench_1904_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl") + dataset_path = Path("/home/lulindong/Pycharm_projects/cc/WebMainBench_1848_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl") print(f"📂 数据集文件: {dataset_path}") if not dataset_path.exists(): @@ -1101,7 +1101,6 @@ def demo_llm_webkit_with_preprocessed_html_evaluation(): # demo_extractor_comparison() # demo_dataset_with_extraction() # 演示保存带有抽取内容的数据集 # demo_multi_extraction() # 演示多个抽取器同时评测 - # demo_lld_workers_extraction() print("\n✅ 示例运行完成!") except Exception as e: diff --git a/tests/test_code_extraction.py b/tests/test_code_extraction.py new file mode 100644 index 0000000..2263439 --- /dev/null +++ b/tests/test_code_extraction.py @@ -0,0 +1,97 @@ +# tests/test/test_code_extraction.py +# !/usr/bin/env python +"""测试code提取功能""" + +import unittest +import sys +import os + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from webmainbench.metrics.base import BaseMetric, MetricResult + + +class TestCodeExtractionMetric(BaseMetric): + """测试用的具体实现类""" + + def _setup(self) -> None: + pass + + def _calculate_score(self, predicted: str, groundtruth: str, **kwargs) -> MetricResult: + return MetricResult( + metric_name=self.name, + score=1.0, + details={"test": True} + ) + + +class TestCodeExtraction(unittest.TestCase): + """测试code提取功能""" + + def setUp(self): + self.metric = TestCodeExtractionMetric("test_metric") + + def test_empty_text(self): + """测试空文本""" + result = BaseMetric._extract_from_markdown("") + self.assertEqual(result['code'], '') + self.assertEqual(result['text'], '') + + def test_inline_code(self): + """测试行内代码""" + text = "这是一个`行内代码`的例子" + result = BaseMetric._extract_from_markdown(text) + print(result) + self.assertEqual(result['code'], '行内代码') + self.assertEqual(result['text'], '这是一个的例子') + + def test_code_block(self): + """测试代码块""" + text = """ +I have the following string: `"aaaabbbb"` +How can I get the last four characters and store them in a string using Python? +Like this: +```python +>>> mystr = "abcdefghijkl" +>>> mystr[-4:] +'ijkl' +``` + """ + + result = BaseMetric._extract_from_markdown(text) + + # 验证提取的代码 + expected_code = (""" +"aaaabbbb" +>>> mystr = "abcdefghijkl" +>>> mystr[-4:] +'ijkl' + """) + self.assertEqual(result['code'], expected_code.strip()) + + # 验证清理后的文本 + expected_text = """ +I have the following string: +How can I get the last four characters and store them in a string using Python? +Like this: +""" + self.assertEqual(result['text'], expected_text.strip()) + self.assertEqual(result['formula'], '') + + def test_code_with_leading_trailing_spaces(self): + """测试代码前后有空格的情况""" + text = "前面 ` code ` 后面" + result = BaseMetric._extract_from_markdown(text) + self.assertEqual(result['code'], 'code') # 应该去除空格 + self.assertEqual(result['text'], '前面 后面') + + def test_multiline_inline_code(self): + """测试多行行内代码(不应该匹配)""" + text = "`第一行\n第二行`" + result = BaseMetric._extract_from_markdown(text) + self.assertEqual(result['code'], '') # 不应该匹配多行行内代码 + self.assertEqual(result['text'], text) # 原样保留 + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_formula_extraction.py b/tests/test_formula_extraction.py new file mode 100644 index 0000000..fd19bd5 --- /dev/null +++ b/tests/test_formula_extraction.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +"""测试Markdown公式提取功能""" + +import unittest +import sys +import os + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from webmainbench.metrics.base import BaseMetric, MetricResult + + +class TestFormulaExtractionMetric(BaseMetric): + """测试用的公式提取 metric 实现类""" + + def _setup(self) -> None: + pass + + def _calculate_score(self, predicted: str, groundtruth: str, **kwargs) -> MetricResult: + return MetricResult( + metric_name=self.name, + score=1.0, + details={"test": True} + ) + + +class TestFormulaExtraction(unittest.TestCase): + """测试Markdown公式提取功能""" + + def setUp(self): + self.metric = TestFormulaExtractionMetric("test_formula_metric") + + def test_inline_formula_extraction(self): + """测试行内公式提取""" + text = """这是行内公式示例: $E = mc^2$,这是普通文本。""" + + result = self.metric._extract_from_markdown(text) + + # 验证公式被提取 + self.assertIn('E = mc^2', result['formula']) + + # 验证文本中公式标记被移除 + self.assertNotIn('$E = mc^2$', result['text']) + self.assertIn('这是行内公式示例: ,这是普通文本。', result['text']) + + def test_block_formula_extraction(self): + """测试行间公式提取""" + text = """这是行间公式: +$$ +\\int_{-\\infty}^{\\infty} e^{-x^2} dx = \\sqrt{\\pi} +$$ +公式结束""" + + result = self.metric._extract_from_markdown(text) + + # 验证公式被提取 + self.assertIn('\\int_{-\\infty}^{\\infty} e^{-x^2} dx = \\sqrt{\\pi}', result['formula']) + + # 修正:允许提取后有多个空行 + self.assertIn('这是行间公式:', result['text']) + self.assertIn('公式结束', result['text']) + # 检查原始公式位置是否被清空 + self.assertNotIn('$$', result['text']) + + def test_escaped_dollar_signs(self): + """测试转义美元符号不被识别为公式""" + text = """ + 这是转义的美元符号: \\$100,不会被识别为公式。 +而这个是公式: $a + b = c$ +""" + + result = self.metric._extract_from_markdown(text) + # 验证转义的美元符号不被提取 + self.assertNotIn('100', result['formula']) + # 验证正常公式被提取 + self.assertIn('a + b = c', result['formula']) + # 验证转义符号保留在文本中 + self.assertIn('\\$100', result['text']) + + def test_multiple_formulas(self): + """测试多个公式提取""" + text = """公式1: $a = b + c$ +公式2: $$x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$$ +公式3: $E_k = \\frac{1}{2}mv^2$""" + + result = self.metric._extract_from_markdown(text) + + # 验证所有公式被提取 + self.assertIn('a = b + c', result['formula']) + self.assertIn('x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}', result['formula']) + self.assertIn('E_k = \\frac{1}{2}mv^2', result['formula']) + + # 验证公式间的分隔 + self.assertIn('\n', result['formula']) + + def test_formula_with_special_characters(self): + """测试包含特殊字符的公式""" + text = """复杂公式: $\\sum_{i=1}^n i = \\frac{n(n+1)}{2}$ +带希腊字母: $$\\alpha + \\beta = \\gamma$$""" + + result = self.metric._extract_from_markdown(text) + + # 验证特殊字符处理正确 + self.assertIn('\\sum_{i=1}^n i = \\frac{n(n+1)}{2}', result['formula']) + self.assertIn('\\alpha + \\beta = \\gamma', result['formula']) + + def test_formula_within_text(self): + """测试文本中的公式提取""" + text = """根据相对论 $E = mc^2$,能量和质量可以互相转换。 +更复杂的情况如 $$\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\epsilon_0}$$ 所示。""" + + result = self.metric._extract_from_markdown(text) + + # 验证公式被提取 + self.assertIn('E = mc^2', result['formula']) + self.assertIn('\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\epsilon_0}', result['formula']) + + # 修正:允许提取后有多个空格 + self.assertIn('根据相对论 ,能量和质量可以互相转换。', result['text']) + self.assertIn('更复杂的情况如 所示。', result['text']) + + def test_empty_formulas(self): + """测试空公式处理""" + text = """空行内公式: $ $ +空行间公式: $$ $$""" + + result = self.metric._extract_from_markdown(text) + + # 验证空公式被提取但内容为空 + self.assertTrue(result['formula'].strip() == '') + + # 验证空公式标记从文本中移除 + self.assertNotIn('$ $', result['text']) + self.assertNotIn('$$ $$', result['text']) + + def test_formula_at_document_edges(self): + """测试文档开头和结尾的公式""" + # 开头的公式 + text1 = """$start = 0$ +后续文本""" + result1 = self.metric._extract_from_markdown(text1) + self.assertIn('start = 0', result1['formula']) + + # 结尾的公式 + text2 = """前置文本 +$$end = 1$$""" + result2 = self.metric._extract_from_markdown(text2) + self.assertIn('end = 1', result2['formula']) + + def test_formula_within_table(self): + """测试表格中的公式提取""" + text = """| 公式类型 | 示例 | +|----------|------| +| 行内公式 | $a + b = c$ | +| 行间公式 | $$\\int_0^1 x dx = 0.5$$ |""" + + result = self.metric._extract_from_markdown(text) + + # 验证表格中的公式被提取 + self.assertIn('a + b = c', result['formula']) + self.assertIn('\\int_0^1 x dx = 0.5', result['formula']) + + # 验证表格结构仍然被正确提取 + self.assertIn('| 公式类型 | 示例 |', result['table']) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 5c38b24..93ee6f9 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -91,8 +91,8 @@ def test_code_edit_metric(self): self.assertTrue(code_result.success) self.assertIsInstance(code_result.score, float) # 验证固定内容的确定分数 - self.assertAlmostEqual(code_result.score, 0.918367, places=5, - msg=f"code_edit分数应该是0.918367,实际: {code_result.score}") + self.assertAlmostEqual(code_result.score, 0.9487179487179487, places=5, + msg=f"code_edit分数应该是0.9487179487179487,实际: {code_result.score}") # 验证详细信息 self.assertEqual(code_result.details['content_type'], 'code') @@ -164,8 +164,8 @@ def test_text_edit_metric(self): self.assertTrue(text_result.success) self.assertIsInstance(text_result.score, float) # 验证固定内容的确定分数 - self.assertAlmostEqual(text_result.score, 0.769231, places=5, - msg=f"text_edit分数应该是0.769231,实际: {text_result.score}") + self.assertAlmostEqual(text_result.score, 0.7692307692307692, places=5, + msg=f"text_edit分数应该是0.7692307692307692,实际: {text_result.score}") # 验证详细信息 self.assertEqual(text_result.details['content_type'], 'text') @@ -299,14 +299,14 @@ def hello_world(): # 验证文本编辑距离(固定内容应该有确定分数) self.assertIn("text_edit", results) self.assertTrue(results["text_edit"].success) - self.assertAlmostEqual(results["text_edit"].score, 1.000000, places=5, - msg=f"text_edit分数应该是1.000000,实际: {results['text_edit'].score}") + self.assertAlmostEqual(results["text_edit"].score, 1.0, places=5, + msg=f"text_edit分数应该是1.0,实际: {results['text_edit'].score}") # 验证代码编辑距离(缺少python标识符导致轻微差异) self.assertIn("code_edit", results) self.assertTrue(results["code_edit"].success) - self.assertAlmostEqual(results["code_edit"].score, 0.905797, places=5, - msg=f"code_edit分数应该是0.905797,实际: {results['code_edit'].score}") + self.assertAlmostEqual(results["code_edit"].score, 1.0, places=5, + msg=f"code_edit分数应该是1.0,实际: {results['code_edit'].score}") def test_table_sample_edit_distance(self): """测试表格样本的编辑距离""" @@ -367,14 +367,14 @@ def test_formula_sample_edit_distance(self): # 验证公式编辑距离(符号转义导致的固定低分) self.assertIn("formula_edit", results) self.assertTrue(results["formula_edit"].success) - self.assertAlmostEqual(results["formula_edit"].score, 0.000000, places=5, - msg=f"formula_edit分数应该是0.000000,实际: {results['formula_edit'].score}") + self.assertAlmostEqual(results["formula_edit"].score, 0.0, places=5, + msg=f"formula_edit分数应该是0.0,实际: {results['formula_edit'].score}") # 验证文本编辑距离(去除公式后的纯文本,也受符号转义影响) self.assertIn("text_edit", results) self.assertTrue(results["text_edit"].success) - self.assertAlmostEqual(results["text_edit"].score, 0.320000, places=5, - msg=f"text_edit分数应该是0.320000,实际: {results['text_edit'].score}") + self.assertAlmostEqual(results["text_edit"].score, 0.31999999999999995, places=5, + msg=f"text_edit分数应该是0.31999999999999995,实际: {results['text_edit'].score}") def test_overall_score_calculation(self): """测试综合分数计算""" diff --git a/webmainbench/metrics/base.py b/webmainbench/metrics/base.py index a2d27ee..22d3f73 100644 --- a/webmainbench/metrics/base.py +++ b/webmainbench/metrics/base.py @@ -201,34 +201,60 @@ def _extract_from_markdown(text: str) -> Dict[str, str]: # 收集所有需要移除的内容片段 extracted_segments = [] - - # 提取代码 code_parts = [] - # 代码块 ```code``` - for match in re.finditer(r'```[\s\S]*?```', text): - code_block = match.group(0) - extracted_segments.append(code_block) - code_parts.append(code_block.strip('`').strip()) + # 同时匹配行内代码 `...` 和代码块 ```...``` + pattern = r'(```[\s\S]*?```|`[^`\n]+`)' + for match in re.finditer(pattern, text): + code_segment = match.group(0) + extracted_segments.append(code_segment) + + if code_segment.startswith('```'): + # 处理代码块(保留内部缩进) + lines = code_segment.split('\n') + # 移除首尾的```标记 + content_lines = lines[1:-1] + # 保留原始缩进,只拼接内容 + code_content = '\n'.join(content_lines) + else: + # 处理行内代码(只去除外层`和前后空格) + code_content = code_segment[1:-1].strip() + + if code_content: # 只添加非空内容 + code_parts.append(code_content) - # 行内代码 `code` - for match in re.finditer(r'`([^`]+)`', text): - inline_code_full = match.group(0) # 包含反引号的完整匹配 - inline_code_content = match.group(1) # 只是内容 - extracted_segments.append(inline_code_full) - code_parts.append(inline_code_content) + # # 提取代码 + # code_parts = [] + # # 代码块 ```code``` + # for match in re.finditer(r'```[\s\S]*?```', text): + # code_block = match.group(0) + # extracted_segments.append(code_block) + # code_parts.append(code_block.strip('`').strip()) + # + # # 行内代码 `code` + # for match in re.finditer(r'`([^`]+)`', text): + # inline_code_full = match.group(0) # 包含反引号的完整匹配 + # inline_code_content = match.group(1) # 只是内容 + # extracted_segments.append(inline_code_full) + # code_parts.append(inline_code_content) # 提取公式 formula_parts = [] # 统一的公式提取模式 latex_patterns = [ - r'(?