@@ -319,20 +319,20 @@ async def traverse_graph_atomically(
319319 :param max_concurrent
320320 :return: question and answer
321321 """
322-
323322 assert traverse_strategy .qa_form == "atomic"
324323
325324 semaphore = asyncio .Semaphore (max_concurrent )
326-
327325 async def _generate_question (
328326 node_or_edge : tuple
329327 ):
330328 if len (node_or_edge ) == 2 :
331329 des = node_or_edge [0 ] + ": " + node_or_edge [1 ]['description' ]
332330 answer = node_or_edge [1 ]['description' ]
331+ loss = node_or_edge [1 ]['loss' ]
333332 else :
334333 des = node_or_edge [2 ]['description' ]
335334 answer = node_or_edge [2 ]['description' ]
335+ loss = node_or_edge [2 ]['loss' ]
336336
337337 async with semaphore :
338338 try :
@@ -356,7 +356,7 @@ async def _generate_question(
356356 compute_content_hash (question ): {
357357 "question" : question ,
358358 "answer" : answer ,
359- "loss" : - 1 ,
359+ "loss" : loss ,
360360 "difficulty" : "medium"
361361 }
362362 }
@@ -377,14 +377,14 @@ async def _generate_question(
377377 if "<SEP>" in node [1 ]['description' ]:
378378 description_list = node [1 ]['description' ].split ("<SEP>" )
379379 for item in description_list :
380- tasks .append ((node [0 ], {"description" : item }))
380+ tasks .append ((node [0 ], {"description" : item , 'loss' : node [ 1 ][ 'loss' ] }))
381381 else :
382382 tasks .append ((node [0 ], node [1 ]))
383383 for edge in edges :
384384 if "<SEP>" in edge [2 ]['description' ]:
385385 description_list = edge [2 ]['description' ].split ("<SEP>" )
386386 for item in description_list :
387- tasks .append ((edge [0 ], edge [1 ], {"description" : item }))
387+ tasks .append ((edge [0 ], edge [1 ], {"description" : item , 'loss' : edge [ 2 ][ 'loss' ] }))
388388 else :
389389 tasks .append ((edge [0 ], edge [1 ], edge [2 ]))
390390
0 commit comments