@@ -158,7 +158,7 @@ def _post_process_synthetic_data(data):
158158 return qas
159159
160160
161- async def traverse_graph_by_edge (
161+ async def traverse_graph_for_aggregated (
162162 llm_client : OpenAIModel ,
163163 tokenizer : Tokenizer ,
164164 graph_storage : NetworkXStorage ,
@@ -251,7 +251,6 @@ async def _process_single_batch(
251251 qas = _post_process_synthetic_data (content )
252252
253253 if len (qas ) == 0 :
254- print (content )
255254 logger .error (
256255 "Error occurred while processing batch, question or answer is None"
257256 )
@@ -307,7 +306,8 @@ async def _process_single_batch(
307306 return results
308307
309308
310- async def traverse_graph_atomically (
309+ # pylint: disable=too-many-branches, too-many-statements
310+ async def traverse_graph_for_atomic (
311311 llm_client : OpenAIModel ,
312312 tokenizer : Tokenizer ,
313313 graph_storage : NetworkXStorage ,
@@ -328,17 +328,28 @@ async def traverse_graph_atomically(
328328 :param max_concurrent
329329 :return: question and answer
330330 """
331- assert traverse_strategy .qa_form == "atomic"
332331
332+ assert traverse_strategy .qa_form == "atomic"
333333 semaphore = asyncio .Semaphore (max_concurrent )
334334
335+ def _parse_qa (qa : str ) -> tuple :
336+ if "Question:" in qa and "Answer:" in qa :
337+ question = qa .split ("Question:" )[1 ].split ("Answer:" )[0 ].strip ()
338+ answer = qa .split ("Answer:" )[1 ].strip ()
339+ elif "问题:" in qa and "答案:" in qa :
340+ question = qa .split ("问题:" )[1 ].split ("答案:" )[0 ].strip ()
341+ answer = qa .split ("答案:" )[1 ].strip ()
342+ else :
343+ return None , None
344+ return question .strip ('"' ), answer .strip ('"' )
345+
335346 async def _generate_question (node_or_edge : tuple ):
336347 if len (node_or_edge ) == 2 :
337348 des = node_or_edge [0 ] + ": " + node_or_edge [1 ]["description" ]
338- loss = node_or_edge [1 ]["loss" ]
349+ loss = node_or_edge [1 ]["loss" ] if "loss" in node_or_edge [ 1 ] else - 1.0
339350 else :
340351 des = node_or_edge [2 ]["description" ]
341- loss = node_or_edge [2 ]["loss" ]
352+ loss = node_or_edge [2 ]["loss" ] if "loss" in node_or_edge [ 2 ] else - 1.0
342353
343354 async with semaphore :
344355 try :
@@ -350,13 +361,8 @@ async def _generate_question(node_or_edge: tuple):
350361 )
351362 )
352363
353- if "Question:" in qa and "Answer:" in qa :
354- question = qa .split ("Question:" )[1 ].split ("Answer:" )[0 ].strip ()
355- answer = qa .split ("Answer:" )[1 ].strip ()
356- elif "问题:" in qa and "答案:" in qa :
357- question = qa .split ("问题:" )[1 ].split ("答案:" )[0 ].strip ()
358- answer = qa .split ("答案:" )[1 ].strip ()
359- else :
364+ question , answer = _parse_qa (qa )
365+ if question is None or answer is None :
360366 return {}
361367
362368 question = question .strip ('"' )
@@ -386,16 +392,18 @@ async def _generate_question(node_or_edge: tuple):
386392 if "<SEP>" in node [1 ]["description" ]:
387393 description_list = node [1 ]["description" ].split ("<SEP>" )
388394 for item in description_list :
389- tasks .append ((node [0 ], {"description" : item , "loss" : node [1 ]["loss" ]}))
395+ tasks .append ((node [0 ], {"description" : item }))
396+ if "loss" in node [1 ]:
397+ tasks [- 1 ][1 ]["loss" ] = node [1 ]["loss" ]
390398 else :
391399 tasks .append ((node [0 ], node [1 ]))
392400 for edge in edges :
393401 if "<SEP>" in edge [2 ]["description" ]:
394402 description_list = edge [2 ]["description" ].split ("<SEP>" )
395403 for item in description_list :
396- tasks .append (
397- ( edge [ 0 ], edge [ 1 ], { "description" : item , " loss": edge [2 ][ "loss" ]})
398- )
404+ tasks .append (( edge [ 0 ], edge [ 1 ], { "description" : item }))
405+ if " loss" in edge [2 ]:
406+ tasks [ - 1 ][ 2 ][ "loss" ] = edge [ 2 ][ "loss" ]
399407 else :
400408 tasks .append ((edge [0 ], edge [1 ], edge [2 ]))
401409
0 commit comments