33
44import os
55import json
6+ import random
67import asyncio
78import argparse
89from hashlib import md5
9-
10- from .inference .devapi import gptqa
11- from .tasks .baseline_task import BaselineTask
12- import random
1310from tqdm .asyncio import tqdm as tqdm_async
1411
12+ from baselines .EntiGraph .inference .devapi import gptqa
13+ from baselines .EntiGraph .tasks .baseline_task import BaselineTask
14+
1515
1616def compute_content_hash (content , prefix : str = "" ):
1717 return prefix + md5 (content .encode ()).hexdigest ()
@@ -37,7 +37,7 @@ async def generate_entities(document_content: str,
3737 response = json .loads (completion )
3838 can_read_entities = response ['entities' ]
3939 return response
40- except Exception as e :
40+ except Exception as e : # pylint: disable=broad-except
4141 print (f"Failed to generate entities: { str (e )} " )
4242 max_tries -= 1
4343
@@ -101,17 +101,21 @@ async def generate_synthetic_data_for_document(input_file, data_type):
101101
102102 async def generate_document_entities (doc ):
103103 async with semaphore :
104- entities = await generate_entities (
105- doc .text ,
106- task .openai_system_generate_entities ,
107- model_name )
108- if not entities :
104+ try :
105+ entities = await generate_entities (
106+ doc .text ,
107+ task .openai_system_generate_entities ,
108+ model_name )
109+ if not entities :
110+ return None
111+ return {
112+ 'document' : doc .text ,
113+ 'entities' : entities ['entities' ],
114+ 'summary' : entities ['summary' ]
115+ }
116+ except Exception as e : # pylint: disable=broad-except
117+ print (f"Error: { e } " )
109118 return None
110- return {
111- 'document' : doc .text ,
112- 'entities' : entities ['entities' ],
113- 'summary' : entities ['summary' ]
114- }
115119
116120 entities_list = []
117121 for result in tqdm_async (
@@ -128,31 +132,38 @@ async def generate_document_entities(doc):
128132 for doc in entities_list :
129133 entities = doc ['entities' ]
130134 temp = []
131- for i in range ( len ( entities ) ):
135+ for i , entity_i in enumerate ( entities ):
132136 for j in range (i + 1 , len (entities )):
133- pair = (doc ['document' ], entities [i ], entities [j ])
137+ entity_j = entities [j ]
138+ pair = (doc ['document' ], entity_i , entity_j )
134139 temp .append (pair )
135140
136- # 由于数量太多,会产生很多垃圾数据,增加计算成本,因此限制同一个文档随机选10个
141+ # Compute all possible combinations of entities is impractical, so we randomly sample 10 pairs
137142 pair_list .extend (random .sample (temp , min (len (temp ), 10 )))
138143
139144
140145 async def process_two_entity_relations (pair ):
141146 async with semaphore :
142- document , entity1 , entity2 = pair
143- response = await generate_two_entity_relations (
144- document , entity1 , entity2 ,
145- task .openai_system_generate_two_entity_relations ,
146- model_name )
147- return response
147+ try :
148+ document , entity1 , entity2 = pair
149+ response = await generate_two_entity_relations (
150+ document , entity1 , entity2 ,
151+ task .openai_system_generate_two_entity_relations ,
152+ model_name )
153+ return response
154+ except Exception as e : # pylint: disable=broad-except
155+ print (f"Error: { e } " )
156+ return None
148157
149158 corpus = []
150159 for result in tqdm_async (
151160 asyncio .as_completed ([process_two_entity_relations (pair ) for pair in pair_list ]),
152161 total = len (pair_list ),
153162 desc = "Generating two entity relations"
154163 ):
155- corpus .append (await result )
164+ result = await result
165+ if result :
166+ corpus .append (result )
156167
157168 # triple_list = []
158169 # for doc in entities_list:
@@ -196,8 +207,9 @@ async def generate_qa_sft(content):
196207 ):
197208 try :
198209 result = await result
199- qa_sft_results .update (_post_process_synthetic_data (result ))
200- except Exception as e :
210+ if result :
211+ qa_sft_results .update (_post_process_synthetic_data (result ))
212+ except Exception as e : # pylint: disable=broad-except
201213 print (f"Error: { e } " )
202214
203215 return qa_sft_results
@@ -225,5 +237,5 @@ async def generate_qa_sft(content):
225237 results = loop .run_until_complete (generate_synthetic_data_for_document (args .input_file , args .data_type ))
226238
227239 # Save results
228- with open (args .output_file , "w" ) as f :
240+ with open (args .output_file , "w" , encoding = 'utf-8' ) as f :
229241 json .dump (results , f , indent = 4 , ensure_ascii = False )
0 commit comments