11import gc
2- import os
3- import sys
42import traceback
53from datetime import datetime
6- from typing import List , LiteralString , Optional , Tuple
4+ from typing import List , LiteralString
75
8- import torch
9- import wandb
6+ import pandas as pd
107from torchmetrics .functional .classification import (
118 multilabel_auroc ,
129 multilabel_average_precision ,
1310 multilabel_f1_score ,
1411)
15- from utils import evaluate_model , get_checkpoint_from_wandb , load_results_from_buffer
1612
1713from chebai .loss .semantic import DisjointLoss
1814from chebai .models import Electra
1915from chebai .preprocessing .datasets .base import _DynamicDataset
2016from chebai .preprocessing .datasets .chebi import ChEBIOver100
21-
22- # from chebai.preprocessing.datasets.pubchem import PubChemKMeans
17+ from chebai . preprocessing . datasets . pubchem import PubChemKMeans
18+ from chebai .result . utils import *
2319
2420DEVICE = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
2521
@@ -127,7 +123,7 @@ def load_preds_labels(
127123def get_label_names (data_module ):
128124 if os .path .exists (os .path .join (data_module .processed_dir_main , "classes.txt" )):
129125 with open (os .path .join (data_module .processed_dir_main , "classes.txt" )) as fin :
130- return [int ( line .strip () ) for line in fin ]
126+ return [line .strip () for line in fin ]
131127 print (
132128 f"Failed to retrieve label names, { os .path .join (data_module .processed_dir_main , 'classes.txt' )} not found"
133129 )
@@ -136,69 +132,97 @@ def get_label_names(data_module):
136132
137133def get_chebi_graph (data_module , label_names ):
138134 if os .path .exists (os .path .join (data_module .raw_dir , "chebi.obo" )):
139- chebi_graph = data_module .extract_class_hierarchy (
135+ chebi_graph = data_module ._extract_class_hierarchy (
140136 os .path .join (data_module .raw_dir , "chebi.obo" )
141137 )
142- return chebi_graph .subgraph (label_names )
138+ return chebi_graph .subgraph ([ int ( n ) for n in label_names ] )
143139 print (
144140 f"Failed to retrieve ChEBI graph, { os .path .join (data_module .raw_dir , 'chebi.obo' )} not found"
145141 )
146142 return None
147143
148144
149- def get_disjoint_groups ():
150- disjoints_owl_file = os .path .join ("data" , "chebi-disjoints.owl" )
151- with open (disjoints_owl_file , "r" ) as f :
152- plaintext = f .read ()
153- segments = plaintext .split ("<" )
154- disjoint_pairs = []
155- left = None
156- for seg in segments :
157- if seg .startswith ("rdf:Description " ) or seg .startswith ("owl:Class" ):
158- left = int (seg .split ('rdf:about="&obo;CHEBI_' )[1 ].split ('"' )[0 ])
159- elif seg .startswith ("owl:disjointWith" ):
160- right = int (seg .split ('rdf:resource="&obo;CHEBI_' )[1 ].split ('"' )[0 ])
161- disjoint_pairs .append ([left , right ])
162-
163- disjoint_groups = []
164- for seg in plaintext .split ("<rdf:Description>" ):
165- if "owl;AllDisjointClasses" in seg :
166- classes = seg .split ('rdf:about="&obo;CHEBI_' )[1 :]
167- classes = [int (c .split ('"' )[0 ]) for c in classes ]
168- disjoint_groups .append (classes )
145+ def get_disjoint_groups (disjoint_files ):
146+ if disjoint_files is None :
147+ disjoint_files = os .path .join ("data" , "chebi-disjoints.owl" )
148+ disjoint_pairs , disjoint_groups = [], []
149+ for file in disjoint_files :
150+ if file .split ("." )[- 1 ] == "csv" :
151+ disjoint_pairs += pd .read_csv (file , header = None ).values .tolist ()
152+ elif file .split ("." )[- 1 ] == "owl" :
153+ with open (file , "r" ) as f :
154+ plaintext = f .read ()
155+ segments = plaintext .split ("<" )
156+ disjoint_pairs = []
157+ left = None
158+ for seg in segments :
159+ if seg .startswith ("rdf:Description " ) or seg .startswith (
160+ "owl:Class"
161+ ):
162+ left = int (seg .split ('rdf:about="&obo;CHEBI_' )[1 ].split ('"' )[0 ])
163+ elif seg .startswith ("owl:disjointWith" ):
164+ right = int (
165+ seg .split ('rdf:resource="&obo;CHEBI_' )[1 ].split ('"' )[0 ]
166+ )
167+ disjoint_pairs .append ([left , right ])
168+
169+ disjoint_groups = []
170+ for seg in plaintext .split ("<rdf:Description>" ):
171+ if "owl;AllDisjointClasses" in seg :
172+ classes = seg .split ('rdf:about="&obo;CHEBI_' )[1 :]
173+ classes = [int (c .split ('"' )[0 ]) for c in classes ]
174+ disjoint_groups .append (classes )
175+ else :
176+ raise NotImplementedError (
177+ "Unsupported disjoint file format: " + file .split ("." )[- 1 ]
178+ )
179+
169180 disjoint_all = disjoint_pairs + disjoint_groups
170181 # one disjointness is commented out in the owl-file
171182 # (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
172- disjoint_all .remove ([22729 , 51880 ])
173- print (f"Found { len (disjoint_all )} disjoint groups" )
183+ if [22729 , 51880 ] in disjoint_all :
184+ disjoint_all .remove ([22729 , 51880 ])
185+ # print(f"Found {len(disjoint_all)} disjoint groups")
174186 return disjoint_all
175187
176188
177189class PredictionSmoother :
178190 """Removes implication and disjointness violations from predictions"""
179191
180- def __init__ (self , dataset ):
181- self .label_names = get_label_names (dataset )
192+ def __init__ (self , dataset , label_names = None , disjoint_files = None ):
193+ if label_names :
194+ self .label_names = label_names
195+ else :
196+ self .label_names = get_label_names (dataset )
182197 self .chebi_graph = get_chebi_graph (dataset , self .label_names )
183- self .disjoint_groups = get_disjoint_groups ()
198+ self .disjoint_groups = get_disjoint_groups (disjoint_files )
184199
185200 def __call__ (self , preds ):
186201 preds_sum_orig = torch .sum (preds )
187- print (f"Preds sum: { preds_sum_orig } " )
188- # eliminate implication violations by setting each prediction to maximum of its successors
189202 for i , label in enumerate (self .label_names ):
190203 succs = [
191- self .label_names .index (p ) for p in self .chebi_graph .successors (label )
204+ self .label_names .index (str (p ))
205+ for p in self .chebi_graph .successors (int (label ))
192206 ] + [i ]
193207 if len (succs ) > 0 :
208+ if torch .max (preds [:, succs ], dim = 1 ).values > 0.5 and preds [:, i ] < 0.5 :
209+ print (
210+ f"Correcting prediction for { label } to max of subclasses { list (self .chebi_graph .successors (int (label )))} "
211+ )
212+ print (
213+ f"Original pred: { preds [:, i ]} , successors: { preds [:, succs ]} "
214+ )
194215 preds [:, i ] = torch .max (preds [:, succs ], dim = 1 ).values
195- print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
216+ if torch .sum (preds ) != preds_sum_orig :
217+ print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
196218 preds_sum_orig = torch .sum (preds )
197219 # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
198220 preds_bounded = torch .min (preds , torch .ones_like (preds ) * 0.49 )
199221 for disj_group in self .disjoint_groups :
200222 disj_group = [
201- self .label_names .index (g ) for g in disj_group if g in self .label_names
223+ self .label_names .index (str (g ))
224+ for g in disj_group
225+ if g in self .label_names
202226 ]
203227 if len (disj_group ) > 1 :
204228 old_preds = preds [:, disj_group ]
@@ -215,14 +239,12 @@ def __call__(self, preds):
215239 print (
216240 f"disjointness group { [self .label_names [d ] for d in disj_group ]} changed { samples_changed } samples"
217241 )
218- print (
219- f"Preds change after disjointness (step 2): { torch .sum (preds ) - preds_sum_orig } "
220- )
221242 preds_sum_orig = torch .sum (preds )
222243 # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
223244 for i , label in enumerate (self .label_names ):
224245 predecessors = [i ] + [
225- self .label_names .index (p ) for p in self .chebi_graph .predecessors (label )
246+ self .label_names .index (str (p ))
247+ for p in self .chebi_graph .predecessors (int (label ))
226248 ]
227249 lowest_predecessors = torch .min (preds [:, predecessors ], dim = 1 )
228250 preds [:, i ] = lowest_predecessors .values
0 commit comments