Skip to content

Commit 0fc3668

Browse files
committed
update trafilatura extract txt
2 parents f6b84eb + d7ba4af commit 0fc3668

File tree

7 files changed

+371
-295
lines changed

7 files changed

+371
-295
lines changed

examples/multi_extractor_compare.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
from webmainbench import DataLoader, Evaluator, ExtractorFactory, DataSaver
22
from pathlib import Path
33

4+
# 全局LLM配置
5+
LLM_CONFIG = {
6+
'llm_base_url': '',
7+
'llm_api_key': '',
8+
'llm_model': '',
9+
'use_llm': True
10+
}
411

512
def all_extractor_comparison():
613
"""演示多抽取器对比"""
714

815
print("\n=== 多抽取器对比演示 ===\n")
916

1017
# 创建数据集
11-
dataset_path = Path("../data/WebMainBench_llm-webkit_v1_WebMainBench_7887_within_formula.jsonl")
18+
dataset_path = Path("../data/test_math.jsonl")
1219
dataset = DataLoader.load_jsonl(dataset_path)
13-
20+
1421
# 创建webkit抽取器
1522
config = {
1623
"use_preprocessed_html": True, # 🔑 关键配置:启用预处理HTML模式
1724
"preprocessed_html_field": "llm_webkit_html" # 指定预处理HTML字段名
1825
}
26+
1927
webkit_extractor = ExtractorFactory.create("llm-webkit", config=config)
2028
# 创建magic-extractor抽取器
2129
magic_extractor = ExtractorFactory.create("magic-html")

webmainbench/metrics/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from .teds_metrics import TEDSMetric, StructureTEDSMetric
1212
from .calculator import MetricCalculator
1313
from .mainhtml_calculator import MainHTMLMetricCalculator
14+
from .base_content_splitter import BaseContentSplitter
15+
from .formula_extractor import FormulaSplitter
16+
from .code_extractor import CodeSplitter
17+
from .table_extractor import TableSplitter
1418

1519
__all__ = [
1620
"BaseMetric",
@@ -27,4 +31,8 @@
2731
"TextEditMetric",
2832
"MetricCalculator",
2933
"MainHTMLMetricCalculator",
34+
'BaseContentSplitter',
35+
'FormulaSplitter',
36+
'CodeSplitter',
37+
'TableSplitter',
3038
]

webmainbench/metrics/base.py

Lines changed: 21 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,7 @@
44

55
from abc import ABC, abstractmethod
66
from dataclasses import dataclass
7-
from typing import Dict, Any, List, Optional, Union
8-
import traceback
9-
import re
10-
from bs4 import BeautifulSoup
11-
import os
12-
import hashlib
7+
from typing import Dict, Any, List, Optional
138

149
@dataclass
1510
class MetricResult:
@@ -144,7 +139,7 @@ def split_content(text: str, content_list: List[Dict[str, Any]] = None, field_na
144139

145140
# 从markdown文本中提取,传递字段名称
146141
return BaseMetric._extract_from_markdown(text or "", field_name=field_name)
147-
142+
148143
@staticmethod
149144
def _extract_from_content_list(content_list: List[Dict[str, Any]]) -> Dict[str, str]:
150145
"""从content_list中递归提取各种类型的内容"""
@@ -194,233 +189,36 @@ def _recursive_extract(items):
194189
'table': '\n'.join(extracted['table']),
195190
'text': '\n'.join(extracted['text'])
196191
}
197-
192+
198193
@staticmethod
199194
def _extract_from_markdown(text: str, field_name: str = None) -> Dict[str, str]:
200195
"""从markdown文本中提取各种类型的内容"""
201196
if not text:
202197
return {'code': '', 'formula': '', 'table': '', 'text': ''}
203198

