Skip to content

Commit cac7942

Browse files
refactor(baselines): refact some logics
1 parent 74bfcce commit cac7942

File tree

5 files changed

+53
-123
lines changed

5 files changed

+53
-123
lines changed

baselines/EntiGraph/entigraph.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33

44
import os
55
import json
6+
import random
67
import asyncio
78
import argparse
89
from hashlib import md5
9-
10-
from .inference.devapi import gptqa
11-
from .tasks.baseline_task import BaselineTask
12-
import random
1310
from tqdm.asyncio import tqdm as tqdm_async
1411

12+
from baselines.EntiGraph.inference.devapi import gptqa
13+
from baselines.EntiGraph.tasks.baseline_task import BaselineTask
14+
1515

1616
def compute_content_hash(content, prefix: str = ""):
1717
return prefix + md5(content.encode()).hexdigest()
@@ -37,7 +37,7 @@ async def generate_entities(document_content: str,
3737
response = json.loads(completion)
3838
can_read_entities = response['entities']
3939
return response
40-
except Exception as e:
40+
except Exception as e: # pylint: disable=broad-except
4141
print(f"Failed to generate entities: {str(e)}")
4242
max_tries -= 1
4343

@@ -101,17 +101,21 @@ async def generate_synthetic_data_for_document(input_file, data_type):
101101

102102
async def generate_document_entities(doc):
103103
async with semaphore:
104-
entities = await generate_entities(
105-
doc.text,
106-
task.openai_system_generate_entities,
107-
model_name)
108-
if not entities:
104+
try:
105+
entities = await generate_entities(
106+
doc.text,
107+
task.openai_system_generate_entities,
108+
model_name)
109+
if not entities:
110+
return None
111+
return {
112+
'document': doc.text,
113+
'entities': entities['entities'],
114+
'summary': entities['summary']
115+
}
116+
except Exception as e: # pylint: disable=broad-except
117+
print(f"Error: {e}")
109118
return None
110-
return {
111-
'document': doc.text,
112-
'entities': entities['entities'],
113-
'summary': entities['summary']
114-
}
115119

116120
entities_list = []
117121
for result in tqdm_async(
@@ -128,31 +132,38 @@ async def generate_document_entities(doc):
128132
for doc in entities_list:
129133
entities = doc['entities']
130134
temp = []
131-
for i in range(len(entities)):
135+
for i, entity_i in enumerate(entities):
132136
for j in range(i + 1, len(entities)):
133-
pair = (doc['document'], entities[i], entities[j])
137+
entity_j = entities[j]
138+
pair = (doc['document'], entity_i, entity_j)
134139
temp.append(pair)
135140

136-
# 由于数量太多,会产生很多垃圾数据,增加计算成本,因此限制同一个文档随机选10个
141+
# Compute all possible combinations of entities is impractical, so we randomly sample 10 pairs
137142
pair_list.extend(random.sample(temp, min(len(temp), 10)))
138143

139144

140145
async def process_two_entity_relations(pair):
141146
async with semaphore:
142-
document, entity1, entity2 = pair
143-
response = await generate_two_entity_relations(
144-
document, entity1, entity2,
145-
task.openai_system_generate_two_entity_relations,
146-
model_name)
147-
return response
147+
try:
148+
document, entity1, entity2 = pair
149+
response = await generate_two_entity_relations(
150+
document, entity1, entity2,
151+
task.openai_system_generate_two_entity_relations,
152+
model_name)
153+
return response
154+
except Exception as e: # pylint: disable=broad-except
155+
print(f"Error: {e}")
156+
return None
148157

149158
corpus= []
150159
for result in tqdm_async(
151160
asyncio.as_completed([process_two_entity_relations(pair) for pair in pair_list]),
152161
total=len(pair_list),
153162
desc="Generating two entity relations"
154163
):
155-
corpus.append(await result)
164+
result = await result
165+
if result:
166+
corpus.append(result)
156167

157168
# triple_list = []
158169
# for doc in entities_list:
@@ -196,8 +207,9 @@ async def generate_qa_sft(content):
196207
):
197208
try:
198209
result = await result
199-
qa_sft_results.update(_post_process_synthetic_data(result))
200-
except Exception as e:
210+
if result:
211+
qa_sft_results.update(_post_process_synthetic_data(result))
212+
except Exception as e: # pylint: disable=broad-except
201213
print(f"Error: {e}")
202214

203215
return qa_sft_results
@@ -225,5 +237,5 @@ async def generate_qa_sft(content):
225237
results = loop.run_until_complete(generate_synthetic_data_for_document(args.input_file, args.data_type))
226238

227239
# Save results
228-
with open(args.output_file, "w") as f:
240+
with open(args.output_file, "w", encoding='utf-8') as f:
229241
json.dump(results, f, indent=4, ensure_ascii=False)

baselines/EntiGraph/entigraph_utils/io_utils.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

baselines/EntiGraph/entigraph_utils/prompt_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pylint: disable=C0301
2+
13
QUALITY_FEW_SHOT_COT_PROMPT = """## Example 1
24
### Question
35
In the context of "Les Misérables", written by Victor Hugo in 1862, what is the main setting of the novel? There is only one correct choice.

baselines/EntiGraph/inference/devapi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from openai import AsyncOpenAI
2-
import dotenv
31
import os
2+
import dotenv
3+
from openai import AsyncOpenAI
44

55
dotenv.load_dotenv()
66

baselines/EntiGraph/tasks/baseline_task.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Rewrite from https://github.com/ZitongYang/Synthetic_Continued_Pretraining/blob/main/tasks/quality.py
2+
23
import json
34
from hashlib import md5
45

5-
from .task_abc import Document, Task
6+
from baselines.EntiGraph.tasks.task_abc import Document, Task
67
from baselines.EntiGraph.entigraph_utils.prompt_utils import (
78
OPENAI_API_SYSTEM_QUALITY_GENERATE_ENTITIES,
89
OPENAI_API_SYSTEM_QUALITY_GENERATE_TWO_ENTITY_RELATIONS,
@@ -16,14 +17,19 @@ class BaselineTask(Task):
1617
openai_system_quality_qa_sft = OPENAI_API_SYSTEM_QUALITY_QA_SFT
1718
llama_cot_prompt = QUALITY_FEW_SHOT_COT_PROMPT
1819

20+
def __init__(self, input_file: str, data_type: str):
21+
self._data = self._load_split(input_file, data_type)
22+
self._create_documents()
23+
self._dedup()
24+
1925
@staticmethod
2026
def _load_split(input_file: str, data_type: str):
2127
if data_type == 'raw':
22-
with open(input_file, "r") as f:
28+
with open(input_file, "r", encoding='utf-8') as f:
2329
data = [json.loads(line) for line in f]
2430
data = [[chunk] for chunk in data]
2531
elif data_type == 'chunked':
26-
with open(input_file, "r") as f:
32+
with open(input_file, "r", encoding='utf-8') as f:
2733
data = json.load(f)
2834

2935
documents = []
@@ -49,11 +55,6 @@ def _dedup(self):
4955
self.documents = list(deuped_documents.values())
5056

5157

52-
def __init__(self, input_file: str, data_type: str):
53-
self._data = self._load_split(input_file, data_type)
54-
self._create_documents()
55-
self._dedup()
56-
5758
def performance_stats(self):
5859
pass
5960

0 commit comments

Comments
 (0)