Skip to content

Commit 0051ce9

Browse files
authored
Merge pull request #25 from SHUzhangshuo/main
fix:extract from markdown
2 parents b6f3750 + 2008717 commit 0051ce9

File tree

3 files changed

+363
-33
lines changed

3 files changed

+363
-33
lines changed

examples/test_table_extract.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#!/usr/bin/env python3
2+
"""
3+
脚本:仅提取 WebMainBench 数据集中的表格内容到 table.md
4+
"""
5+
6+
import json
7+
import sys
8+
import os
9+
from pathlib import Path
10+
11+
# 添加父目录到 sys.path 以便导入 webmainbench
12+
sys.path.append(str(Path(__file__).parent.parent))
13+
14+
from webmainbench.metrics.base import BaseMetric
15+
16+
def extract_only_tables_from_dataset():
17+
"""只提取 WebMainBench 数据集中的表格内容并输出到 table.md(table为空的不记录)"""
18+
19+
# 路径配置
20+
dataset_path = "/home/zhangshuo/Desktop/vscodeworkspace/WebMainBench/data/WebMainBench_llm-webkit_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl"
21+
output_path = "table.md"
22+
23+
# 检查数据集文件是否存在
24+
if not os.path.exists(dataset_path):
25+
print(f"错误:未找到数据集文件 {dataset_path}")
26+
return
27+
28+
extracted_tables = []
29+
line_ids = []
30+
31+
# 按行读取 JSONL 文件
32+
with open(dataset_path, 'r', encoding='utf-8') as f:
33+
for line_num, line in enumerate(f, 1):
34+
try:
35+
data = json.loads(line.strip())
36+
37+
# 提取ID和内容
38+
item_id = data.get('track_id', f'line_{line_num}')
39+
content = data.get('llm_webkit_md', '')
40+
41+
# 使用 _extract_from_markdown 提取
42+
if content:
43+
extracted = BaseMetric._extract_from_markdown(content)
44+
table_content = extracted.get("table", "")
45+
# 只记录table不为空的项
46+
if table_content and table_content.strip():
47+
extracted_tables.append(table_content)
48+
line_ids.append((item_id, line_num))
49+
except json.JSONDecodeError as e:
50+
print(f"解析JSON出错,行{line_num}: {e}")
51+
continue
52+
except Exception as e:
53+
print(f"处理第{line_num}行时出错: {e}")
54+
continue
55+
56+
# 写入 table.md 文件,只输出 table 字段
57+
with open(output_path, 'w', encoding='utf-8') as f:
58+
f.write("# Extracted Table Content from WebMainBench Dataset\n\n")
59+
f.write(f"Total items processed: {len(extracted_tables)}\n\n")
60+
61+
for idx, (table_content, (item_id, line_num)) in enumerate(zip(extracted_tables, line_ids), 1):
62+
f.write(f"## Item {idx}\n")
63+
f.write(f"- **ID**: {item_id}\n")
64+
f.write(f"- **Line Number**: {line_num}\n")
65+
f.write(f"- **Extracted Table**:\n\n")
66+
f.write("```\n")
67+
f.write(table_content)
68+
f.write("\n```\n\n")
69+
f.write("---\n\n")
70+
71+
print(f"表格提取完成!共处理 {len(extracted_tables)} 条数据。")
72+
print(f"表格内容已保存到: {output_path}")
73+
74+
if __name__ == "__main__":
75+
extract_only_tables_from_dataset()

