1- import re
21import json
32import asyncio
43import argparse
87from openai import AsyncOpenAI
98from aiolimiter import AsyncLimiter
109
10+ from moudle .TextHelper import TextHelper
11+
1112# 设置接口
1213BATCH = 16
1314MODEL = "no"
1718TEMPERATURE = 0.50
1819
1920# 设置任务参数
20- TIMEOUT = 180
21- CHUNK_SIZE = 10
21+ TIMEOUT = 300
22+ CHUNK_SIZE = 4
2223
23- # 线程锁
24- LOCK = threading .Lock ()
24+ # 锁
25+ LOCK_ASYNCIO = asyncio .Lock ()
26+ LOCK_THREADING = threading .Lock ()
2527
2628# 限制器
2729SEMAPHORE = asyncio .Semaphore (BATCH )
3739def split (datas : list [str ], size : int ) -> list [list [str ]]:
3840 return [datas [i :i + size ] for i in range (0 , len (datas ), size )]
3941
40- # 安全加载 JSON 字典
41- def safe_load_json_dict (json_str : str ) -> dict :
42- result = {}
43-
44- # 移除首尾空白符(含空格、制表符、换行符)
45- json_str = json_str .strip ()
46-
47- # 移除代码标识
48- json_str = json_str .removeprefix ("```json" ).removeprefix ("```" ).strip ()
49-
50- # 先尝试使用 json.loads 解析
51- try :
52- result = json .loads (json_str )
53- except Exception :
54- pass
55-
56- # 否则使用正则表达式匹配
57- if len (result ) == 0 :
58- for item in re .findall (r"['\"].+?['\"]\s*\:\s*['\"].+?['\"]\s*(?=[,}])" , json_str , flags = re .IGNORECASE ):
59- p = item .split (":" )
60- result [p [0 ].strip ().strip ("'\" " ).strip ()] = p [1 ].strip ().strip ("'\" " ).strip ()
61-
62- return result
63-
64- # 安全加载 JSON 列表
65- def safe_load_json_list (json_str : str ) -> list :
66- result = []
67-
68- # 移除首尾空白符(含空格、制表符、换行符)
69- json_str = json_str .strip ()
70-
71- # 移除代码标识
72- json_str = json_str .removeprefix ("```json" ).removeprefix ("```" ).strip ()
73-
74- # 先尝试使用 json.loads 解析
75- try :
76- result = json .loads (json_str )
77- except Exception :
78- pass
79-
80- # 否则使用正则表达式匹配
81- if len (result ) == 0 :
82- for item in re .findall (r"\{.+?\}" , json_str , flags = re .IGNORECASE ):
83- result .append (safe_load_json_dict (item ))
84-
85- return result
42+ # 写入文件
43+ def write (target : str , data : dict ) -> None :
44+ with LOCK_THREADING :
45+ with open (target , "w" , encoding = "utf-8" ) as writer :
46+ writer .write (json .dumps (data , indent = 4 , ensure_ascii = False ))
8647
8748# 发起请求
88- async def request (lines : list [str ], prompt : str , tasks : list [asyncio .Task ], success : list [str ], failure : list [str ]) -> None :
89- async with SEMAPHORE , ASYNCLIMITER :
90- try :
91- llm_request , llm_response , error = None , None , None
92-
93- messages = [
94- {
95- "role" : "system" ,
96- "content" : prompt ,
97- },
98- {
99- "role" : "user" ,
100- "content" : "\n " .join (lines ),
101- }
102- ]
103-
104- llm_request = {
105- "model" : MODEL ,
106- "stream" : False ,
107- "temperature" : TEMPERATURE ,
108- "top_p" : TOP_P ,
109- "max_tokens" : 4096 ,
110- # "frequency_penalty" : 0.2 if retry == True else 0,
111- "messages" : messages ,
49+ async def request (prompt : str , content : str ) -> tuple [Exception , dict , dict ]:
50+ try :
51+ llm_request , llm_response , error = None , None , None
52+
53+ messages = [
54+ {
55+ "role" : "system" ,
56+ "content" : prompt ,
57+ },
58+ {
59+ "role" : "user" ,
60+ "content" : content ,
11261 }
113-
114- completion = await OPENAICLIENT .chat .completions .create (** llm_request )
115-
116- # OpenAI 的 API 返回的对象通常是 OpenAIObject 类型
117- # 该类有一个内置方法可以将其转换为字典
118- llm_response = completion .to_dict ()
119- usage = completion .usage
120- content = completion .choices [0 ].message .content .strip ()
121-
122- # 检查是否超过最大 token 限制
123- if usage .completion_tokens >= 4096 :
124- raise Exception ("超过最大 token 限制" )
125-
126- json_list = safe_load_json_list (content )
127- if len (json_list ) == 0 :
128- raise Exception ("无法解析 JSON 列表" )
129- except Exception as e :
130- error = e
131- finally :
132- with LOCK :
133- if error == None :
134- success .append ({
135- "request" : llm_request ,
136- "response" : llm_response ,
137- })
138- print (f"成功 { len (success )} 个,失败 { len (failure )} 个,剩余 { len (tasks ) - len (success ) - len (failure )} 个任务 ... " )
139- else :
140- failure .append ({
141- "error" : str (error ),
142- "request" : llm_request ,
143- "response" : llm_response ,
144- })
145- print (f"成功 { len (success )} 个,失败 { len (failure )} 个,剩余 { len (tasks ) - len (success ) - len (failure )} 个任务 ... { str (error )} " )
62+ ]
63+
64+ llm_request = {
65+ "model" : MODEL ,
66+ "stream" : False ,
67+ "temperature" : TEMPERATURE ,
68+ "top_p" : TOP_P ,
69+ "max_tokens" : 2048 ,
70+ "messages" : messages ,
71+ }
72+
73+ # 获取回复
74+ completion = await OPENAICLIENT .chat .completions .create (** llm_request )
75+
76+ # OpenAI 的 API 返回的对象通常是 OpenAIObject 类型
77+ # 该类有一个内置方法可以将其转换为字典
78+ llm_response = completion .to_dict ()
79+ result = TextHelper .safe_load_json_list (completion .choices [0 ].message .content .strip ())
80+ if len (result ) == 0 :
81+ raise Exception ("没有解析到有效 JSON 数据 ..." )
82+ except Exception as e :
83+ error = e
84+ finally :
85+ return llm_request , llm_response , error
86+
87+ # 成功时
88+ async def on_success (llm_request : dict , llm_response : dict , error : Exception , tasks : list [asyncio .Task ], success : list [str ], failure : list [str ]) -> None :
89+ async with LOCK_ASYNCIO :
90+ success .append ({
91+ "request" : llm_request ,
92+ "response" : llm_response ,
93+ })
94+ print (f"成功 { len (success )} 个,失败 { len (failure )} 个,剩余 { len (tasks ) - len (success ) - len (failure )} 个任务 ..." )
95+
96+ # 失败时
97+ async def on_failure (llm_request : dict , llm_response : dict , error : Exception , tasks : list [asyncio .Task ], success : list [str ], failure : list [str ]) -> None :
98+ async with LOCK_ASYNCIO :
99+ failure .append ({
100+ "error" : str (error ),
101+ "request" : llm_request ,
102+ "response" : llm_response ,
103+ })
104+ print (f"成功 { len (success )} 个,失败 { len (failure )} 个,剩余 { len (tasks ) - len (success ) - len (failure )} 个任务 ... { str (error )} " )
105+
106+ # 执行任务
107+ async def start (target : str , prompt_llm_check : str , prompt_llm_recognize : str , lines : list [str ], tasks : list [asyncio .Task ], success : list [str ], failure : list [str ]) -> None :
108+ async with SEMAPHORE , ASYNCLIMITER :
109+ # 获取 LLM 识别结果
110+ error = None
111+ llm_request , llm_response , error = await request (
112+ prompt_llm_recognize ,
113+ "\n " .join (lines ),
114+ )
115+
116+ if error == None :
117+ pass
118+ else :
119+ await on_failure (llm_request , llm_response , error , tasks , success , failure )
120+ return
121+
122+ # 数据处理
123+ result = {}
124+ result ["entities" ] = TextHelper .safe_load_json_list (llm_response .get ("choices" )[0 ].get ("message" ).get ("content" ).strip ())
125+ result ["sentences" ] = "\n " .join (lines )
126+
127+ # 获取 LLM 检查结果
128+ error = None
129+ llm_request , llm_response , error = await request (
130+ prompt_llm_check ,
131+ json .dumps (
132+ result ,
133+ indent = None ,
134+ ensure_ascii = False ,
135+ ),
136+ )
137+
138+ if error == None :
139+ await on_success (llm_request , llm_response , error , tasks , success , failure )
140+ else :
141+ await on_failure (llm_request , llm_response , error , tasks , success , failure )
142+
143+ # 写入文件
144+ if len (success ) + len (failure ) > 0 and (len (success ) + len (failure )) % 5 == 0 :
145+ write (f"{ target .replace (".txt" , "" )} _failure.log" , failure )
146+ write (f"{ target .replace (".txt" , "" )} _success.log" , success )
146147
147148# 主函数
148149async def main (target : str ) -> None :
149- with open ("prompt/llm_ner.txt" , "r" , encoding = "utf-8" ) as reader :
150- prompt = reader .read ().strip ()
150+ with open ("prompt/llm_check.txt" , "r" , encoding = "utf-8" ) as reader :
151+ prompt_llm_check = reader .read ().strip ()
152+
153+ with open ("prompt/llm_recognize.txt" , "r" , encoding = "utf-8" ) as reader :
154+ prompt_llm_recognize = reader .read ().strip ()
151155
152156 with open (target , "r" , encoding = "utf-8" ) as reader :
153157 lines = [line .strip () for line in reader .readlines () if line .strip () != "" ]
@@ -161,16 +165,12 @@ async def main(target: str) -> None:
161165 # 执行并发任务
162166 tasks = []
163167 for lines in line_chunks :
164- tasks .append (asyncio .create_task (request ( lines , prompt , tasks , success , failure )))
168+ tasks .append (asyncio .create_task (start ( target , prompt_llm_check , prompt_llm_recognize , lines , tasks , success , failure )))
165169 await asyncio .gather (* tasks , return_exceptions = True )
166170
167- # 写入成功日志
168- with open (f"{ target .replace (".txt" , "" )} _success.log" , "w" , encoding = "utf-8" ) as writer :
169- writer .write (json .dumps (success , indent = 4 , ensure_ascii = False ))
170-
171- # 写入失败日志
172- with open (f"{ target .replace (".txt" , "" )} _failure.log" , "w" , encoding = "utf-8" ) as writer :
173- writer .write (json .dumps (failure , indent = 4 , ensure_ascii = False ))
171+ # 写入文件
172+ write (f"{ target .replace (".txt" , "" )} _failure.log" , failure )
173+ write (f"{ target .replace (".txt" , "" )} _success.log" , success )
174174
175175# 入口函数
176176if __name__ == "__main__" :
0 commit comments