@@ -139,16 +139,118 @@ def load_preds_labels_from_nonwandb(
139139 return preds , labels
140140
141141
142+ def get_label_names (data_module ):
143+ if os .path .exists (os .path .join (data_module .raw_dir , "classes.txt" )):
144+ with open (os .path .join (data_module .raw_dir , "classes.txt" )) as fin :
145+ return [int (line .strip ()) for line in fin ]
146+ return None
147+
148+
149+ def get_chebi_graph (data_module , label_names ):
150+ if os .path .exists (os .path .join (data_module .raw_dir , "chebi.obo" )):
151+ chebi_graph = data_module .extract_class_hierarchy (
152+ os .path .join (data_module .raw_dir , "chebi.obo" )
153+ )
154+ return chebi_graph .subgraph (label_names )
155+ return None
156+
157+
158+ def get_disjoint_groups ():
159+ disjoints_owl_file = os .path .join ("data" , "chebi-disjoints.owl" )
160+ with open (disjoints_owl_file , "r" ) as f :
161+ plaintext = f .read ()
162+ segments = plaintext .split ("<" )
163+ disjoint_pairs = []
164+ left = None
165+ for seg in segments :
166+ if seg .startswith ("rdf:Description " ) or seg .startswith ("owl:Class" ):
167+ left = int (seg .split ('rdf:about="&obo;CHEBI_' )[1 ].split ('"' )[0 ])
168+ elif seg .startswith ("owl:disjointWith" ):
169+ right = int (seg .split ('rdf:resource="&obo;CHEBI_' )[1 ].split ('"' )[0 ])
170+ disjoint_pairs .append ([left , right ])
171+
172+ disjoint_groups = []
173+ for seg in plaintext .split ("<rdf:Description>" ):
174+ if "owl;AllDisjointClasses" in seg :
175+ classes = seg .split ('rdf:about="&obo;CHEBI_' )[1 :]
176+ classes = [int (c .split ('"' )[0 ]) for c in classes ]
177+ disjoint_groups .append (classes )
178+ disjoint_all = disjoint_pairs + disjoint_groups
179+ # one disjointness is commented out in the owl-file
180+ # (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
181+ disjoint_all .remove ([22729 , 51880 ])
182+ print (f"Found { len (disjoint_all )} disjoint groups" )
183+ return disjoint_all
184+
185+
186+ def smooth_preds (preds , label_names , chebi_graph , disjoint_groups ):
187+ preds_sum_orig = torch .sum (preds )
188+ print (f"Preds sum: { preds_sum_orig } " )
189+ # eliminate implication violations by setting each prediction to maximum of its successors
190+ for i , label in enumerate (label_names ):
191+ succs = [label_names .index (p ) for p in chebi_graph .successors (label )] + [i ]
192+ if len (succs ) > 0 :
193+ preds [:, i ] = torch .max (preds [:, succs ], dim = 1 ).values
194+ print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
195+ preds_sum_orig = torch .sum (preds )
196+ # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
197+ preds_bounded = torch .min (preds , torch .ones_like (preds ) * 0.49 )
198+ for disj_group in disjoint_groups :
199+ disj_group = [label_names .index (g ) for g in disj_group if g in label_names ]
200+ if len (disj_group ) > 1 :
201+ old_preds = preds [:, disj_group ]
202+ disj_max = torch .max (preds [:, disj_group ], dim = 1 )
203+ for i , row in enumerate (preds ):
204+ for l in range (len (preds [i ])):
205+ if l in disj_group and l != disj_group [disj_max .indices [i ]]:
206+ preds [i , l ] = preds_bounded [i , l ]
207+ samples_changed = 0
208+ for i , row in enumerate (preds [:, disj_group ]):
209+ if any (r != o for r , o in zip (row , old_preds [i ])):
210+ samples_changed += 1
211+ if samples_changed != 0 :
212+ print (
213+ f"disjointness group { [label_names [d ] for d in disj_group ]} changed { samples_changed } samples"
214+ )
215+ print (
216+ f"Preds change after disjointness (step 2): { torch .sum (preds ) - preds_sum_orig } "
217+ )
218+ preds_sum_orig = torch .sum (preds )
219+ # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
220+ for i , label in enumerate (label_names ):
221+ predecessors = [i ] + [
222+ label_names .index (p ) for p in chebi_graph .predecessors (label )
223+ ]
224+ lowest_predecessors = torch .min (preds [:, predecessors ], dim = 1 )
225+ preds [:, i ] = lowest_predecessors .values
226+ for idx_idx , idx in enumerate (lowest_predecessors .indices ):
227+ if idx > 0 :
228+ print (
229+ f"class { label } : changed prediction of sample { idx_idx } to value of class "
230+ f"{ label_names [predecessors [idx ]]} ({ preds [idx_idx , i ].item ():.2f} )"
231+ )
232+ if torch .sum (preds ) != preds_sum_orig :
233+ print (
234+ f"Preds change (step 3) for { label } : { torch .sum (preds ) - preds_sum_orig } "
235+ )
236+ preds_sum_orig = torch .sum (preds )
237+ return preds
238+
239+
142240def analyse_run (
143241 preds ,
144242 labels ,
145243 df_hyperparams , # parameters that are the independent of the semantic loss function used
146244 labeled_data_cls = ChEBIOver100 , # use labels from this dataset for violations
147245 chebi_version = 231 ,
148246 results_path = os .path .join ("_semantic" , "eval_results.csv" ),
247+ violation_metrics : [str | list [callable ]] = "all" ,
248+ verbose_violation_output = False ,
149249):
150250 """Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided),
151251 saves results to csv"""
252+ if violation_metrics == "all" :
253+ violation_metrics = [product , lukasiewicz , weak , strict , binary ]
152254 data_module_labeled = labeled_data_cls (chebi_version = chebi_version )
153255 n_labels = preds .size (1 )
154256 print (f"Found { preds .shape [0 ]} predictions ({ n_labels } classes)" )
@@ -173,7 +275,7 @@ def analyse_run(
173275 del preds_exp
174276 gc .collect ()
175277
176- for i , metric in enumerate ([ product , lukasiewicz , weak , strict , binary ] ):
278+ for i , metric in enumerate (violation_metrics ):
177279 if filter_type == "impl" :
178280 df_new .append (df_hyperparams .copy ())
179281 df_new [- 1 ]["metric" ] = metric .__name__
@@ -188,6 +290,27 @@ def analyse_run(
188290 m ["fns" ] = apply_metric (
189291 metric , l_preds , 1 - r_preds if filter_type == "impl" else r_preds
190292 )
293+ if verbose_violation_output :
294+ label_names = get_label_names (data_module_labeled )
295+ print (f"Found { torch .sum (m ['fns' ])} { filter_type } -violations" )
296+ # for k, fn_cls in enumerate(m['fns']):
297+ # if fn_cls > 0:
298+ # print(f"\tThereof, {fn_cls.item()} belong to class {label_names[k]}")
299+ if torch .sum (m ["fns" ]) != 0 :
300+ fns = metric (
301+ l_preds , 1 - r_preds if filter_type == "impl" else r_preds
302+ )
303+ print (fns .shape )
304+ for k , row in enumerate (fns ):
305+ if torch .sum (row ) != 0 :
306+ print (f"{ torch .sum (row )} violations for entity { k } " )
307+ for j , violation in enumerate (row ):
308+ if violation > 0 :
309+ print (
310+ f"\t violated ({ label_names [dl_filter_l [j ]]} -> { preds [k , dl_filter_l [j ]]:.3f} "
311+ f", { label_names [dl_filter_r [j ]]} -> { preds [k , dl_filter_r [j ]]:.3f} )"
312+ )
313+
191314 m_cls = {}
192315 for key , value in m .items ():
193316 m_cls [key ] = _sort_results_by_label (
@@ -259,14 +382,23 @@ def run_all(
259382 skip_analyse = False ,
260383 skip_preds = False ,
261384 nonwandb_runs = None ,
385+ violation_metrics = "all" ,
386+ remove_violations = False ,
262387):
263388 # evaluate a list of runs on Hazardous and ChEBIOver100 datasets
264389 if datasets is None :
265390 datasets = [(Hazardous , "all" ), (ChEBIOver100 , "test" )]
266391 timestamp = datetime .now ().strftime ("%y%m%d-%H%M" )
267392 results_path = os .path .join (
268- "_semloss_eval" , f"semloss_results_pc-dis-200k_{ timestamp } .csv"
393+ "_semloss_eval" ,
394+ f"semloss_results_pc-dis-200k_{ timestamp } { '_violations_removed' if remove_violations else '' } .csv" ,
395+ )
396+ label_names = get_label_names (ChEBIOver100 (chebi_version = chebi_version ))
397+ chebi_graph = get_chebi_graph (
398+ ChEBIOver100 (chebi_version = chebi_version ), label_names
269399 )
400+ disjoint_groups = get_disjoint_groups ()
401+
270402 api = wandb .Api ()
271403 for run_id in run_ids :
272404 try :
@@ -280,17 +412,50 @@ def run_all(
280412 "data_module" : test_on .__name__ ,
281413 "chebi_version" : chebi_version ,
282414 }
283- if not skip_preds :
284- preds , labels = load_preds_labels_from_wandb (
285- run , epoch , chebi_version , test_on , kind
415+ buffer_dir_smoothed = os .path .join (
416+ "results_buffer" ,
417+ "smoothed3step" ,
418+ f"{ run .name } _ep{ epoch } " ,
419+ f"{ test_on .__name__ } _{ kind } " ,
420+ )
421+ if remove_violations and os .path .exists (
422+ os .path .join (buffer_dir_smoothed , "preds000.pt" )
423+ ):
424+ preds = torch .load (
425+ os .path .join (buffer_dir_smoothed , "preds000.pt" ), DEVICE
286426 )
427+ labels = None
287428 else :
288- buffer_dir = os .path .join (
289- "results_buffer" ,
290- f"{ run .name } _ep{ epoch } " ,
291- f"{ test_on .__name__ } _{ kind } " ,
292- )
293- preds , labels = load_results_from_buffer (buffer_dir , device = DEVICE )
429+ if not skip_preds :
430+ preds , labels = load_preds_labels_from_wandb (
431+ run , epoch , chebi_version , test_on , kind
432+ )
433+ else :
434+ buffer_dir = os .path .join (
435+ "results_buffer" ,
436+ f"{ run .name } _ep{ epoch } " ,
437+ f"{ test_on .__name__ } _{ kind } " ,
438+ )
439+ preds , labels = load_results_from_buffer (
440+ buffer_dir , device = DEVICE
441+ )
442+ assert (
443+ preds is not None
444+ ), f"Did not find predictions in dir { buffer_dir } "
445+ if remove_violations :
446+ preds = smooth_preds (
447+ preds , label_names , chebi_graph , disjoint_groups
448+ )
449+ buffer_dir_smoothed = os .path .join (
450+ "results_buffer" ,
451+ "smoothed3step" ,
452+ f"{ run .name } _ep{ epoch } " ,
453+ f"{ test_on .__name__ } _{ kind } " ,
454+ )
455+ os .makedirs (buffer_dir_smoothed , exist_ok = True )
456+ torch .save (
457+ preds , os .path .join (buffer_dir_smoothed , "preds000.pt" )
458+ )
294459 if not skip_analyse :
295460 print (
296461 f"Calculating metrics for run { run .name } on { test_on .__name__ } ({ kind } )"
@@ -301,6 +466,8 @@ def run_all(
301466 df_hyperparams = df ,
302467 chebi_version = chebi_version ,
303468 results_path = results_path ,
469+ violation_metrics = violation_metrics ,
470+ verbose_violation_output = True ,
304471 )
305472 except Exception as e :
306473 print (f"Failed for run { run_id } : { e } " )
@@ -330,6 +497,13 @@ def run_all(
330497 preds , labels = load_results_from_buffer (
331498 buffer_dir , device = DEVICE
332499 )
500+ assert (
501+ preds is not None
502+ ), f"Did not find predictions in dir { buffer_dir } "
503+ if remove_violations :
504+ preds = smooth_preds (
505+ preds , label_names , chebi_graph , disjoint_groups
506+ )
333507 if not skip_analyse :
334508 print (
335509 f"Calculating metrics for run { run_name } on { test_on .__name__ } ({ kind } )"
@@ -340,16 +514,15 @@ def run_all(
340514 df_hyperparams = df ,
341515 chebi_version = chebi_version ,
342516 results_path = results_path ,
517+ violation_metrics = violation_metrics ,
343518 )
344519 except Exception as e :
345520 print (f"Failed for run { run_name } : { e } " )
346521 print (traceback .format_exc ())
347522
348523
349524def run_semloss_eval (mode = "eval" ):
350- non_wandb_runs = (
351- []
352- ) # ("chebi100_semprodk2_weighted_v231_pc_200k_dis_24042-2000", 195)]
525+ non_wandb_runs = []
353526 if mode == "preds" :
354527 api = wandb .Api ()
355528 runs = api .runs ("chebai/chebai" , filters = {"tags" : "eval_semloss_paper" })
@@ -375,8 +548,16 @@ def run_semloss_eval(mode="eval"):
375548 "tk15yznc" ,
376549 ]
377550 baseline = ["i4wtz1k4" , "zd020wkv" , "rc1q3t49" ]
551+ k2 = ["ng3usn0p" , "rp0wwzjv" , "8fma1q7r" ]
378552 ids = baseline
379- run_all (ids , skip_preds = True , nonwandb_runs = non_wandb_runs )
553+ run_all (
554+ ids ,
555+ skip_preds = True ,
556+ nonwandb_runs = non_wandb_runs ,
557+ datasets = [(ChEBIOver100 , "test" )],
558+ violation_metrics = [binary ],
559+ remove_violations = True ,
560+ )
380561
381562
382563if __name__ == "__main__" :
0 commit comments