@@ -135,7 +135,9 @@ def get_average_loss(batch: tuple, loss_strategy: str) -> float:
135135 ) / (len (batch [0 ]) + len (batch [1 ]))
136136 raise ValueError ("Invalid loss strategy" )
137137 except Exception as e : # pylint: disable=broad-except
138- logger .error ("Error calculating average loss: %s" , e )
138+ logger .warning (
139+ "Loss not found in some nodes or edges, setting loss to -1.0: %s" , e
140+ )
139141 return - 1.0
140142
141143
@@ -158,7 +160,7 @@ def _post_process_synthetic_data(data):
158160 return qas
159161
160162
161- async def traverse_graph_by_edge (
163+ async def traverse_graph_for_aggregated (
162164 llm_client : OpenAIModel ,
163165 tokenizer : Tokenizer ,
164166 graph_storage : NetworkXStorage ,
@@ -251,7 +253,6 @@ async def _process_single_batch(
251253 qas = _post_process_synthetic_data (content )
252254
253255 if len (qas ) == 0 :
254- print (content )
255256 logger .error (
256257 "Error occurred while processing batch, question or answer is None"
257258 )
@@ -307,7 +308,8 @@ async def _process_single_batch(
307308 return results
308309
309310
310- async def traverse_graph_atomically (
311+ # pylint: disable=too-many-branches, too-many-statements
312+ async def traverse_graph_for_atomic (
311313 llm_client : OpenAIModel ,
312314 tokenizer : Tokenizer ,
313315 graph_storage : NetworkXStorage ,
@@ -328,17 +330,28 @@ async def traverse_graph_atomically(
328330 :param max_concurrent
329331 :return: question and answer
330332 """
331- assert traverse_strategy .qa_form == "atomic"
332333
334+ assert traverse_strategy .qa_form == "atomic"
333335 semaphore = asyncio .Semaphore (max_concurrent )
334336
337+ def _parse_qa (qa : str ) -> tuple :
338+ if "Question:" in qa and "Answer:" in qa :
339+ question = qa .split ("Question:" )[1 ].split ("Answer:" )[0 ].strip ()
340+ answer = qa .split ("Answer:" )[1 ].strip ()
341+ elif "问题:" in qa and "答案:" in qa :
342+ question = qa .split ("问题:" )[1 ].split ("答案:" )[0 ].strip ()
343+ answer = qa .split ("答案:" )[1 ].strip ()
344+ else :
345+ return None , None
346+ return question .strip ('"' ), answer .strip ('"' )
347+
335348 async def _generate_question (node_or_edge : tuple ):
336349 if len (node_or_edge ) == 2 :
337350 des = node_or_edge [0 ] + ": " + node_or_edge [1 ]["description" ]
338- loss = node_or_edge [1 ]["loss" ]
351+ loss = node_or_edge [1 ]["loss" ] if "loss" in node_or_edge [ 1 ] else - 1.0
339352 else :
340353 des = node_or_edge [2 ]["description" ]
341- loss = node_or_edge [2 ]["loss" ]
354+ loss = node_or_edge [2 ]["loss" ] if "loss" in node_or_edge [ 2 ] else - 1.0
342355
343356 async with semaphore :
344357 try :
@@ -350,13 +363,8 @@ async def _generate_question(node_or_edge: tuple):
350363 )
351364 )
352365
353- if "Question:" in qa and "Answer:" in qa :
354- question = qa .split ("Question:" )[1 ].split ("Answer:" )[0 ].strip ()
355- answer = qa .split ("Answer:" )[1 ].strip ()
356- elif "问题:" in qa and "答案:" in qa :
357- question = qa .split ("问题:" )[1 ].split ("答案:" )[0 ].strip ()
358- answer = qa .split ("答案:" )[1 ].strip ()
359- else :
366+ question , answer = _parse_qa (qa )
367+ if question is None or answer is None :
360368 return {}
361369
362370 question = question .strip ('"' )
@@ -386,16 +394,18 @@ async def _generate_question(node_or_edge: tuple):
386394 if "<SEP>" in node [1 ]["description" ]:
387395 description_list = node [1 ]["description" ].split ("<SEP>" )
388396 for item in description_list :
389- tasks .append ((node [0 ], {"description" : item , "loss" : node [1 ]["loss" ]}))
397+ tasks .append ((node [0 ], {"description" : item }))
398+ if "loss" in node [1 ]:
399+ tasks [- 1 ][1 ]["loss" ] = node [1 ]["loss" ]
390400 else :
391401 tasks .append ((node [0 ], node [1 ]))
392402 for edge in edges :
393403 if "<SEP>" in edge [2 ]["description" ]:
394404 description_list = edge [2 ]["description" ].split ("<SEP>" )
395405 for item in description_list :
396- tasks .append (
397- ( edge [ 0 ], edge [ 1 ], { "description" : item , " loss": edge [2 ][ "loss" ]})
398- )
406+ tasks .append (( edge [ 0 ], edge [ 1 ], { "description" : item }))
407+ if " loss" in edge [2 ]:
408+ tasks [ - 1 ][ 2 ][ "loss" ] = edge [ 2 ][ "loss" ]
399409 else :
400410 tasks .append ((edge [0 ], edge [1 ], edge [2 ]))
401411
0 commit comments