204-
# 收集所有需要移除的内容片段
205-
extracted_segments = []
206-
code_parts = []
207-
# # 同匹配行间代码块 ```...```
208-
# pattern = r'(```[\s\S]*?```)'
209-
# for match in re.finditer(pattern, text):
210-
# code_segment = match.group(0)
211-
# extracted_segments.append(code_segment)
212-
#
213-
# if code_segment.startswith('```'):
214-
# # 处理代码块(保留内部缩进)
215-
# lines = code_segment.split('\n')
216-
# # 移除首尾的```标记
217-
# content_lines = lines[1:-1]
218-
# # 保留原始缩进,只拼接内容
219-
# code_content = '\n'.join(content_lines)
220-
# else:
221-
# # 处理行内代码(只去除外层`和前后空格)
222-
# code_content = code_segment[1:-1].strip()
223-
#
224-
# if code_content: # 只添加非空内容
225-
# code_parts.append(code_content)
226-
227-
# 1. 首先处理三个反引号包裹的代码块(优先级最高)
228-
backtick_pattern = r'(```[\s\S]*?```)'
229-
for match in re.finditer(backtick_pattern, text):
230-
code_segment = match.group(0)
231-
232-
if code_segment.startswith('```'):
233-
# 处理代码块
234-
lines = code_segment.split('\n')
235-
# 移除首尾的```标记
236-
content_lines = lines[1:-1]
237-
code_content = '\n'.join(content_lines)
238-
else:
239-
# 处理行内代码
240-
code_content = code_segment[1:-1].strip()
241-
242-
if code_content:
243-
code_parts.append(code_content)
244-
245-
# 2. 处理缩进代码块 - 使用更精确的匹配
246-
# 匹配模式:前面有空行 + 连续的多行缩进内容 + 后面有空行
247-
# 关键:要求所有匹配的行都是缩进的
248-
indent_pattern = r'(?:\n\s*\n)((?:(?: {4,}|\t+)[^\n]*(?:\n|$)){2,})(?=\n\s*\n|$)'
249-
250-
for match in re.finditer(indent_pattern, text, re.MULTILINE):
251-
code_segment = match.group(1)
252-
253-
# 验证:确保所有行都是缩进的(避免混合缩进和非缩进行)
254-
lines = code_segment.split('\n')
255-
all_indented = all(
256-
line.startswith(' ') or line.startswith('\t') or not line.strip()
257-
for line in lines
258-
if line.strip() # 空行不算
259-
)
260-
261-
if not all_indented:
262-
continue # 跳过包含非缩进行的块
263-
264-
# 进一步验证代码特征
265-
non_empty_lines = [line.strip() for line in lines if line.strip()]
266-
if len(non_empty_lines) < 2: # 至少2行非空内容
267-
continue
268-
269-
# 检查是否有明显的非代码特征
270-
has_list_features = any(
271-
re.match(r'^[-•*]\s', line) or
272-
re.match(r'^\d+\.\s', line) or
273-
re.search(r'\$[\d,]', line) or
274-
re.search(r'\b(million|billion|thousand)\b', line, re.IGNORECASE)
275-
for line in non_empty_lines
276-
)
277-
278-
if has_list_features:
279-
continue # 跳过列表内容
280-
281-
# 清理代码段
282-
cleaned_lines = []
283-
for line in code_segment.split('\n'):
284-
if line.strip():
285-
if line.startswith(' '):
286-
cleaned_lines.append(line[4:])
287-
elif line.startswith('\t'):
288-
cleaned_lines.append(line[1:])
289-
else:
290-
cleaned_lines.append(line)
291-
292-
code_content = '\n'.join(cleaned_lines)
293-
if code_content.strip():
294-
code_parts.append(code_content)
295-
296-
# 提取公式 - 新的两步处理逻辑
297-
formula_parts = []
199+
# 加载 llm 配置
200+
from examples.multi_extractor_compare import LLM_CONFIG
201+
# 直接创建具体的提取器实例
202+
from .code_extractor import CodeSplitter
203+
from .formula_extractor import FormulaSplitter
204+
from .table_extractor import TableSplitter
298205

299-
# 第一步:先用正则提取公式
300-
regex_formulas = []
301-
latex_patterns = [
302-
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$
303-
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\]
304-
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$
305-
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\)
306-
]
206+
code_extractor = CodeSplitter(LLM_CONFIG)
207+
formula_extractor = FormulaSplitter(LLM_CONFIG)
208+
table_extractor = TableSplitter(LLM_CONFIG)
307209

308-
for pattern in latex_patterns:
309-
for match in re.finditer(pattern, text, re.DOTALL):
310-
formula_full = match.group(0)
311-
formula_content = match.group(1)
312-
extracted_segments.append(formula_full)
313-
if formula_content.strip():
314-
regex_formulas.append(formula_content.strip())
210+
# 提取各类内容
211+
code_content = code_extractor.extract(text, field_name)
212+
formula_content = formula_extractor.extract(text, field_name)
213+
table_content = table_extractor.extract(text, field_name)
315214

