|
4 | 4 |
|
5 | 5 | from abc import ABC, abstractmethod |
6 | 6 | from dataclasses import dataclass |
7 | | -from typing import Dict, Any, List, Optional, Union |
8 | | -import traceback |
9 | | -import re |
10 | | -from bs4 import BeautifulSoup |
11 | | -import os |
12 | | -import hashlib |
| 7 | +from typing import Dict, Any, List, Optional |
13 | 8 |
|
14 | 9 | @dataclass |
15 | 10 | class MetricResult: |
@@ -144,7 +139,7 @@ def split_content(text: str, content_list: List[Dict[str, Any]] = None, field_na |
144 | 139 |
|
145 | 140 | # 从markdown文本中提取,传递字段名称 |
146 | 141 | return BaseMetric._extract_from_markdown(text or "", field_name=field_name) |
147 | | - |
| 142 | + |
148 | 143 | @staticmethod |
149 | 144 | def _extract_from_content_list(content_list: List[Dict[str, Any]]) -> Dict[str, str]: |
150 | 145 | """从content_list中递归提取各种类型的内容""" |
@@ -194,233 +189,36 @@ def _recursive_extract(items): |
194 | 189 | 'table': '\n'.join(extracted['table']), |
195 | 190 | 'text': '\n'.join(extracted['text']) |
196 | 191 | } |
197 | | - |
| 192 | + |
198 | 193 | @staticmethod |
199 | 194 | def _extract_from_markdown(text: str, field_name: str = None) -> Dict[str, str]: |
200 | 195 | """从markdown文本中提取各种类型的内容""" |
201 | 196 | if not text: |
202 | 197 | return {'code': '', 'formula': '', 'table': '', 'text': ''} |
203 | 198 |
|
204 | | - # 收集所有需要移除的内容片段 |
205 | | - extracted_segments = [] |
206 | | - code_parts = [] |
207 | | - # # 同匹配行间代码块 ```...``` |
208 | | - # pattern = r'(```[\s\S]*?```)' |
209 | | - # for match in re.finditer(pattern, text): |
210 | | - # code_segment = match.group(0) |
211 | | - # extracted_segments.append(code_segment) |
212 | | - # |
213 | | - # if code_segment.startswith('```'): |
214 | | - # # 处理代码块(保留内部缩进) |
215 | | - # lines = code_segment.split('\n') |
216 | | - # # 移除首尾的```标记 |
217 | | - # content_lines = lines[1:-1] |
218 | | - # # 保留原始缩进,只拼接内容 |
219 | | - # code_content = '\n'.join(content_lines) |
220 | | - # else: |
221 | | - # # 处理行内代码(只去除外层`和前后空格) |
222 | | - # code_content = code_segment[1:-1].strip() |
223 | | - # |
224 | | - # if code_content: # 只添加非空内容 |
225 | | - # code_parts.append(code_content) |
226 | | - |
227 | | - # 1. 首先处理三个反引号包裹的代码块(优先级最高) |
228 | | - backtick_pattern = r'(```[\s\S]*?```)' |
229 | | - for match in re.finditer(backtick_pattern, text): |
230 | | - code_segment = match.group(0) |
231 | | - |
232 | | - if code_segment.startswith('```'): |
233 | | - # 处理代码块 |
234 | | - lines = code_segment.split('\n') |
235 | | - # 移除首尾的```标记 |
236 | | - content_lines = lines[1:-1] |
237 | | - code_content = '\n'.join(content_lines) |
238 | | - else: |
239 | | - # 处理行内代码 |
240 | | - code_content = code_segment[1:-1].strip() |
241 | | - |
242 | | - if code_content: |
243 | | - code_parts.append(code_content) |
244 | | - |
245 | | - # 2. 处理缩进代码块 - 使用更精确的匹配 |
246 | | - # 匹配模式:前面有空行 + 连续的多行缩进内容 + 后面有空行 |
247 | | - # 关键:要求所有匹配的行都是缩进的 |
248 | | - indent_pattern = r'(?:\n\s*\n)((?:(?: {4,}|\t+)[^\n]*(?:\n|$)){2,})(?=\n\s*\n|$)' |
249 | | - |
250 | | - for match in re.finditer(indent_pattern, text, re.MULTILINE): |
251 | | - code_segment = match.group(1) |
252 | | - |
253 | | - # 验证:确保所有行都是缩进的(避免混合缩进和非缩进行) |
254 | | - lines = code_segment.split('\n') |
255 | | - all_indented = all( |
256 | | - line.startswith(' ') or line.startswith('\t') or not line.strip() |
257 | | - for line in lines |
258 | | - if line.strip() # 空行不算 |
259 | | - ) |
260 | | - |
261 | | - if not all_indented: |
262 | | - continue # 跳过包含非缩进行的块 |
263 | | - |
264 | | - # 进一步验证代码特征 |
265 | | - non_empty_lines = [line.strip() for line in lines if line.strip()] |
266 | | - if len(non_empty_lines) < 2: # 至少2行非空内容 |
267 | | - continue |
268 | | - |
269 | | - # 检查是否有明显的非代码特征 |
270 | | - has_list_features = any( |
271 | | - re.match(r'^[-•*]\s', line) or |
272 | | - re.match(r'^\d+\.\s', line) or |
273 | | - re.search(r'\$[\d,]', line) or |
274 | | - re.search(r'\b(million|billion|thousand)\b', line, re.IGNORECASE) |
275 | | - for line in non_empty_lines |
276 | | - ) |
277 | | - |
278 | | - if has_list_features: |
279 | | - continue # 跳过列表内容 |
280 | | - |
281 | | - # 清理代码段 |
282 | | - cleaned_lines = [] |
283 | | - for line in code_segment.split('\n'): |
284 | | - if line.strip(): |
285 | | - if line.startswith(' '): |
286 | | - cleaned_lines.append(line[4:]) |
287 | | - elif line.startswith('\t'): |
288 | | - cleaned_lines.append(line[1:]) |
289 | | - else: |
290 | | - cleaned_lines.append(line) |
291 | | - |
292 | | - code_content = '\n'.join(cleaned_lines) |
293 | | - if code_content.strip(): |
294 | | - code_parts.append(code_content) |
295 | | - |
296 | | - # 提取公式 - 新的两步处理逻辑 |
297 | | - formula_parts = [] |
| 199 | + # 加载 llm 配置 |
| 200 | + from examples.multi_extractor_compare import LLM_CONFIG |
| 201 | + # 直接创建具体的提取器实例 |
| 202 | + from .code_extractor import CodeSplitter |
| 203 | + from .formula_extractor import FormulaSplitter |
| 204 | + from .table_extractor import TableSplitter |
298 | 205 |
|
299 | | - # 第一步:先用正则提取公式 |
300 | | - regex_formulas = [] |
301 | | - latex_patterns = [ |
302 | | - r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$ |
303 | | - r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\] |
304 | | - r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$ |
305 | | - r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\) |
306 | | - ] |
| 206 | + code_extractor = CodeSplitter(LLM_CONFIG) |
| 207 | + formula_extractor = FormulaSplitter(LLM_CONFIG) |
| 208 | + table_extractor = TableSplitter(LLM_CONFIG) |
307 | 209 |
|
308 | | - for pattern in latex_patterns: |
309 | | - for match in re.finditer(pattern, text, re.DOTALL): |
310 | | - formula_full = match.group(0) |
311 | | - formula_content = match.group(1) |
312 | | - extracted_segments.append(formula_full) |
313 | | - if formula_content.strip(): |
314 | | - regex_formulas.append(formula_content.strip()) |
| 210 | + # 提取各类内容 |
| 211 | + code_content = code_extractor.extract(text, field_name) |
| 212 | + formula_content = formula_extractor.extract(text, field_name) |
| 213 | + table_content = table_extractor.extract(text, field_name) |
315 | 214 |
|
316 | | - # 第二步:根据字段类型决定是否需要API修正 |
317 | | - if field_name == "groundtruth_content": |
318 | | - print(f"[DEBUG] 检测到groundtruth内容,仅使用正则提取公式") |
319 | | - formula_parts = regex_formulas |
320 | | - else: |
321 | | - print(f"[DEBUG] 检测到md内容,使用正则+API修正模式") |
322 | | - # 对于llm_webkit_md,将正则结果传递给API进行修正 |
323 | | - if regex_formulas: |
324 | | - # 将正则提取的公式作为输入传递给API |
325 | | - regex_formulas_text = '\n'.join(regex_formulas) |
326 | | - print(f"[DEBUG] 正则提取到 {len(regex_formulas)} 个公式,准备API修正") |
327 | | - |
328 | | - cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.cache') |
329 | | - os.makedirs(cache_dir, exist_ok=True) |
330 | | - |
331 | | - # 使用正则结果的哈希作为缓存文件名 |
332 | | - text_hash = hashlib.md5(regex_formulas_text.encode('utf-8')).hexdigest() |
333 | | - cache_file = os.path.join(cache_dir, f'formula_correction_cache_{text_hash}.json') |
334 | | - |
335 | | - try: |
336 | | - from .formula_extractor import correct_formulas_with_llm |
337 | | - corrected_formulas = correct_formulas_with_llm(regex_formulas, cache_file) |
338 | | - formula_parts = corrected_formulas |
339 | | - print(f"[DEBUG] API修正成功,最终得到 {len(formula_parts)} 个公式") |
340 | | - except Exception as e: |
341 | | - print(f"[DEBUG] API修正失败: {type(e).__name__}: {e},使用正则结果") |
342 | | - formula_parts = regex_formulas |
343 | | - else: |
344 | | - print(f"[DEBUG] 正则未提取到公式,跳过API修正") |
345 | | - formula_parts = [] |
346 | | - |
347 | | - # 提取表格 |
348 | | - table_parts = [] |
349 | | - |
350 | | - # ===== 1. 提取 HTML 表格 ===== |
351 | | - # 用 BeautifulSoup 替代正则,防止嵌套或匹配不全 |
352 | | - soup = BeautifulSoup(text, "html.parser") |
353 | | - for table in soup.find_all("table"): |
354 | | - # 判断当前表格的父级是否是表格内的标签(<td>、<tr>、<tbody>等) |
355 | | - parent_is_table_related = table.find_parent(["td", "tr", "tbody", "table"]) is not None |
356 | | - if not parent_is_table_related: # 父级不是表格相关标签 → 是外层表格 |
357 | | - html_table = str(table) |
358 | | - extracted_segments.append(html_table) |
359 | | - table_parts.append(html_table) |
360 | | - |
361 | | - # ===== 2. 提取 Markdown 表格 ===== |
362 | | - lines = text.split('\n') |
363 | | - table_lines = [] |
364 | | - in_markdown_table = False |
365 | | - found_separator = False # 是否已找到分隔行 |
366 | | - |
367 | | - def is_md_table_line(line): |
368 | | - """判断是否可能是 Markdown 表格行""" |
369 | | - if line.count("|") < 1: # 至少三个竖线 |
370 | | - return False |
371 | | - return True |
372 | | - |
373 | | - def is_md_separator_line(line): |
374 | | - """判断是否为 Markdown 分隔行""" |
375 | | - parts = [p.strip() for p in line.split("|")] |
376 | | - # 检查是否所有部分都是分隔符格式 |
377 | | - for p in parts: |
378 | | - if p and not re.match(r"^:?\-{3,}:?$", p): |
379 | | - return False |
380 | | - return True |
381 | | - |
382 | | - def save_table(): |
383 | | - """保存当前表格并清空缓存""" |
384 | | - nonlocal table_lines |
385 | | - # 只有当表格行数大于等于2,且第二行是分隔行时才保存 |
386 | | - if len(table_lines) >= 2 and is_md_separator_line(table_lines[1]): |
387 | | - md_table = '\n'.join(table_lines) |
388 | | - extracted_segments.append(md_table) |
389 | | - table_parts.append(md_table) |
390 | | - |
391 | | - for line in lines: |
392 | | - if is_md_table_line(line): |
393 | | - table_lines.append(line) |
394 | | - in_markdown_table = True |
395 | | - if is_md_separator_line(line): |
396 | | - found_separator = True |
397 | | - else: |
398 | | - if in_markdown_table: |
399 | | - save_table() |
400 | | - table_lines = [] |
401 | | - in_markdown_table = False |
402 | | - found_separator = False |
403 | | - |
404 | | - # 处理文档末尾的 Markdown 表格 |
405 | | - if in_markdown_table: |
406 | | - save_table() |
407 | | - |
408 | | - # 提取剩余文本(移除所有已提取的内容片段) |
409 | | - clean_text = text |
410 | | - for segment in extracted_segments: |
411 | | - clean_text = clean_text.replace(segment, '', 1) |
412 | | - |
413 | | - # 清理多余的空行 |
414 | | - clean_text = re.sub(r'\n\s*\n', '\n\n', clean_text) |
415 | | - clean_text = clean_text.strip() |
416 | | - |
417 | 215 | return { |
418 | | - 'code': '\n'.join(code_parts), |
419 | | - 'formula': '\n'.join(formula_parts), |
420 | | - 'table': '\n'.join(table_parts), |
421 | | - 'text': text # 原始全部文本 |
| 216 | + 'code': code_content, |
| 217 | + 'formula': formula_content, |
| 218 | + 'table': table_content, |
| 219 | + 'text': text # 保留原始全部文本 |
422 | 220 | } |
423 | | - |
| 221 | + |
424 | 222 | def aggregate_results(self, results: List[MetricResult]) -> MetricResult: |
425 | 223 | """ |
426 | 224 | Aggregate multiple metric results. |
|
0 commit comments