1- from abc import ABC , abstractmethod from typing import List , Dict , Any import os import hashlib import json from openai import OpenAI class BaseContentSplitter (ABC ): """抽象基类,用于从文本中提取特定类型的内容""" # 默认的LLM提示词模板 DEFAULT_LLM_PROMPT = """请处理以下内容: {content} """ def __init__(self, config: Dict[str, Any] = None): """初始化提取器""" self.config = config or {} # 保留这行代码,用于控制是否使用LLM self.use_llm = self.config.get('use_llm', True) # 初始化OpenAI客户端(如果配置了LLM) if self.use_llm and self.config.get('llm_base_url') and self.config.get('llm_api_key'): self.client = OpenAI( base_url=self.config.get('llm_base_url', ""), api_key=self.config.get('llm_api_key', "") ) else: self.client = None self.cache_dir = self.config.get('cache_dir', os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.cache')) os.makedirs(self.cache_dir, exist_ok=True) @abstractmethod def extract(self, text: str, field_name: str = None) -> str: """提取特定类型的内容""" pass @abstractmethod def extract_basic(self, text: str) -> List[str]: """使用基本方法提取内容(通常是正则表达式)""" pass def should_use_llm(self, field_name: str) -> bool: """判断是否应该使用LLM进行增强提取""" if not self.use_llm: return False # 默认逻辑:对groundtruth内容不使用LLM,对其他内容使用 if field_name == "groundtruth_content": print(f"[DEBUG] 检测到groundtruth内容,不使用LLM") return False return True def enhance_with_llm(self, basic_results: List[str], cache_key: str = None) -> List[str]: """使用LLM增强基本提取结果""" if not basic_results: print(f"[DEBUG] 输入内容为空,跳过LLM增强") return [] # 生成缓存键 if cache_key is None: content_str = '\n'.join(basic_results) cache_key = hashlib.md5(content_str.encode('utf-8')).hexdigest() cache_file = os.path.join(self.cache_dir, f'{self.__class__.__name__.lower()}_cache_{cache_key}.json') # 检查缓存 if os.path.exists(cache_file): try: with open(cache_file, 'r', encoding='utf-8') as f: cached_result = json.load(f) print(f"[DEBUG] 从缓存加载LLM增强结果: {len(cached_result)} 个") return cached_result except Exception as e: print(f"[DEBUG] 缓存读取失败: {e}") # 实际的LLM增强逻辑 try: enhanced_results = self._llm_enhance(basic_results) # 保存缓存 try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(enhanced_results, f, ensure_ascii=False, indent=2) print(f"[DEBUG] LLM增强结果已缓存到: {cache_file}") except Exception as e: print(f"[DEBUG] 缓存保存失败: {e}") return enhanced_results except Exception as e: print(f"[DEBUG] LLM增强失败: {type(e).__name__}: {e}") return basic_results @abstractmethod def _llm_enhance(self, basic_results: List[str]) -> List[str]: """使用LLM增强基本提取结果的具体实现""" pass
1+ from abc import ABC , abstractmethod
2+ from typing import List , Dict , Any
3+ import os
4+ import hashlib
5+ import json
6+ from openai import OpenAI
7+
8+
9+ class BaseContentSplitter (ABC ):
10+ """抽象基类,用于从文本中提取特定类型的内容"""
11+
12+ # 默认的LLM提示词模板
13+ DEFAULT_LLM_PROMPT = """请处理以下内容:
14+ {content}
15+ """
16+
17+ def __init__ (self , config : Dict [str , Any ] = None ):
18+ """初始化提取器"""
19+ self .config = config or {}
20+
21+ # 保留这行代码,用于控制是否使用LLM
22+ self .use_llm = self .config .get ('use_llm' , True )
23+
24+ # 初始化OpenAI客户端(如果配置了LLM)
25+ if self .use_llm and self .config .get ('llm_base_url' ) and self .config .get ('llm_api_key' ):
26+ self .client = OpenAI (
27+ base_url = self .config .get ('llm_base_url' , "" ),
28+ api_key = self .config .get ('llm_api_key' , "" )
29+ )
30+ else :
31+ self .client = None
32+
33+ self .cache_dir = self .config .get ('cache_dir' ,
34+ os .path .join (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))),
35+ '.cache' ))
36+ os .makedirs (self .cache_dir , exist_ok = True )
37+
38+ @abstractmethod
39+ def extract (self , text : str , field_name : str = None ) -> str :
40+ """提取特定类型的内容"""
41+ pass
42+
43+ @abstractmethod
44+ def extract_basic (self , text : str ) -> List [str ]:
45+ """使用基本方法提取内容(通常是正则表达式)"""
46+ pass
47+
48+ def should_use_llm (self , field_name : str ) -> bool :
49+ """判断是否应该使用LLM进行增强提取"""
50+ if not self .use_llm :
51+ return False
52+
53+ # 默认逻辑:对groundtruth内容不使用LLM,对其他内容使用
54+ if field_name == "groundtruth_content" :
55+ print (f"[DEBUG] 检测到groundtruth内容,不使用LLM" )
56+ return False
57+ return True
58+
59+ def enhance_with_llm (self , basic_results : List [str ], cache_key : str = None ) -> List [str ]:
60+ """使用LLM增强基本提取结果"""
61+ if not basic_results :
62+ print (f"[DEBUG] 输入内容为空,跳过LLM增强" )
63+ return []
64+
65+ # 生成缓存键
66+ if cache_key is None :
67+ content_str = '\n ' .join (basic_results )
68+ cache_key = hashlib .md5 (content_str .encode ('utf-8' )).hexdigest ()
69+
70+ cache_file = os .path .join (self .cache_dir , f'{ self .__class__ .__name__ .lower ()} _cache_{ cache_key } .json' )
71+
72+ # 检查缓存
73+ if os .path .exists (cache_file ):
74+ try :
75+ with open (cache_file , 'r' , encoding = 'utf-8' ) as f :
76+ cached_result = json .load (f )
77+ print (f"[DEBUG] 从缓存加载LLM增强结果: { len (cached_result )} 个" )
78+ return cached_result
79+ except Exception as e :
80+ print (f"[DEBUG] 缓存读取失败: { e } " )
81+
82+ # 实际的LLM增强逻辑
83+ try :
84+ enhanced_results = self ._llm_enhance (basic_results )
85+
86+ # 保存缓存
87+ try :
88+ with open (cache_file , 'w' , encoding = 'utf-8' ) as f :
89+ json .dump (enhanced_results , f , ensure_ascii = False , indent = 2 )
90+ print (f"[DEBUG] LLM增强结果已缓存到: { cache_file } " )
91+ except Exception as e :
92+ print (f"[DEBUG] 缓存保存失败: { e } " )
93+
94+ return enhanced_results
95+ except Exception as e :
96+ print (f"[DEBUG] LLM增强失败: { type (e ).__name__ } : { e } " )
97+ return basic_results
98+
99+ @abstractmethod
100+ def _llm_enhance (self , basic_results : List [str ]) -> List [str ]:
101+ """使用LLM增强基本提取结果的具体实现"""
102+ pass
0 commit comments