Skip to content

Commit 186fd7b

Browse files
committed
update
1 parent 8ba4665 commit 186fd7b

17 files changed

+595
-203
lines changed

00.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@
3232
from model.NERTrainerCallback import NERTrainerCallback
3333

3434
# 参数设置
35-
MODEL_NAME = "facebookai_xlm_roberta_base_pretrain_20240823"
35+
MODEL_NAME = "facebookai_xlm_roberta_base_pretrain_20240826"
3636
MODEL_PATH = f"assets/{MODEL_NAME}"
3737
OUTPUT_PATH = "output"
3838
DATASET_PATH = "dataset/ner"
3939
EPOCHS = 24
4040
PATIENCE = 12
4141
PATIENCE_KEEPER = 0
4242
BATCH_SIZE = 32
43-
GRADIENT_ACCUMULATION_SIZE = 64
43+
GRADIENT_ACCUMULATION_SIZE = 32
4444
FROZEN_LAYER = 0
4545
LEARNING_RATE = 2 * 1e-5
4646
DO_LOWER_CASE = False
@@ -65,13 +65,13 @@ def load_dataset(tokenizer):
6565
if file.name.endswith(".json"):
6666
with open(file.path, "r", encoding = "utf-8") as file:
6767
count = count + 1
68-
datas.extend(json.load(file))
68+
datas.extend(random.sample(json.load(file), 10000))
6969

7070
print(f"")
7171
print(f"找到数据文件 {count} 个,共 {len(datas)} 条数据 ...")
7272

7373
# 分割数据集
74-
train_datas, test_datas = train_test_split(datas, test_size = 0.02, shuffle = True, random_state = 42)
74+
train_datas, test_datas = train_test_split(datas, test_size = 0.025, shuffle = True, random_state = 42)
7575

7676
# 创建数据集和数据加载器
7777
print(f"")

01.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
("dataset/pretrain/en_r18_visual_novels", 20 * 10000),
4242
("dataset/pretrain/zh", 20 * 10000),
4343
("dataset/pretrain/zh_r18_pixiv", 20 * 10000),
44-
("dataset/pretrain/jp", 17 * 10000),
45-
("dataset/pretrain/jp_r18", 17 * 10000),
46-
("dataset/pretrain/jp_r18_rpgmaker", 6 * 10000),
44+
("dataset/pretrain/jp", 15 * 10000),
45+
("dataset/pretrain/jp_r18", 15 * 10000),
46+
("dataset/pretrain/jp_r18_rpgmaker", 10 * 10000),
47+
("dataset/pretrain/kr", 40 * 10000),
4748
]
4849

4950
# 加载分词器
@@ -175,11 +176,9 @@ def load_dataset(tokenizer):
175176
count = count + 1
176177
with open(f"{dir_path}/{MODEL_NAME}_{dir_name}.txt", "r", encoding = "utf-8") as file:
177178
datas_by_type = [line.strip() for line in tqdm(file, desc = path, total = num)]
178-
random.shuffle(datas_by_type)
179179
else:
180-
total = len([entry for entry in os.listdir(path) if os.path.isfile(os.path.join(path, entry))])
181-
182180
lines = []
181+
total = len([f for f in os.scandir(path) if f.name.endswith(".txt")])
183182
for file in tqdm(os.scandir(path), desc = path, total = total):
184183
if file.name.endswith(".txt"):
185184
with open(file.path, "r", encoding = "utf-8") as file:
@@ -201,6 +200,7 @@ def load_dataset(tokenizer):
201200
datas.extend(datas_by_type)
202201