316-
# 第二步:根据字段类型决定是否需要API修正
317-
if field_name == "groundtruth_content":
318-
print(f"[DEBUG] 检测到groundtruth内容,仅使用正则提取公式")
319-
formula_parts = regex_formulas
320-
else:
321-
print(f"[DEBUG] 检测到md内容,使用正则+API修正模式")
322-
# 对于llm_webkit_md,将正则结果传递给API进行修正
323-
if regex_formulas:
324-
# 将正则提取的公式作为输入传递给API
325-
regex_formulas_text = '\n'.join(regex_formulas)
326-
print(f"[DEBUG] 正则提取到 {len(regex_formulas)} 个公式,准备API修正")
327-
328-
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.cache')
329-
os.makedirs(cache_dir, exist_ok=True)
330-
331-
# 使用正则结果的哈希作为缓存文件名
332-
text_hash = hashlib.md5(regex_formulas_text.encode('utf-8')).hexdigest()
333-
cache_file = os.path.join(cache_dir, f'formula_correction_cache_{text_hash}.json')
334-
335-
try:
336-
from .formula_extractor import correct_formulas_with_llm
337-
corrected_formulas = correct_formulas_with_llm(regex_formulas, cache_file)
338-
formula_parts = corrected_formulas
339-
print(f"[DEBUG] API修正成功,最终得到 {len(formula_parts)} 个公式")
340-
except Exception as e:
341-
print(f"[DEBUG] API修正失败: {type(e).__name__}: {e},使用正则结果")
342-
formula_parts = regex_formulas
343-
else:
344-
print(f"[DEBUG] 正则未提取到公式,跳过API修正")
345-
formula_parts = []
346-
347-
# 提取表格
348-
table_parts = []
349-
350-
# ===== 1. 提取 HTML 表格 =====
351-
# 用 BeautifulSoup 替代正则,防止嵌套或匹配不全
352-
soup = BeautifulSoup(text, "html.parser")
353-
for table in soup.find_all("table"):
354-
# 判断当前表格的父级是否是表格内的标签(<td>、<tr>、<tbody>等)
355-
parent_is_table_related = table.find_parent(["td", "tr", "tbody", "table"]) is not None
356-
if not parent_is_table_related: # 父级不是表格相关标签 → 是外层表格
357-
html_table = str(table)
358-
extracted_segments.append(html_table)
359-
table_parts.append(html_table)
360-
361-
# ===== 2. 提取 Markdown 表格 =====
362-
lines = text.split('\n')
363-
table_lines = []
364-
in_markdown_table = False
365-
found_separator = False # 是否已找到分隔行
366-
367-
def is_md_table_line(line):
368-
"""判断是否可能是 Markdown 表格行"""
369-
if line.count("|") < 1: # 至少三个竖线
370-
return False
371-
return True
372-
373-
def is_md_separator_line(line):
374-
"""判断是否为 Markdown 分隔行"""
375-
parts = [p.strip() for p in line.split("|")]
376-
# 检查是否所有部分都是分隔符格式
377-
for p in parts:
378-
if p and not re.match(r"^:?\-{3,}:?$", p):
379-
return False
380-
return True
381-
382-
def save_table():
383-
"""保存当前表格并清空缓存"""
384-
nonlocal table_lines
385-
# 只有当表格行数大于等于2,且第二行是分隔行时才保存
386-
if len(table_lines) >= 2 and is_md_separator_line(table_lines[1]):
387-
md_table = '\n'.join(table_lines)
388-
extracted_segments.append(md_table)
389-
table_parts.append(md_table)
390-
391-
for line in lines:
392-
if is_md_table_line(line):
393-
table_lines.append(line)
394-
in_markdown_table = True
395-
if is_md_separator_line(line):
396-
found_separator = True
397-
else:
398-
if in_markdown_table:
399-
save_table()
400-
table_lines = []
401-
in_markdown_table = False
402-
found_separator = False
403-
404-
# 处理文档末尾的 Markdown 表格
405-
if in_markdown_table:
406-
save_table()
407-
408-
# 提取剩余文本(移除所有已提取的内容片段)
409-
clean_text = text
410-
for segment in extracted_segments:
411-
clean_text = clean_text.replace(segment, '', 1)
412-
413-
# 清理多余的空行
414-
clean_text = re.sub(r'\n\s*\n', '\n\n', clean_text)
415-
clean_text = clean_text.strip()
416-
417215
return {
418-
'code': '\n'.join(code_parts),
419-
'formula': '\n'.join(formula_parts),
420-
'table': '\n'.join(table_parts),
421-
'text': text # 原始全部文本
216+
'code': code_content,
217+
'formula': formula_content,
218+
'table': table_content,
219+
'text': text # 保留原始全部文本
422220
}
423-
221+
424222
def aggregate_results(self, results: List[MetricResult]) -> MetricResult:
425223
"""
426224
Aggregate multiple metric results.

0 commit comments

Comments
 (0)