77from templates import DESCRIPTION_REPHRASING_PROMPT
88
99
10- async def quiz_relations (
10+ async def quiz (
1111 teacher_llm_client : OpenAIModel ,
1212 graph_storage : NetworkXStorage ,
1313 rephrase_storage : JsonKVStorage ,
@@ -26,16 +26,12 @@ async def quiz_relations(
2626
2727 semaphore = asyncio .Semaphore (max_concurrent )
2828
29- async def _quiz_single_relation (
30- edge : tuple ,
29+ async def _process_single_quiz (
3130 des : str ,
3231 prompt : str ,
3332 gt : str
3433 ):
3534 async with semaphore :
36- source_id = edge [0 ]
37- target_id = edge [1 ]
38-
3935 try :
4036 # 如果在rephrase_storage中已经存在,直接取出
4137 descriptions = await rephrase_storage .get_by_id (des )
@@ -49,11 +45,12 @@ async def _quiz_single_relation(
4945 return {des : [(new_description , gt )]}
5046
5147 except Exception as e : # pylint: disable=broad-except
52- logger .error ("Error when quizzing edge %s -> %s : %s" , source_id , target_id , e )
48+ logger .error ("Error when quizzing description %s: %s" , des , e )
5349 return None
5450
5551
5652 edges = await graph_storage .get_all_edges ()
53+ nodes = await graph_storage .get_all_nodes ()
5754
5855 results = defaultdict (list )
5956 tasks = []
@@ -68,19 +65,36 @@ async def _quiz_single_relation(
6865 for i in range (max_samples ):
6966 if i > 0 :
7067 tasks .append (
71- _quiz_single_relation ( edge , description ,
68+ _process_single_quiz ( description ,
7269 DESCRIPTION_REPHRASING_PROMPT [language ]['TEMPLATE' ].format (
7370 input_sentence = description ), 'yes' )
7471 )
75- tasks .append (_quiz_single_relation ( edge , description ,
72+ tasks .append (_process_single_quiz ( description ,
7673 DESCRIPTION_REPHRASING_PROMPT [language ]['ANTI_TEMPLATE' ].format (
7774 input_sentence = description ), 'no' ))
7875
76+ for node in nodes :
77+ node_data = node [1 ]
78+ description = node_data ["description" ]
79+ language = "English" if detect_main_language (description ) == "en" else "Chinese"
80+
81+ results [description ] = [(description , 'yes' )]
82+
83+ for i in range (max_samples ):
84+ if i > 0 :
85+ tasks .append (
86+ _process_single_quiz (description ,
87+ DESCRIPTION_REPHRASING_PROMPT [language ]['TEMPLATE' ].format (
88+ input_sentence = description ), 'yes' )
89+ )
90+ tasks .append (_process_single_quiz (description ,
91+ DESCRIPTION_REPHRASING_PROMPT [language ]['ANTI_TEMPLATE' ].format (
92+ input_sentence = description ), 'no' ))
7993
8094 for result in tqdm_async (
8195 asyncio .as_completed (tasks ),
8296 total = len (tasks ),
83- desc = "Quizzing relations "
97+ desc = "Quizzing descriptions "
8498 ):
8599 new_result = await result
86100 if new_result :
@@ -91,4 +105,5 @@ async def _quiz_single_relation(
91105 results [key ] = list (set (value ))
92106 await rephrase_storage .upsert ({key : results [key ]})
93107
108+
94109 return rephrase_storage
0 commit comments