203202
# 生成数据集
203+
random.shuffle(datas)
204204
os.makedirs("cache", exist_ok = True)
205205
dataset_train = Dataset.from_dict({"line": datas})
206206
dataset_train_tokenized = dataset_train.map(

02.py

Lines changed: 43 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -6,185 +6,34 @@
66

77
from datetime import datetime
88

9-
import cohere
109
from rich import print
1110
from openai import AsyncOpenAI
1211
from aiolimiter import AsyncLimiter
1312

14-
PROMPT_JP = (
15-
"""
16-
请生成用于实体识别模型训练的日文合成语料,并检查其质量。
17-
生成时,请遵循以下内容要求、实体类别和质量标准:
13+
# 设置任务参数
14+
BATCH = 8
15+
TIMEOUT = 180
16+
MAX_LOOP = 32
17+
LOOP_SIZE = 128
18+
19+
MODEL = "glm-4-9b-chat"
20+
API_KEY = "sk-no-key-required"
21+
BASE_URL = "http://localhost:8080/v1"
22+
TOP_P = 0.95
23+
TEMPERATURE = 0.95
24+
PRESENCE_PENALTY = 0.95
25+
FREQUENCY_PENALTY = 0.00
26+
27+
PROMPT = {}
28+
for f in os.scandir(f"prompt"):
29+
if f.name.startswith("llm_corpus"):
30+
with open(f"prompt/{f.name}", "r", encoding = "utf-8") as file:
31+
PROMPT[f.name.replace(".txt", "")] = file.read().strip()
32+
33+
LANGUAGE = "kr"
34+
OUTPUT_PATH = f"dataset/ner/{LANGUAGE}"
35+
TARGET_PROMPT = PROMPT.get(f"llm_corpus_{LANGUAGE}")
1836

19-
内容要求:
20-
1、生成语句:
21-
生成10个语句,每个语句包含2-4个不同类别的实体。
22-
每个类别的实体在每个语句中最多出现一次,以确保多样性。
23-
24-
2、实体使用:
25-
每个语句中的实体词语之间不要相互包含。
26-
实体词语应为日文片假名或平假名形式,避免使用英文或汉字。
27-
28-
3、符号使用:
29-
除语法上必要的情况外,避免使用《》、「」、『』等符号包裹实体词语。
30-
31-
4、多样性与独特性:
32-
语句应展现多样性,避免重复或相似的语句结构和实体。
33-
使用随机性和不同的句子模板来增加多样性。
34-
35-
5、语句类型:
36-
语句类型应包括但不限于旁白、对话、场景描述、第一人称视角、第三人称视角等。
37-
38-
6、题材涵盖:
39-
语句题材应涵盖异世界、转生、穿越、奇幻、冒险、战争、科幻、历史、战国、中华风、中世纪、超能力、校园恋爱、运动竞技等轻小说常见题材。
40-
41-
实体类别:
42-
1、人名(PER):包括个体的人名,常见的名字、昵称、艺名、历史人物名字等,不包括代指人的称谓、头衔、职业和代词等。
43-
2、组织与团体(ORG):包括公司、机构、政府组织、非政府组织、学校、家族、门派等组织与团体。
44-
3、地点与设施(LOC):包括国家、城市、州、省、街道、自然地理实体(如河流、山脉)等地点或建筑物、地标、机场、桥梁、剧院、体育场等设施。
45-
4、产品与道具(PRD):包括物品、道具、商品、品牌、技术产品等。
46-
5、事件(EVT):包括历史事件、会议、发布会、庆典、比赛等。
47-
48-
49-
质量标准:
50-
1、生成的句子应具备高语言质量,确保流畅且自然。
51-
2、各类实体在句子中的分布应合理,避免单一类型实体的过多重复。
52-
3、确保生成的句子在日文语境中具有逻辑性和可读性,避免语法错误或不自然的表达。
53-
4、验证实体的使用是否符合其定义,确保它们在上下文中扮演合理的角色。
54-
55-
回复使用JSON格式,回复中仅需要以下数据,不要出现其他文字或者描述:
56-
[
57-
{
58-
"sentence": "<日文语句>",
59-
"entities": [
60-
{"name": "<实体名称>", "ner_type": "<PER/ORG/LOC/INS/PRD/EVT>"},
61-
{"name": "<实体名称>", "ner_type": "<PER/ORG/LOC/INS/PRD/EVT>"}
62-
]
63-
}
64-
]
65-
"""
66-
)
67-
68-
PROMPT_CN = (
69-
"""
70-
请生成用于实体识别模型训练的中文合成语料,并检查其质量。
71-
生成时,请遵循以下内容要求、实体类别和质量标准:
72-
73-
内容要求:
74-
1、生成语句:
75-
生成10个语句,每个语句包含2-4个不同类别的实体。
76-
每个类别的实体在每个语句中最多出现一次,以确保多样性。
77-
78-
2、实体使用:
79-
每个语句中的实体词语之间不要相互包含。
80-
81-
3、符号使用:
82-
除语法上必要的情况外,避免使用《》、「」、『』等符号包裹实体词语。
83-
84-
4、多样性与独特性:
85-
语句应展现多样性,避免重复或相似的语句结构和实体。
86-
使用随机性和不同的句子模板来增加多样性。
87-
88-
5、语句类型:
89-
语句类型应包括但不限于旁白、对话、场景描述、第一人称视角、第三人称视角等。
90-
91-
6、题材涵盖:
92-
语句题材应涵盖异世界、转生、穿越、奇幻、冒险、战争、科幻、历史、中世纪、超能力、校园恋爱、运动竞技等轻小说常见题材。
93-
94-
实体类别:
95-
1、人名(PER):包括个体的人名,常见的名字、昵称、艺名、历史人物名字等,不包括代指人的称谓、头衔、职业和代词等。
96-
2、组织与团体(ORG):包括公司、机构、政府组织、非政府组织、学校、家族、门派等组织与团体。
97-
3、地点与设施(LOC):包括国家、城市、州、省、街道、自然地理实体(如河流、山脉)等地点或建筑物、地标、机场、桥梁、剧院、体育场等设施。
98-
4、产品与道具(PRD):包括物品、道具、商品、品牌、技术产品等。
99-
5、事件(EVT):包括历史事件、会议、发布会、庆典、比赛等。
100-
101-
质量标准:
102-
1、生成的句子应具备高语言质量,确保流畅且自然。
103-
2、各类实体在句子中的分布应合理,避免单一类型实体的过多重复。
104-
3、确保生成的句子应具有逻辑性和可读性,避免语法错误或不自然的表达。
105-
4、验证实体的使用是否符合其定义,确保它们在上下文中扮演合理的角色。
106-
107-
回复使用JSON格式,回复中仅需要以下数据,不要出现其他文字或者描述:
108-
[
109-
{
110-
"sentence": "<中文句子>",
111-
"entities": [
112-
{"name": "<实体名称>", "ner_type": "<PER/ORG/LOC/PRD/EVT>"},
113-
{"name": "<实体名称>", "ner_type": "<PER/ORG/LOC/PRD/EVT>"}
114-
]
115-
}
116-
]
117-
"""
118-
)
119-
120-
PROMPT_EN = (
121-
"""
122-
请生成用于实体识别模型训练的英文合成语料,并检查其质量。
123-
生成时,请遵循以下内容要求、实体类别和质量标准:
124-
125-
内容要求:
126-
1、生成语句:
127-
生成10个语句,每个语句包含2-4个不同类别的实体。
128-
每个类别的实体在每个语句中最多出现一次,以确保多样性。
129-
130-
2、实体使用:
131-
每个语句中的实体词语之间不要相互包含。
132-
133-
3、符号使用:
134-
除语法上必要的情况外,避免使用《》、「」、『』等符号包裹实体词语。
135-
136-
4、多样性与独特性:
137-
语句应展现多样性,避免重复或相似的语句结构和实体。
138-
使用随机性和不同的句子模板来增加多样性。
139-
140-
5、语句类型:
141-
语句类型应包括但不限于旁白、对话、场景描述、第一人称视角、第三人称视角等。
142-
143-
6、题材涵盖:
144-
语句题材应涵盖异世界、转生、穿越、奇幻、冒险、战争、科幻、历史、中世纪、超能力、校园恋爱、运动竞技等轻小说常见题材。
145-
146-
实体类别:
147-
1、人名(PER):包括个体的人名,常见的名字、昵称、艺名、历史人物名字等,不包括代指人的称谓、头衔、职业和代词等。
148-
2、组织与团体(ORG):包括公司、机构、政府组织、非政府组织、学校、家族、门派等组织与团体。
149-
3、地点与设施(LOC):包括国家、城市、州、省、街道、自然地理实体(如河流、山脉)等地点或建筑物、地标、机场、桥梁、剧院、体育场等设施。
150-
4、产品与道具(PRD):包括物品、道具、商品、品牌、技术产品等。
151-
5、事件(EVT):包括历史事件、会议、发布会、庆典、比赛等。
152-
153-
质量标准:
154-
1、生成的句子应具备高语言质量,确保流畅且自然。
155-
2、各类实体在句子中的分布应合理,避免单一类型实体的过多重复。
156-
3、确保生成的句子应具有逻辑性和可读性,避免语法错误或不自然的表达。
157-
4、验证实体的使用是否符合其定义,确保它们在上下文中扮演合理的角色。
158-
159-
回复使用JSON格式,回复中仅需要以下数据,不要出现其他文字或者描述:
160-
[
161-
{
162-
"sentence": "<英文句子>",
163-
"entities": [
164-
{"name": "<实体名称>", "ner_type": "<PER/ORG/LOC/PRD/EVT>"},
165-
{"name": "<实体名称>", "ner_type": "<PER/ORG/LOC/PRD/EVT>"}
166-
]
167-
}
168-
]
169-
"""
170-
)
171-
172-
with open("02.json", "r", encoding = "utf-8") as f:
173-
data = json.load(f)
174-
MODEL = data.get("MODEL", "glm-4-9b-chat")
175-
API_KEY = data.get("API_KEY", "sk-no-key-required")
176-
BASE_URL = data.get("BASE_URL", "http://localhost:8080/v1")
177-
178-
BATCH = 32
179-
TIMEOUT = 120
180-
MAX_LOOP = 3
181-
LOOP_SIZE = 256
182-
TEMPERATURE = 1.25
183-
PRESENCE_PENALTY = 1.0
184-
185-
PROMPT = PROMPT_CN
186-
187-
names = set()
18837
semaphore = asyncio.Semaphore(BATCH)
18938
async_limiter = AsyncLimiter(max_rate = BATCH, time_period = 1)
19039
openai_handler = AsyncOpenAI(
@@ -223,19 +72,20 @@ def fix_broken_json_string(jsonstring):
22372

22473
return jsonstring
22574

75+
# 异步请求
22676
async def request():
22777
async with semaphore, async_limiter:
22878
completion = await openai_handler.chat.completions.create(
22979
model = MODEL,
23080
temperature = TEMPERATURE,
231-
top_p = 0.5,
232-
# max_tokens = 4096,
81+
top_p = TOP_P,
82+
max_tokens = 3 * 1024,
23383
presence_penalty = PRESENCE_PENALTY,
234-
frequency_penalty = 0,
84+
frequency_penalty = FREQUENCY_PENALTY,
23585
messages = [
23686
{
23787
"role": "user",
238-
"content": PROMPT.replace("{words}", ",".join(names))
88+
"content": TARGET_PROMPT
23989
},
24090
],
24191
)
@@ -252,13 +102,21 @@ async def request():
252102
print(message.content.strip())
253103
raise e
254104

255-
for v1 in result:
256-
print(v1)
257-
for v2 in v1.get("entities", []):
258-
names.add(v2["name"])
105+
for v in result:
106+
v["sentence"] = v["sentence"].replace("(PER)", "")
107+
v["sentence"] = v["sentence"].replace("(ORG)", "")
108+
v["sentence"] = v["sentence"].replace("(LOC)", "")
109+
v["sentence"] = v["sentence"].replace("(PRD)", "")
110+
v["sentence"] = v["sentence"].replace("(EVT)", "")
111+
v["sentence"] = v["sentence"].replace("「", "")
112+
v["sentence"] = v["sentence"].replace("」", "")
113+
v["sentence"] = v["sentence"].replace("『", "")
114+
v["sentence"] = v["sentence"].replace("』", "")
115+
print(f"{v}\n")
259116

260117
return result
261118

119+
# 异步任务完成回调
262120
def on_task_done(future, datas, loop, failed, successed):
263121
try:
264122
data = future.result()
@@ -270,13 +128,13 @@ def on_task_done(future, datas, loop, failed, successed):
270128
finally:
271129
print(f"正在进行第 {loop} 轮任务,成功 {len(successed)} 次 ... 失败 {len(failed)} 次 ...")
272130

131+
# 主函数
273132
async def main():
274133
loop = 0
275134
start_time = datetime.now().strftime("%Y%m%d_%H%M%S")
276135

277136
while loop < MAX_LOOP:
278137
loop = loop + 1
279-
names = set()
280138
failed = []
281139
successed = []
282140

@@ -292,10 +150,11 @@ async def main():
292150
await asyncio.gather(*tasks, return_exceptions = True)
293151

294152
# 写入本地
295-
file_path = f"dataset\\{start_time}_{MODEL.replace("/", "_").replace("-", "_")}_{loop:02d}.json"
153+
file_path = f"{OUTPUT_PATH}/{start_time}_{MODEL.replace("/", "_").replace("-", "_")}_{loop:02d}.json"
296154
with open(file_path, "w", encoding = "utf-8") as file:
297155
file.write(json.dumps(datas, indent = 4, ensure_ascii = False))
298156
print(f"第 {loop} 轮已完成,数据已写入 {file_path} ...")
299157

158+
# 入口函数
300159
if __name__ == "__main__":
301160
asyncio.run(main())

0 commit comments

Comments
 (0)