Skip to content

Commit 2baf561

Browse files
authored
Merge pull request #21 from SHUzhangshuo/main
修改了_extract_from_markdown方法,并基于新的方法进行统计测试
2 parents 0784253 + 41f12bd commit 2baf561

File tree

5 files changed

+2271
-106
lines changed

5 files changed

+2271
-106
lines changed

examples/statics.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
from abc import ABCMeta
2+
import json
3+
from typing import Dict, Any, List, Optional, Union, override
4+
5+
from webmainbench.metrics.base import DocElementType, ParagraphTextType
6+
7+
def normalize_math_delimiters(text: str) -> str:
8+
"""将[tex][/tex]和[itex][/itex]格式的数学公式转换为$$..$$和$..$ 格式.
9+
10+
这是兜底处理,针对公式被br标签分割后没有识别为公式的情况.
11+
处理两种情况:
12+
1. 行间公式: [tex]...[/tex] -> $$...$$
13+
2. 行内公式: [itex]...[/itex] -> $...$
14+
该方法保留公式内容的原始格式,包括换行符和空格。
15+
Args:
16+
text (str): 包含数学公式的文本
17+
Returns:
18+
str: 替换数学公式标记后的文本
19+
"""
20+
import re
21+
22+
# 替换行间公式 [tex]...[/tex] -> $$...$$
23+
# 使用非贪婪匹配和DOTALL标志以匹配跨行公式
24+
display_pattern = re.compile(r'\[tex\](.*?)\[/tex\]', re.DOTALL)
25+
text = display_pattern.sub(lambda m: f'$${m.group(1).strip()}$$', text)
26+
27+
# 替换行内公式 [itex]...[/itex] -> $...$
28+
inline_pattern = re.compile(r'\[itex\](.*?)\[/itex\]', re.DOTALL)
29+
text = inline_pattern.sub(lambda m: f'${m.group(1).strip()}$', text)
30+
31+
return text
32+
33+
class ABC(metaclass=ABCMeta):
34+
"""Helper class that provides a standard way to create an ABC using
35+
inheritance.
36+
"""
37+
__slots__ = ()
38+
39+
class StructureMapper(ABC):
40+
"""作用是把contentList结构组合转化为另外一个结构 例如,从contentList转化为html, txt, md等等.
41+
42+
Args:
43+
object (_type_): _description_
44+
"""
45+
def __init__(self):
46+
self.__txt_para_splitter = '\n'
47+
self.__md_para_splitter = '\n\n'
48+
self.__text_end = '\n'
49+
self.__list_item_start = '-' # md里的列表项前缀
50+
self.__list_para_prefix = ' ' # 两个空格,md里的列表项非第一个段落的前缀:如果多个段落的情况,第二个以及之后的段落前缀
51+
self.__md_special_chars = ['#', '`', '$'] # TODO 拼装table的时候还应该转义掉|符号
52+
self.__nodes_document_type = [DocElementType.MM_NODE_LIST, DocElementType.PARAGRAPH, DocElementType.LIST, DocElementType.SIMPLE_TABLE, DocElementType.COMPLEX_TABLE, DocElementType.TITLE, DocElementType.IMAGE, DocElementType.AUDIO, DocElementType.VIDEO, DocElementType.CODE, DocElementType.EQUATION_INTERLINE]
53+
self.__inline_types_document_type = [ParagraphTextType.EQUATION_INLINE, ParagraphTextType.CODE_INLINE]
54+
55+
def to_html(self):
56+
raise NotImplementedError('This method must be implemented by the subclass.')
57+
58+
def to_txt(self, exclude_nodes=DocElementType.MM_NODE_LIST, exclude_inline_types=[]):
59+
"""把content_list转化为txt格式.
60+
61+
Args:
62+
exclude_nodes (list): 需要排除的节点类型
63+
Returns:
64+
str: txt格式的文本内容
65+
"""
66+
text_blocks: list[str] = [] # 每个是个DocElementType规定的元素块之一转换成的文本
67+
content_lst = self._get_data()
68+
for page in content_lst:
69+
for content_lst_node in page:
70+
if content_lst_node['type'] not in exclude_nodes:
71+
txt_content = self.__content_lst_node_2_txt(content_lst_node, exclude_inline_types)
72+
if txt_content and len(txt_content) > 0:
73+
text_blocks.append(txt_content)
74+
75+
txt = self.__txt_para_splitter.join(text_blocks)
76+
txt = normalize_math_delimiters(txt)
77+
txt = txt.strip() + self.__text_end # 加上结尾换行符
78+
return txt
79+
80+
class ContentList(StructureMapper):
81+
"""content_list格式的工具链实现."""
82+
83+
def __init__(self, json_data_lst: list):
84+
super().__init__()
85+
if json_data_lst is None:
86+
json_data_lst = []
87+
self.__content_list = json_data_lst
88+
89+
def length(self) -> int:
90+
return len(self.__content_list)
91+
92+
def append(self, content: dict):
93+
self.__content_list.append(content)
94+
95+
def __getitem__(self, key):
96+
return self.__content_list[key] # 提供读取功能
97+
98+
def __setitem__(self, key, value):
99+
self.__content_list[key] = value # 提供设置功能
100+
101+
def __delitem__(self, key):
102+
del self.__content_list[key]
103+
104+
@override
105+
def _get_data(self) -> List[Dict]:
106+
return self.__content_list
107+
108+
class Statics:
109+
"""统计content_list中每个元素的type的数量."""
110+
def __init__(self, statics: dict = None):
111+
self.statics = statics if statics else {}
112+
self._validate(self.statics)
113+
114+
def _validate(self, statics: dict):
115+
"""校验statics的格式.需要是字典且只有一个为"statics"的key.示例:
116+
{
117+
"list": 1,
118+
"list.text": 2,
119+
"list.equation-inline": 1,
120+
"paragraph": 2,
121+
"paragraph.text": 2,
122+
"equation-interline": 2
123+
}
124+
"""
125+
if not isinstance(statics, dict):
126+
raise ValueError('statics must be a dict')
127+
128+
def __additem__(self, key, value):
129+
self.statics[key] = value
130+
131+
def __getitem__(self, key):
132+
return self.statics[key]
133+
134+
def __getall__(self):
135+
return self.statics
136+
137+
def __clear__(self):
138+
self.statics = {}
139+
140+
def print(self):
141+
print(json.dumps(self.statics, indent=4))
142+
143+
def merge_statics(self, statics: dict) -> dict:
144+
"""合并多个contentlist的统计结果.
145+
146+
Args:
147+
statics: 每个contentlist的统计结果
148+
Returns:
149+
dict: 合并后的统计结果
150+
"""
151+
for key, value in statics.items():
152+
if isinstance(value, (int, float)):
153+
self.statics[key] = self.statics.get(key, 0) + value
154+
155+
return self.statics
156+
157+
def get_statics(self, contentlist) -> dict:
158+
"""
159+
统计contentlist中每个元素的type的数量(会清空之前的数据)
160+
Args:
161+
contentlist: 可以是ContentList对象或直接的列表数据
162+
Returns:
163+
dict: 每个元素的类型的数量
164+
"""
165+
self.__clear__()
166+
return self._calculate_statics(contentlist)
167+
168+
def add_statics(self, contentlist) -> dict:
169+
"""
170+
统计contentlist中每个元素的type的数量(累计到现有数据)
171+
Args:
172+
contentlist: 可以是ContentList对象或直接的列表数据
173+
Returns:
174+
dict: 累计后的统计结果
175+
"""
176+
return self._calculate_statics(contentlist)
177+
178+
def _calculate_statics(self, contentlist) -> dict:
179+
"""
180+
内部方法:计算contentlist的统计结果
181+
Args:
182+
contentlist: 可以是ContentList对象或直接的列表数据
183+
Returns:
184+
dict: 统计结果
185+
"""
186+
def process_list_items(items, parent_type):
187+
"""递归处理列表项
188+
Args:
189+
items: 列表项
190+
parent_type: 父元素类型(用于构建统计key)
191+
"""
192+
if isinstance(items, list):
193+
for item in items:
194+
process_list_items(item, parent_type)
195+
elif isinstance(items, dict) and 't' in items:
196+
# 到达最终的文本/公式元素
197+
item_type = f"{parent_type}.{items['t']}"
198+
current_count = self.statics.get(item_type, 0)
199+
self.statics[item_type] = current_count + 1
200+
201+
# 处理不同类型的输入
202+
if hasattr(contentlist, '_get_data'):
203+
# 如果是ContentList对象
204+
data = contentlist._get_data()
205+
else:
206+
# 如果是直接的列表数据
207+
data = contentlist
208+
209+
for page in data: # page是每一页的内容列表
210+
for element in page: # element是每个具体元素
211+
# 1. 统计基础元素
212+
element_type = element['type']
213+
current_count = self.statics.get(element_type, 0)
214+
self.statics[element_type] = current_count + 1
215+
216+
# 2. 统计复合元素内部结构
217+
if element_type == DocElementType.PARAGRAPH:
218+
# 段落内部文本类型统计
219+
for item in element['content']:
220+
item_type = f"{DocElementType.PARAGRAPH}.{item['t']}"
221+
current_count = self.statics.get(item_type, 0)
222+
self.statics[item_type] = current_count + 1
223+
224+
elif element_type == DocElementType.LIST:
225+
# 使用递归函数处理列表项
226+
process_list_items(element['content']['items'], DocElementType.LIST)
227+
elif element_type == DocElementType.COMPLEX_TABLE:
228+
# 统计复杂表格数量
229+
if element.get('content', {}).get('is_complex', False):
230+
item_type = f'{DocElementType.COMPLEX_TABLE}.complex'
231+
current_count = self.statics.get(item_type, 0)
232+
self.statics[item_type] = current_count + 1
233+
234+
return self.statics
235+

