@@ -100,6 +100,33 @@ def get_loss_tercile(losses: list) -> (float, float):
100100
101101 return losses [q1_index ], losses [q2_index ]
102102
103+ def assign_difficulty (subgraphs : list , difficulty_order : list ) -> list :
104+ """
105+ Assign difficulty to subgraphs based on the loss
106+
107+ :param subgraphs
108+ :param difficulty_order
109+ :return
110+ """
111+ losses = []
112+ for subgraph in subgraphs :
113+ loss = get_average_loss (subgraph )
114+ losses .append (loss )
115+ q1 , q2 = get_loss_tercile (losses )
116+
117+ for i , subgraph in enumerate (subgraphs ):
118+ loss = get_average_loss (subgraph )
119+ if loss < q1 :
120+ # easy
121+ subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [0 ])
122+ elif loss < q2 :
123+ # medium
124+ subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [1 ])
125+ else :
126+ # hard
127+ subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [2 ])
128+ return subgraphs
129+
103130def get_average_loss (batch : tuple ) -> float :
104131 if loss_strategy == "only_edge" :
105132 return sum (edge [2 ]['loss' ] for edge in batch [1 ]) / len (batch [1 ])
@@ -258,24 +285,7 @@ async def _process_single_batch(
258285 traverse_strategy
259286 )
260287
261- losses = []
262- for batch in processing_batches :
263- loss = get_average_loss (batch )
264- losses .append (loss )
265- q1 , q2 = get_loss_tercile (losses )
266-
267- difficulty_order = traverse_strategy .difficulty_order
268- for i , batch in enumerate (processing_batches ):
269- loss = get_average_loss (batch )
270- if loss < q1 :
271- # easy
272- processing_batches [i ] = (batch [0 ], batch [1 ], difficulty_order [0 ])
273- elif loss < q2 :
274- # medium
275- processing_batches [i ] = (batch [0 ], batch [1 ], difficulty_order [1 ])
276- else :
277- # hard
278- processing_batches [i ] = (batch [0 ], batch [1 ], difficulty_order [2 ])
288+ processing_batches = assign_difficulty (processing_batches , traverse_strategy .difficulty_order )
279289
280290 for result in tqdm_async (asyncio .as_completed (
281291 [_process_single_batch (batch ) for batch in processing_batches ]
0 commit comments