77from utils import detect_main_language , compute_content_hash , logger
88from graphgen .operators .split_graph import get_batches_with_strategy
99
10- # TODO: move to config
11- # TODO: if add isolated nodes, the loss strategy should be changed to "both"
12- loss_strategy : str = "only_edge" # only_edge, both
1310
1411async def _pre_tokenize (graph_storage : NetworkXStorage ,
1512 tokenizer : Tokenizer ,
@@ -101,22 +98,23 @@ def get_loss_tercile(losses: list) -> (float, float):
10198
10299 return losses [q1_index ], losses [q2_index ]
103100
104- def assign_difficulty (subgraphs : list , difficulty_order : list ) -> list :
101+ def assign_difficulty (subgraphs : list , difficulty_order : list , loss_strategy : str ) -> list :
105102 """
106- Assign difficulty to subgraphs based on the loss
103+ Assign difficulty to subgraphs based on the loss.
107104
108105 :param subgraphs
109106 :param difficulty_order
107+ :param loss_strategy
110108 :return
111109 """
112110 losses = []
113111 for subgraph in subgraphs :
114- loss = get_average_loss (subgraph )
112+ loss = get_average_loss (subgraph , loss_strategy )
115113 losses .append (loss )
116114 q1 , q2 = get_loss_tercile (losses )
117115
118116 for i , subgraph in enumerate (subgraphs ):
119- loss = get_average_loss (subgraph )
117+ loss = get_average_loss (subgraph , loss_strategy )
120118 if loss < q1 :
121119 # easy
122120 subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [0 ])
@@ -128,7 +126,7 @@ def assign_difficulty(subgraphs: list, difficulty_order: list) -> list:
128126 subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [2 ])
129127 return subgraphs
130128
131- def get_average_loss (batch : tuple ) -> float :
129+ def get_average_loss (batch : tuple , loss_strategy : str ) -> float :
132130 if loss_strategy == "only_edge" :
133131 return sum (edge [2 ]['loss' ] for edge in batch [1 ]) / len (batch [1 ])
134132 if loss_strategy == "both" :
@@ -242,7 +240,7 @@ async def _process_single_batch(
242240 compute_content_hash (context ): {
243241 "question" : question ,
244242 "answer" : context ,
245- "loss" : get_average_loss (_process_batch ),
243+ "loss" : get_average_loss (_process_batch , traverse_strategy . loss_strategy ),
246244 "difficulty" : _process_batch [2 ],
247245 }
248246 }
@@ -268,7 +266,7 @@ async def _process_single_batch(
268266 final_results [compute_content_hash (qa ['question' ])] = {
269267 "question" : qa ['question' ],
270268 "answer" : qa ['answer' ],
271- "loss" : get_average_loss (_process_batch ),
269+ "loss" : get_average_loss (_process_batch , traverse_strategy . loss_strategy ),
272270 "difficulty" : _process_batch [2 ],
273271 }
274272 return final_results
@@ -286,7 +284,8 @@ async def _process_single_batch(
286284 traverse_strategy
287285 )
288286
289- processing_batches = assign_difficulty (processing_batches , traverse_strategy .difficulty_order )
287+ processing_batches = assign_difficulty (processing_batches , traverse_strategy .difficulty_order ,
288+ traverse_strategy .loss_strategy )
290289
291290 for result in tqdm_async (asyncio .as_completed (
292291 [_process_single_batch (batch ) for batch in processing_batches ]
@@ -436,7 +435,8 @@ async def traverse_graph_for_multi_hop(
436435 traverse_strategy
437436 )
438437
439- processing_batches = assign_difficulty (processing_batches , traverse_strategy .difficulty_order )
438+ processing_batches = assign_difficulty (processing_batches , traverse_strategy .difficulty_order ,
439+ traverse_strategy .loss_strategy )
440440
441441 async def _process_single_batch (
442442 _process_batch : tuple
@@ -487,7 +487,7 @@ async def _process_single_batch(
487487 compute_content_hash (question ): {
488488 "question" : question ,
489489 "answer" : answer ,
490- "loss" : get_average_loss (_process_batch ),
490+ "loss" : get_average_loss (_process_batch , traverse_strategy . loss_strategy ),
491491 "difficulty" : _process_batch [2 ],
492492 }
493493 }
0 commit comments