Skip to content

Commit 00a4b62

Browse files
committed
恢复缓存机制
1 parent 13a919b commit 00a4b62

File tree

2 files changed

+180
-2
lines changed

2 files changed

+180
-2
lines changed
Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,102 @@
1-
from abc import ABC, abstractmethodfrom typing import List, Dict, Anyimport osimport hashlibimport jsonfrom openai import OpenAIclass BaseContentSplitter(ABC): """抽象基类,用于从文本中提取特定类型的内容""" # 默认的LLM提示词模板 DEFAULT_LLM_PROMPT = """请处理以下内容: {content} """ def __init__(self, config: Dict[str, Any] = None): """初始化提取器""" self.config = config or {} # 保留这行代码,用于控制是否使用LLM self.use_llm = self.config.get('use_llm', True) # 初始化OpenAI客户端(如果配置了LLM) if self.use_llm and self.config.get('llm_base_url') and self.config.get('llm_api_key'): self.client = OpenAI( base_url=self.config.get('llm_base_url', ""), api_key=self.config.get('llm_api_key', "") ) else: self.client = None self.cache_dir = self.config.get('cache_dir', os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.cache')) os.makedirs(self.cache_dir, exist_ok=True) @abstractmethod def extract(self, text: str, field_name: str = None) -> str: """提取特定类型的内容""" pass @abstractmethod def extract_basic(self, text: str) -> List[str]: """使用基本方法提取内容(通常是正则表达式)""" pass def should_use_llm(self, field_name: str) -> bool: """判断是否应该使用LLM进行增强提取""" if not self.use_llm: return False # 默认逻辑:对groundtruth内容不使用LLM,对其他内容使用 if field_name == "groundtruth_content": print(f"[DEBUG] 检测到groundtruth内容,不使用LLM") return False return True def enhance_with_llm(self, basic_results: List[str], cache_key: str = None) -> List[str]: """使用LLM增强基本提取结果""" if not basic_results: print(f"[DEBUG] 输入内容为空,跳过LLM增强") return [] # 生成缓存键 if cache_key is None: content_str = '\n'.join(basic_results) cache_key = hashlib.md5(content_str.encode('utf-8')).hexdigest() cache_file = os.path.join(self.cache_dir, f'{self.__class__.__name__.lower()}_cache_{cache_key}.json') # 检查缓存 if os.path.exists(cache_file): try: with open(cache_file, 'r', encoding='utf-8') as f: cached_result = json.load(f) print(f"[DEBUG] 从缓存加载LLM增强结果: {len(cached_result)} 个") return cached_result except Exception as e: print(f"[DEBUG] 缓存读取失败: {e}") # 实际的LLM增强逻辑 try: enhanced_results = self._llm_enhance(basic_results) # 保存缓存 try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(enhanced_results, f, ensure_ascii=False, indent=2) print(f"[DEBUG] LLM增强结果已缓存到: {cache_file}") except Exception as e: print(f"[DEBUG] 缓存保存失败: {e}") return enhanced_results except Exception as e: print(f"[DEBUG] LLM增强失败: {type(e).__name__}: {e}") return basic_results @abstractmethod def _llm_enhance(self, basic_results: List[str]) -> List[str]: """使用LLM增强基本提取结果的具体实现""" pass
1+
from abc import ABC, abstractmethod
2+
from typing import List, Dict, Any
3+
import os
4+
import hashlib
5+
import json
6+
from openai import OpenAI
7+
8+
9+
class BaseContentSplitter(ABC):
10+
"""抽象基类,用于从文本中提取特定类型的内容"""
11+
12+
# 默认的LLM提示词模板
13+
DEFAULT_LLM_PROMPT = """请处理以下内容:
14+
{content}
15+
"""
16+
17+
def __init__(self, config: Dict[str, Any] = None):
18+
"""初始化提取器"""
19+
self.config = config or {}
20+
21+
# 保留这行代码,用于控制是否使用LLM
22+
self.use_llm = self.config.get('use_llm', True)
23+
24+
# 初始化OpenAI客户端(如果配置了LLM)
25+
if self.use_llm and self.config.get('llm_base_url') and self.config.get('llm_api_key'):
26+
self.client = OpenAI(
27+
base_url=self.config.get('llm_base_url', ""),
28+
api_key=self.config.get('llm_api_key', "")
29+
)
30+
else:
31+
self.client = None
32+
33+
self.cache_dir = self.config.get('cache_dir',
34+
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
35+
'.cache'))
36+
os.makedirs(self.cache_dir, exist_ok=True)
37+
38+
@abstractmethod
39+
def extract(self, text: str, field_name: str = None) -> str:
40+
"""提取特定类型的内容"""
41+
pass
42+
43+
@abstractmethod
44+
def extract_basic(self, text: str) -> List[str]:
45+
"""使用基本方法提取内容(通常是正则表达式)"""
46+
pass
47+
48+
def should_use_llm(self, field_name: str) -> bool:
49+
"""判断是否应该使用LLM进行增强提取"""
50+
if not self.use_llm:
51+
return False
52+
53+
# 默认逻辑:对groundtruth内容不使用LLM,对其他内容使用
54+
if field_name == "groundtruth_content":
55+
print(f"[DEBUG] 检测到groundtruth内容,不使用LLM")
56+
return False
57+
return True
58+
59+
def enhance_with_llm(self, basic_results: List[str], cache_key: str = None) -> List[str]:
60+
"""使用LLM增强基本提取结果"""
61+
if not basic_results:
62+
print(f"[DEBUG] 输入内容为空,跳过LLM增强")
63+
return []
64+
65+
# 生成缓存键
66+
if cache_key is None:
67+
content_str = '\n'.join(basic_results)
68+
cache_key = hashlib.md5(content_str.encode('utf-8')).hexdigest()
69+
70+
cache_file = os.path.join(self.cache_dir, f'{self.__class__.__name__.lower()}_cache_{cache_key}.json')
71+
72+
# 检查缓存
73+
if os.path.exists(cache_file):
74+
try:
75+
with open(cache_file, 'r', encoding='utf-8') as f:
76+
cached_result = json.load(f)
77+
print(f"[DEBUG] 从缓存加载LLM增强结果: {len(cached_result)} 个")
78+
return cached_result
79+
except Exception as e:
80+
print(f"[DEBUG] 缓存读取失败: {e}")
81+
82+
# 实际的LLM增强逻辑
83+
try:
84+
enhanced_results = self._llm_enhance(basic_results)
85+
86+
# 保存缓存
87+
try:
88+
with open(cache_file, 'w', encoding='utf-8') as f:
89+
json.dump(enhanced_results, f, ensure_ascii=False, indent=2)
90+
print(f"[DEBUG] LLM增强结果已缓存到: {cache_file}")
91+
except Exception as e:
92+
print(f"[DEBUG] 缓存保存失败: {e}")
93+
94+
return enhanced_results
95+
except Exception as e:
96+
print(f"[DEBUG] LLM增强失败: {type(e).__name__}: {e}")
97+
return basic_results
98+
99+
@abstractmethod
100+
def _llm_enhance(self, basic_results: List[str]) -> List[str]:
101+
"""使用LLM增强基本提取结果的具体实现"""
102+
pass
Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,78 @@
1-
# webmainbench/metrics/extractors/table_extractor.pyimport refrom bs4 import BeautifulSoupfrom typing import List, Dict, Anyfrom .base_content_splitter import BaseContentSplitterclass TableSplitter(BaseContentSplitter): """从文本中提取表格""" def extract(self, text: str, field_name: str = None) -> str: """提取表格""" tables = self.extract_basic(text) if self.should_use_llm(field_name): table_parts = self.enhance_with_llm(tables) else: table_parts = tables return '\n'.join(table_parts) def extract_basic(self, text: str) -> List[str]: """基本表格提取方法""" table_parts = [] # HTML表格提取 soup = BeautifulSoup(text, "html.parser") for table in soup.find_all("table"): if not table.find_parent(["td", "tr", "tbody", "table"]): table_parts.append(str(table)) # Markdown表格提取 lines = text.split('\n') table_lines = [] in_markdown_table = False def is_md_table_line(line): """判断是否可能是 Markdown 表格行""" if line.count("|") < 1: return False return True def is_md_separator_line(line): """判断是否为 Markdown 分隔行""" parts = [p.strip() for p in line.split("|")] for p in parts: if p and not re.match(r"^:?\-{3,}:?$", p): return False return True def save_table(): """保存当前表格并清空缓存""" nonlocal table_lines if len(table_lines) >= 2 and is_md_separator_line(table_lines[1]): md_table = '\n'.join(table_lines) table_parts.append(md_table) for line in lines: if is_md_table_line(line): table_lines.append(line) in_markdown_table = True else: if in_markdown_table: save_table() table_lines = [] in_markdown_table = False # 处理文档末尾的 Markdown 表格 if in_markdown_table: save_table() return table_parts def _llm_enhance(self, basic_results: List[str]) -> List[str]: """使用LLM增强表格提取结果(未实现)""" print(f"[DEBUG] 表格LLM增强功能尚未实现,返回原始结果") return basic_results
1+
# webmainbench/metrics/extractors/table_extractor.py
2+
import re
3+
from bs4 import BeautifulSoup
4+
from typing import List, Dict, Any
5+
6+
from .base_content_splitter import BaseContentSplitter
7+
8+
9+
class TableSplitter(BaseContentSplitter):
10+
"""从文本中提取表格"""
11+
12+
def extract(self, text: str, field_name: str = None) -> str:
13+
"""提取表格"""
14+
tables = self.extract_basic(text)
15+
16+
if self.should_use_llm(field_name):
17+
table_parts = self.enhance_with_llm(tables)
18+
else:
19+
table_parts = tables
20+
21+
return '\n'.join(table_parts)
22+
23+
def extract_basic(self, text: str) -> List[str]:
24+
"""基本表格提取方法"""
25+
table_parts = []
26+
27+
# HTML表格提取
28+
soup = BeautifulSoup(text, "html.parser")
29+
for table in soup.find_all("table"):
30+
if not table.find_parent(["td", "tr", "tbody", "table"]):
31+
table_parts.append(str(table))
32+
33+
# Markdown表格提取
34+
lines = text.split('\n')
35+
table_lines = []
36+
in_markdown_table = False
37+
38+
def is_md_table_line(line):
39+
"""判断是否可能是 Markdown 表格行"""
40+
if line.count("|") < 1:
41+
return False
42+
return True
43+
44+
def is_md_separator_line(line):
45+
"""判断是否为 Markdown 分隔行"""
46+
parts = [p.strip() for p in line.split("|")]
47+
for p in parts:
48+
if p and not re.match(r"^:?\-{3,}:?$", p):
49+
return False
50+
return True
51+
52+
def save_table():
53+
"""保存当前表格并清空缓存"""
54+
nonlocal table_lines
55+
if len(table_lines) >= 2 and is_md_separator_line(table_lines[1]):
56+
md_table = '\n'.join(table_lines)
57+
table_parts.append(md_table)
58+
59+
for line in lines:
60+
if is_md_table_line(line):
61+
table_lines.append(line)
62+
in_markdown_table = True
63+
else:
64+
if in_markdown_table:
65+
save_table()
66+
table_lines = []
67+
in_markdown_table = False
68+
69+
# 处理文档末尾的 Markdown 表格
70+
if in_markdown_table:
71+
save_table()
72+
73+
return table_parts
74+
75+
def _llm_enhance(self, basic_results: List[str]) -> List[str]:
76+
"""使用LLM增强表格提取结果(未实现)"""
77+
print(f"[DEBUG] 表格LLM增强功能尚未实现,返回原始结果")
78+
return basic_results

0 commit comments

Comments
 (0)