Skip to content

Commit dfbd377

Browse files
committed
update
1 parent 02acf69 commit dfbd377

File tree

8 files changed

+212
-254
lines changed

8 files changed

+212
-254
lines changed

01_ner.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@
6161
# 数据
6262
EVAL_DATA = 4096
6363
DATASET_PATH = [
64-
("/mnt/e/ai/dataset/ner/zh", 2 * 10000 + EVAL_DATA / 4),
65-
("/mnt/e/ai/dataset/ner/en", 2 * 10000 + EVAL_DATA / 4),
66-
("/mnt/e/ai/dataset/ner/jp", 2 * 10000 + EVAL_DATA / 4),
67-
("/mnt/e/ai/dataset/ner/ko", 2 * 10000 + EVAL_DATA / 4),
64+
("/mnt/e/ai/dataset/ner/zh/20250102", 2 * 10000 + EVAL_DATA / 4),
65+
("/mnt/e/ai/dataset/ner/en/20250102", 2 * 10000 + EVAL_DATA / 4),
66+
("/mnt/e/ai/dataset/ner/ja/20250102", 2 * 10000 + EVAL_DATA / 4),
67+
("/mnt/e/ai/dataset/ner/ko/20250102", 2 * 10000 + EVAL_DATA / 4),
6868
]
6969

7070
# 加载模型
@@ -101,12 +101,12 @@ def sample(data: list[dict], limit: int) -> list[dict]:
101101
type_count = {}
102102
for item in data:
103103
for entity in item.get("entities", []):
104-
type_count[entity.get("entity_type")] = type_count.get(entity.get("entity_type"), 0) + 1
104+
type_count[entity.get("entity_group")] = type_count.get(entity.get("entity_group"), 0) + 1
105105
max_k = max(type_count, key = lambda k: type_count.get(k), default="")
106106

107107
# 拆分数据
108-
data_x = [item for item in data if any(entity.get("entity_type") != max_k for entity in item.get("entities", []))]
109-
data_y = [item for item in data if not any(entity.get("entity_type") != max_k for entity in item.get("entities", []))]
108+
data_x = [item for item in data if any(entity.get("entity_group") != max_k for entity in item.get("entities", []))]
109+
data_y = [item for item in data if not any(entity.get("entity_group") != max_k for entity in item.get("entities", []))]
110110