tests/test_table_extraction.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
#!/usr/bin/env python
2+
"""测试Markdown表格提取功能"""
3+
4+
import unittest
5+
import sys
6+
import os
7+
8+
# 添加项目根目录到Python路径
9+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
10+
11+
from webmainbench.metrics.base import BaseMetric, MetricResult
12+
13+
14+
class TestTableExtractionMetric(BaseMetric):
15+
"""测试用的具体实现类"""
16+
17+
def _setup(self) -> None:
18+
pass
19+
20+
def _calculate_score(self, predicted: str, groundtruth: str, **kwargs) -> MetricResult:
21+
return MetricResult(
22+
metric_name=self.name,
23+
score=1.0,
24+
details={"test": True}
25+
)
26+
27+
28+
class TestTableExtraction(unittest.TestCase):
29+
"""测试Markdown表格提取功能"""
30+
31+
def setUp(self):
32+
self.metric = TestTableExtractionMetric("test_metric")
33+
34+
def test_basic_table_extraction(self):
35+
"""测试基本表格提取"""
36+
text = """文字内容
37+
38+
| 列1 | 列2 |
39+
|-----|-----|
40+
| 数据1 | 数据2 |
41+
42+
更多文字"""
43+
44+
result = self.metric._extract_from_markdown(text)
45+
46+
# 验证表格被提取
47+
self.assertIn('| 列1 | 列2 |', result['table'])
48+
self.assertIn('|-----|-----|', result['table'])
49+
self.assertIn('| 数据1 | 数据2 |', result['table'])
50+
51+
# 验证文本中表格被移除
52+
self.assertNotIn('| 列1 | 列2 |', result['text'])
53+
54+
def test_no_name_error(self):
55+
"""测试修复后的代码不会出现 'name table_lines is not defined' 错误"""
56+
text = """| A | B |
57+
|-----|-----|
58+
| 1 | 2 |"""
59+
60+
try:
61+
result = self.metric._extract_from_markdown(text)
62+
self.assertIsInstance(result, dict)
63+
self.assertIn('table', result)
64+
print(f"✅ 表格提取成功: {repr(result['table'])}")
65+
except NameError as e:
66+
if 'table_lines' in str(e):
67+
self.fail(f"仍然存在table_lines未定义的错误: {e}")
68+
else:
69+
raise
70+
71+
def test_html_table_extraction(self):
72+
"""测试HTML表格提取"""
73+
text = """这是HTML表格:
74+
<table>
75+
<tr><th>标题1</th><th>标题2</th></tr>
76+
<tr><td>数据1</td><td>数据2</td></tr>
77+
</table>
78+
这是普通文本。"""
79+
80+
result = self.metric._extract_from_markdown(text)
81+
82+
# 验证HTML表格被提取
83+
expected_table = """<table>
84+
<tr><th>标题1</th><th>标题2</th></tr>
85+
<tr><td>数据1</td><td>数据2</td></tr>
86+
</table>"""
87+
self.assertIn(expected_table, result['table'])
88+
89+
# 验证文本中HTML表格被移除
90+
self.assertNotIn('<table>', result['text'])
91+
92+
def test_complex_markdown_table(self):
93+
"""测试复杂Markdown表格"""
94+
text = """复杂表格:
95+
96+
| 姓名 | 年龄 | 职业 | 薪资 |
97+
|:-----|:----:|-----:|------|
98+
| 张三 | 25 | 工程师 | 15k |
99+
| 李四 | 30 | 设计师 | 18k |
100+
| 王五 | 28 | 产品经理 | 20k |
101+
102+
表格结束"""
103+
104+
result = self.metric._extract_from_markdown(text)
105+
106+
# 验证复杂表格被完整提取
107+
expected_table = """| 姓名 | 年龄 | 职业 | 薪资 |
108+
|:-----|:----:|-----:|------|
109+
| 张三 | 25 | 工程师 | 15k |
110+
| 李四 | 30 | 设计师 | 18k |
111+
| 王五 | 28 | 产品经理 | 20k |"""
112+
self.assertIn(expected_table, result['table'])
113+
114+
# 验证文本中表格被移除
115+
self.assertNotIn('| 姓名 | 年龄 | 职业 | 薪资 |', result['text'])
116+
117+
118+
119+
def test_table_with_alignment(self):
120+
"""测试带对齐的表格"""
121+
text = """对齐表格:
122+
| 左对齐 | 居中 | 右对齐 |
123+
|:-------|:----:|-------:|
124+
| 内容1 | 内容2 | 内容3 |"""
125+
126+
result = self.metric._extract_from_markdown(text)
127+
128+
# 验证对齐表格被提取
129+
expected_table = """| 左对齐 | 居中 | 右对齐 |
130+
|:-------|:----:|-------:|
131+
| 内容1 | 内容2 | 内容3 |"""
132+
self.assertIn(expected_table, result['table'])
133+
134+
def test_invalid_table_ignored(self):
135+
"""测试无效表格被忽略"""
136+
text = """这不是表格:| 列1 | 列2 |
137+
这也不是:|-----|
138+
这也不是:| 数据 |"""
139+
140+
result = self.metric._extract_from_markdown(text)
141+
142+
# 验证无效表格不被提取
143+
self.assertEqual(result['table'], '')
144+
145+
# 验证原始文本保持不变
146+
self.assertIn('| 列1 | 列2 |', result['text'])
147+
148+
def test_table_with_escaped_pipes(self):
149+
"""测试包含转义管道的表格"""
150+
text = """转义管道表格:
151+
| 列1 | 列2 \| 列3 | 列4 |
152+
|-----|-----|-----|
153+
| 数据1 | 数据2 | 数据3 |"""
154+
155+
result = self.metric._extract_from_markdown(text)
156+
157+
# 验证包含转义管道的表格被提取
158+
expected_table = """| 列1 | 列2 \\| 列3 | 列4 |
159+
|-----|-----|-----|
160+
| 数据1 | 数据2 | 数据3 |"""
161+
self.assertIn(expected_table, result['table'])
162+
163+
def test_table_at_document_end(self):
164+
"""测试文档末尾的表格"""
165+
text = """开始内容
166+
| 列1 | 列2 |
167+
|-----|-----|
168+
| 数据1 | 数据2 |"""
169+
170+
result = self.metric._extract_from_markdown(text)
171+
172+
# 验证文档末尾的表格被提取
173+
expected_table = """| 列1 | 列2 |
174+
|-----|-----|
175+
| 数据1 | 数据2 |"""
176+
self.assertIn(expected_table, result['table'])
177+
178+
179+
180+
def test_empty_and_whitespace_handling(self):
181+
"""测试空内容和空白处理"""
182+
# 测试空字符串
183+
result = self.metric._extract_from_markdown("")
184+
self.assertEqual(result['table'], '')
185+
self.assertEqual(result['text'], '')
186+
187+
# 测试只有空白字符
188+
result = self.metric._extract_from_markdown(" \n\n ")
189+
self.assertEqual(result['table'], '')
190+
self.assertEqual(result['text'], '')
191+
192+
def test_table_with_complex_content(self):
193+
"""测试包含复杂内容的表格"""
194+
text = """复杂内容表格:
195+
| 列1 | 列2 | 列3 |
196+
|-----|-----|-----|
197+
| 包含**粗体** | 包含`代码` | 包含[链接](url) |
198+
| 包含*斜体* | 包含$公式$ | 包含>引用 |"""
199+
200+
result = self.metric._extract_from_markdown(text)
201+
202+
# 验证复杂内容表格被提取
203+
expected_table = """| 列1 | 列2 | 列3 |
204+
|-----|-----|-----|
205+
| 包含**粗体** | 包含`代码` | 包含[链接](url) |
206+
| 包含*斜体* | 包含$公式$ | 包含>引用 |"""
207+
self.assertIn(expected_table, result['table'])
208+
209+
def test_nested_html_tables(self):
210+
"""测试嵌套HTML表格"""
211+
text = """嵌套表格:
212+
<table>
213+
<tr><td>外层表格</td></tr>
214+
<tr><td>
215+
<table>
216+
<tr><td>内层表格</td></tr>
217+
</table>
218+
</td></tr>
219+
</table>"""
220+
221+
result = self.metric._extract_from_markdown(text)
222+
print("result['table']",result['table'])
223+
# 验证嵌套表格被完整提取
224+
expected_table = """<table>
225+
<tr><td>外层表格</td></tr>
226+
<tr><td>
227+
<table>
228+
<tr><td>内层表格</td></tr>
229+
</table>
230+
</td></tr>
231+
</table>
232+
<table>
233+
<tr><td>内层表格</td></tr>
234+
</table>"""
235+
self.assertIn(expected_table, result['table'])
236+
237+
238+
if __name__ == '__main__':
239+
unittest.main()

0 commit comments

Comments
 (0)