examples/statistics_test.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#!/usr/bin/env python3
2+
"""
3+
统计测试脚本:分析_extract_from_markdown方法的结果统计(统计整个数据集)
4+
并统计所有content_list的statics分布
5+
"""
6+
7+
import json
8+
import sys
9+
from pathlib import Path
10+
sys.path.insert(0, str(Path(__file__).parent))
11+
12+
from webmainbench.data.loader import DataLoader
13+
from webmainbench.metrics.base import BaseMetric
14+
15+
# 直接import merge_statics
16+
from examples.statics import Statics
17+
18+
def analyze_sample(sample, sample_index):
19+
"""分析单个样本的结果"""
20+
result = {
21+
'id': sample.id,
22+
'index': sample_index,
23+
'has_groundtruth_content': bool(sample.groundtruth_content),
24+
'markdown_result': None,
25+
'markdown_zero_chars': {'code': 0, 'formula': 0, 'table': 0, 'text': 0}
26+
}
27+
28+
# 使用_extract_from_markdown
29+
if sample.groundtruth_content:
30+
try:
31+
markdown_result = BaseMetric._extract_from_markdown(sample.groundtruth_content)
32+
result['markdown_result'] = markdown_result
33+
34+
# 统计0字符的情况
35+
for key in ['code', 'formula', 'table', 'text']:
36+
if len(markdown_result[key]) == 0:
37+
result['markdown_zero_chars'][key] = 1
38+
except Exception as e:
39+
print(f"Markdown方法处理样本 {sample.id} 时出错: {e}")
40+
41+
return result
42+
43+
def generate_summary_report(analysis_results, statics_total):
44+
"""生成汇总报告"""
45+
total_samples = len(analysis_results)
46+
47+
# 统计基本信息
48+
has_groundtruth_content = sum(1 for r in analysis_results if r['has_groundtruth_content'])
49+
50+
# 统计0字符情况
51+
markdown_zero_stats = {'code': 0, 'formula': 0, 'table': 0, 'text': 0}
52+
53+
for result in analysis_results:
54+
for key in ['code', 'formula', 'table', 'text']:
55+
markdown_zero_stats[key] += result['markdown_zero_chars'][key]
56+
57+
# 生成报告
58+
print("=" * 60)
59+
print("数据集分析汇总报告")
60+
print("=" * 60)
61+
print(f"总样本数: {total_samples}")
62+
print(f"有groundtruth_content的样本: {has_groundtruth_content} ({has_groundtruth_content/total_samples*100:.1f}%)")
63+
64+
print("\n" + "=" * 60)
65+
print("0字符统计 (Markdown方法)")
66+
print("=" * 60)
67+
for key in ['code', 'formula', 'table', 'text']:
68+
count = markdown_zero_stats[key]
69+
percentage = count / has_groundtruth_content * 100 if has_groundtruth_content > 0 else 0
70+
print(f"{key:>10}: {count:>3} 个样本 ({percentage:>5.1f}%)")
71+
72+
# 找出所有0字符的样本
73+
print("\n" + "=" * 60)
74+
print("所有字段都为0的样本")
75+
print("=" * 60)
76+
77+
markdown_all_zero = []
78+
79+
for result in analysis_results:
80+
if result['markdown_result']:
81+
if all(result['markdown_zero_chars'].values()):
82+
markdown_all_zero.append(result['id'])
83+
84+
print(f"Markdown方法全为0的样本: {len(markdown_all_zero)}")
85+
if markdown_all_zero:
86+
print("样本ID:", markdown_all_zero[:10]) # 只显示前10个
87+
88+
# 新增:输出所有content_list statics统计
89+
print("\n" + "=" * 60)
90+
print("所有content_list statics统计(类型分布总和)")
91+
print("=" * 60)
92+
import json
93+
# 使用__getall__()方法获取字典数据而不是直接序列化Statics对象
94+
statics_dict = statics_total.__getall__()
95+
print(json.dumps(statics_dict, ensure_ascii=False, indent=2))
96+
97+
def main():
98+
"""主函数"""
99+
print("开始数据集统计分析...")
100+
101+
try:
102+
# 加载数据集
103+
data_file = "WebMainBench/data/WebMainBench_llm-webkit_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl"
104+
dataset = DataLoader.load_jsonl(data_file)
105+
106+
print(f"成功加载 {len(dataset.samples)} 个样本")
107+
108+
# 统计整个数据集(不再只取前100个样本)
109+
test_samples = dataset.samples
110+
print(f"分析全部 {len(test_samples)} 个样本...")
111+
112+
# 分析每个样本
113+
analysis_results = []
114+
statics = Statics()
115+
for i, sample in enumerate(test_samples):
116+
if i % 100 == 0: # 每100个样本显示进度
117+
print(f"处理进度: {i}/{len(test_samples)}")
118+
119+
result = analyze_sample(sample, i)
120+
analysis_results.append(result)
121+
122+
# 统计content_list statics
123+
if getattr(sample, "groundtruth_content_list", None):
124+
cl = sample.groundtruth_content_list
125+
# 兼容格式
126+
if cl and isinstance(cl, list) and (len(cl) == 0 or not isinstance(cl[0], list)):
127+
cl_input = [cl]
128+
else:
129+
cl_input = cl
130+
# 使用add_statics进行累计统计
131+
statics_result = statics.add_statics(cl_input)
132+
if i < 5: # 只打印前5个样本的统计结果作为示例
133+
print(f"样本 {i} 统计结果:")
134+
print(json.dumps(statics_result, ensure_ascii=False, indent=2))
135+
# 生成汇总报告
136+
137+
generate_summary_report(analysis_results, statics)
138+
139+
print("\n" + "=" * 60)
140+
print("分析完成!")
141+
print("=" * 60)
142+
143+
except Exception as e:
144+
print(f"分析失败: {e}")
145+
import traceback
146+
traceback.print_exc()
147+
148+
if __name__ == '__main__':
149+
main()

0 commit comments

Comments
 (0)