111111
# 随机取样
112112
if len(data_x) >= limit:
@@ -140,8 +140,8 @@ def load_dataset(tokenizer: PreTrainedTokenizerFast) -> tuple[Dataset, Dataset,
140140
types = set()
141141
for v in data:
142142
for entity in v.get("entities", []):
143-
if entity.get("entity_type") != "":
144-
types.add(entity.get("entity_type"))
143+
if entity.get("entity_group") != "":
144+
types.add(entity.get("entity_group"))
145145
id2label = {0: "O"}
146146
for c in list(sorted(types)):
147147
id2label[len(id2label)] = f"B-{c}"
@@ -211,14 +211,14 @@ def load_dataset_map_function(samples: dict, tokenizer: PreTrainedTokenizerFast,
211211
result = []
212212
for entity in entities:
213213
surface = entity.get("surface", "")
214-
entity_type = entity.get("entity_type", "")
214+
entity_group = entity.get("entity_group", "")
215215

216216
# 获取实体词语在字符串中的位置
217217
char_start = sentence.find(surface)
218218
char_end = char_start + len(surface)
219219

220220
# 有效性检查
221-
if char_start < 0 or surface == "" or entity_type == "":
221+
if char_start < 0 or surface == "" or entity_group == "":
222222
continue
223223

224224
# 通过字符位置反查 Token 位置
@@ -228,7 +228,7 @@ def load_dataset_map_function(samples: dict, tokenizer: PreTrainedTokenizerFast,
228228
if token_start == -1 or token_end == -1:
229229
continue
230230

231-
result.append((token_start, token_end, entity_type))
231+
result.append((token_start, token_end, entity_group))
232232

233233
# 生成 labels
234234
labels = [0 for _ in range(len(input_ids))]

92_corpus_ner.py

Lines changed: 118 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import re
21
import json
32
import asyncio
43
import argparse
@@ -8,6 +7,8 @@
87
from openai import AsyncOpenAI
98
from aiolimiter import AsyncLimiter
109

10+
from moudle.TextHelper import TextHelper
11+
1112
# 设置接口
1213
BATCH = 16
1314
MODEL = "no"
@@ -17,11 +18,12 @@
1718
TEMPERATURE = 0.50
1819

1920
# 设置任务参数
20-
TIMEOUT = 180
21-
CHUNK_SIZE = 10
21+
TIMEOUT = 300
22+
CHUNK_SIZE = 4
2223

23-
# 线程锁
24-
LOCK = threading.Lock()
24+
# 锁
25+
LOCK_ASYNCIO = asyncio.Lock()
26+
LOCK_THREADING = threading.Lock()
2527

2628
# 限制器
2729
SEMAPHORE = asyncio.Semaphore(BATCH)
@@ -37,117 +39,119 @@
3739
def split(datas: list[str], size: int) -> list[list[str]]:
3840
return [datas[i:i + size] for i in range(0, len(datas), size)]
3941

40-
# 安全加载 JSON 字典
41-
def safe_load_json_dict(json_str: str) -> dict:
42-
result = {}
43-
44-
# 移除首尾空白符(含空格、制表符、换行符)
45-
json_str = json_str.strip()
46-
47-
# 移除代码标识
48-
json_str = json_str.removeprefix("```json").removeprefix("```").strip()
49-
50-
# 先尝试使用 json.loads 解析
51-
try:
52-
result = json.loads(json_str)
53-
except Exception:
54-
pass
55-
56-
# 否则使用正则表达式匹配
57-
if len(result) == 0:
58-
for item in re.findall(r"['\"].+?['\"]\s*\:\s*['\"].+?['\"]\s*(?=[,}])", json_str, flags = re.IGNORECASE):
59-
p = item.split(":")
60-
result[p[0].strip().strip("'\"").strip()] = p[1].strip().strip("'\"").strip()
61-
62-
return result
63-
64-
# 安全加载 JSON 列表
65-
def safe_load_json_list(json_str: str) -> list:
66-
result = []
67-
68-
# 移除首尾空白符(含空格、制表符、换行符)
69-
json_str = json_str.strip()
70-
71-
# 移除代码标识
72-
json_str = json_str.removeprefix("```json").removeprefix("```").strip()
73-
74-
# 先尝试使用 json.loads 解析
75-
try:
76-
result = json.loads(json_str)
77-
except Exception:
78-
pass
79-
80-
# 否则使用正则表达式匹配
81-
if len(result) == 0:
82-
for item in re.findall(r"\{.+?\}", json_str, flags = re.IGNORECASE):
83-
result.append(safe_load_json_dict(item))
84-
85-
return result
42+
# 写入文件
43+
def write(target: str, data: dict) -> None:
44+
with LOCK_THREADING:
45+
with open(target, "w", encoding = "utf-8") as writer:
46+
writer.write(json.dumps(data, indent = 4, ensure_ascii = False))
8647

8748
# 发起请求
88-
async def request(lines: list[str], prompt: str, tasks: list[asyncio.Task], success: list[str], failure: list[str]) -> None:
89-
async with SEMAPHORE, ASYNCLIMITER:
90-
try:
91-
llm_request, llm_response, error = None, None, None
92-
93-
messages = [
94-
{
95-
"role": "system",
96-
"content": prompt,
97-
},
98-
{
99-
"role": "user",
100-
"content": "\n".join(lines),
101-
}
102-
]
103-
104-
llm_request = {
105-
"model": MODEL,
106-
"stream": False,
107-
"temperature": TEMPERATURE,
108-
"top_p": TOP_P,
109-
"max_tokens": 4096,
110-
# "frequency_penalty" : 0.2 if retry == True else 0,
111-
"messages": messages,
49+
async def request(prompt: str, content: str) -> tuple[Exception, dict, dict]:
50+
try:
51+
llm_request, llm_response, error = None, None, None
52+
53+
messages = [
54+
{
55+
"role": "system",
56+
"content": prompt,
57+
},
58+
{
59+
"role": "user",
60+
"content": content,
11261
}
113-
114-
completion = await OPENAICLIENT.chat.completions.create(**llm_request)
115-
116-
# OpenAI 的 API 返回的对象通常是 OpenAIObject 类型
117-
# 该类有一个内置方法可以将其转换为字典
118-
llm_response = completion.to_dict()
119-
usage = completion.usage
120-
content = completion.choices[0].message.content.strip()
121-
122-
# 检查是否超过最大 token 限制
123-
if usage.completion_tokens >= 4096:
124-
raise Exception("超过最大 token 限制")
125-
126-
json_list = safe_load_json_list(content)
127-
if len(json_list) == 0:
128-
raise Exception("无法解析 JSON 列表")
129-
except Exception as e:
130-
error = e
131-
finally:
132-
with LOCK:
133-
if error == None:
134-
success.append({
135-
"request": llm_request,
136-
"response": llm_response,
137-
})
138-
print(f"成功 {len(success)} 个,失败 {len(failure)} 个,剩余 {len(tasks) - len(success) - len(failure)} 个任务 ... ")
139-
else:
140-
failure.append({
141-
"error": str(error),
142-
"request": llm_request,
143-
"response": llm_response,
144-
})
145-
print(f"成功 {len(success)} 个,失败 {len(failure)} 个,剩余 {len(tasks) - len(success) - len(failure)} 个任务 ... {str(error)}")
62+
]
63+
64+
llm_request = {
65+
"model": MODEL,
66+
"stream": False,
67+
"temperature": TEMPERATURE,
68+
"top_p": TOP_P,
69+
"max_tokens": 2048,
70+
"messages": messages,
71+
}
72+
73+
# 获取回复
74+
completion = await OPENAICLIENT.chat.completions.create(**llm_request)
75+
76+
# OpenAI 的 API 返回的对象通常是 OpenAIObject 类型
77+
# 该类有一个内置方法可以将其转换为字典
78+
llm_response = completion.to_dict()
79+
result = TextHelper.safe_load_json_list(completion.choices[0].message.content.strip())
80+
if len(result) == 0:
81+
raise Exception("没有解析到有效 JSON 数据 ...")
82+
except Exception as e:
83+
error = e
84+
finally:
85+
return llm_request, llm_response, error
86+
87+
# 成功时
88+
async def on_success(llm_request: dict, llm_response: dict, error: Exception, tasks: list[asyncio.Task], success: list[str], failure: list[str]) -> None:
89+
async with LOCK_ASYNCIO:
90+
success.append({
91+
"request": llm_request,
92+
"response": llm_response,
93+
})
94+
print(f"成功 {len(success)} 个,失败 {len(failure)} 个,剩余 {len(tasks) - len(success) - len(failure)} 个任务 ...")
95+
96+
# 失败时
97+
async def on_failure(llm_request: dict, llm_response: dict, error: Exception, tasks: list[asyncio.Task], success: list[str], failure: list[str]) -> None:
98+
async with LOCK_ASYNCIO:
99+
failure.append({
100+
"error": str(error),
101+
"request": llm_request,
102+
"response": llm_response,
103+
})
104+
print(f"成功 {len(success)} 个,失败 {len(failure)} 个,剩余 {len(tasks) - len(success) - len(failure)} 个任务 ... {str(error)}")
105+
106+
# 执行任务
107+
async def start(target: str, prompt_llm_check: str, prompt_llm_recognize: str, lines: list[str], tasks: list[asyncio.Task], success: list[str], failure: list[str]) -> None:
108+
async with SEMAPHORE, ASYNCLIMITER:
109+
# 获取 LLM 识别结果
110+
error = None
111+
llm_request, llm_response, error = await request(
112+
prompt_llm_recognize,
113+
"\n".join(lines),
114+
)
115+
116+
if error == None:
117+
pass
118+
else:
119+
await on_failure(llm_request, llm_response, error, tasks, success, failure)
120+
return
121+
122+
# 数据处理
123+
result = {}
124+
result["entities"] = TextHelper.safe_load_json_list(llm_response.get("choices")[0].get("message").get("content").strip())
125+
result["sentences"] = "\n".join(lines)
126+
127+
# 获取 LLM 检查结果
128+
error = None
129+
llm_request, llm_response, error = await request(
130+
prompt_llm_check,
131+
json.dumps(
132+
result,
133+
indent = None,
134+
ensure_ascii = False,
135+
),
136+
)
137+
138+
if error == None:
139+
await on_success(llm_request, llm_response, error, tasks, success, failure)
140+
else:
141+
await on_failure(llm_request, llm_response, error, tasks, success, failure)
142+
143+
# 写入文件
144+
if len(success) + len(failure) > 0 and (len(success) + len(failure)) % 5 == 0:
145+
write(f"{target.replace(".txt", "")}_failure.log", failure)
146+
write(f"{target.replace(".txt", "")}_success.log", success)
146147

147148
# 主函数
148149
async def main(target: str) -> None:
149-
with open("prompt/llm_ner.txt", "r", encoding = "utf-8") as reader:
150-
prompt = reader.read().strip()
150+
with open("prompt/llm_check.txt", "r", encoding = "utf-8") as reader:
151+
prompt_llm_check = reader.read().strip()
152+
153+
with open("prompt/llm_recognize.txt", "r", encoding = "utf-8") as reader:
154+
prompt_llm_recognize = reader.read().strip()
151155

152156
with open(target, "r", encoding = "utf-8") as reader:
153157
lines = [line.strip() for line in reader.readlines() if line.strip() != ""]
@@ -161,16 +165,12 @@ async def main(target: str) -> None:
161165
# 执行并发任务
162166
tasks = []
163167
for lines in line_chunks:
164-
tasks.append(asyncio.create_task(request(lines, prompt, tasks, success, failure)))
168+
tasks.append(asyncio.create_task(start(target, prompt_llm_check, prompt_llm_recognize, lines, tasks, success, failure)))
165169
await asyncio.gather(*tasks, return_exceptions = True)
166170

167-
# 写入成功日志
168-
with open(f"{target.replace(".txt", "")}_success.log", "w", encoding = "utf-8") as writer:
169-
writer.write(json.dumps(success, indent = 4, ensure_ascii = False))
170-
171-
# 写入失败日志
172-
with open(f"{target.replace(".txt", "")}_failure.log", "w", encoding = "utf-8") as writer:
173-
writer.write(json.dumps(failure, indent = 4, ensure_ascii = False))
171+
# 写入文件
172+
write(f"{target.replace(".txt", "")}_failure.log", failure)
173+
write(f"{target.replace(".txt", "")}_success.log", success)
174174

175175
# 入口函数
176176
if __name__ == "__main__":

0 commit comments

Comments
 (0)