11import gc
2- import sys
32import traceback
43from datetime import datetime
54from typing import List , LiteralString
65
6+ import pandas as pd
77from torchmetrics .functional .classification import (
88 multilabel_auroc ,
99 multilabel_average_precision ,
1010 multilabel_f1_score ,
1111)
12- from utils import *
1312
1413from chebai .loss .semantic import DisjointLoss
14+ from chebai .models import Electra
1515from chebai .preprocessing .datasets .base import _DynamicDataset
1616from chebai .preprocessing .datasets .chebi import ChEBIOver100
1717from chebai .preprocessing .datasets .pubchem import PubChemKMeans
18+ from chebai .result .utils import *
1819
1920DEVICE = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
2021
@@ -122,7 +123,7 @@ def load_preds_labels(
122123def get_label_names (data_module ):
123124 if os .path .exists (os .path .join (data_module .processed_dir_main , "classes.txt" )):
124125 with open (os .path .join (data_module .processed_dir_main , "classes.txt" )) as fin :
125- return [int ( line .strip () ) for line in fin ]
126+ return [line .strip () for line in fin ]
126127 print (
127128 f"Failed to retrieve label names, { os .path .join (data_module .processed_dir_main , 'classes.txt' )} not found"
128129 )
@@ -131,70 +132,97 @@ def get_label_names(data_module):
131132
132133def get_chebi_graph (data_module , label_names ):
133134 if os .path .exists (os .path .join (data_module .raw_dir , "chebi.obo" )):
134- chebi_graph = data_module .extract_class_hierarchy (
135+ chebi_graph = data_module ._extract_class_hierarchy (
135136 os .path .join (data_module .raw_dir , "chebi.obo" )
136137 )
137- return chebi_graph .subgraph (label_names )
138+ return chebi_graph .subgraph ([ int ( n ) for n in label_names ] )
138139 print (
139140 f"Failed to retrieve ChEBI graph, { os .path .join (data_module .raw_dir , 'chebi.obo' )} not found"
140141 )
141142 return None
142143
143144
144- def get_disjoint_groups ():
145- disjoints_owl_file = os .path .join ("data" , "chebi-disjoints.owl" )
146- with open (disjoints_owl_file , "r" ) as f :
147- plaintext = f .read ()
148- segments = plaintext .split ("<" )
149- disjoint_pairs = []
150- left = None
151- for seg in segments :
152- if seg .startswith ("rdf:Description " ) or seg .startswith ("owl:Class" ):
153- left = int (seg .split ('rdf:about="&obo;CHEBI_' )[1 ].split ('"' )[0 ])
154- elif seg .startswith ("owl:disjointWith" ):
155- right = int (seg .split ('rdf:resource="&obo;CHEBI_' )[1 ].split ('"' )[0 ])
156- disjoint_pairs .append ([left , right ])
157-
158- disjoint_groups = []
159- for seg in plaintext .split ("<rdf:Description>" ):
160- if "owl;AllDisjointClasses" in seg :
161- classes = seg .split ('rdf:about="&obo;CHEBI_' )[1 :]
162- classes = [int (c .split ('"' )[0 ]) for c in classes ]
163- disjoint_groups .append (classes )
145+ def get_disjoint_groups (disjoint_files ):
146+ if disjoint_files is None :
147+ disjoint_files = os .path .join ("data" , "chebi-disjoints.owl" )
148+ disjoint_pairs , disjoint_groups = [], []
149+ for file in disjoint_files :
150+ if file .split ("." )[- 1 ] == "csv" :
151+ disjoint_pairs += pd .read_csv (file , header = None ).values .tolist ()
152+ elif file .split ("." )[- 1 ] == "owl" :
153+ with open (file , "r" ) as f :
154+ plaintext = f .read ()
155+ segments = plaintext .split ("<" )
156+ disjoint_pairs = []
157+ left = None
158+ for seg in segments :
159+ if seg .startswith ("rdf:Description " ) or seg .startswith (
160+ "owl:Class"
161+ ):
162+ left = int (seg .split ('rdf:about="&obo;CHEBI_' )[1 ].split ('"' )[0 ])
163+ elif seg .startswith ("owl:disjointWith" ):
164+ right = int (
165+ seg .split ('rdf:resource="&obo;CHEBI_' )[1 ].split ('"' )[0 ]
166+ )
167+ disjoint_pairs .append ([left , right ])
168+
169+ disjoint_groups = []
170+ for seg in plaintext .split ("<rdf:Description>" ):
171+ if "owl;AllDisjointClasses" in seg :
172+ classes = seg .split ('rdf:about="&obo;CHEBI_' )[1 :]
173+ classes = [int (c .split ('"' )[0 ]) for c in classes ]
174+ disjoint_groups .append (classes )
175+ else :
176+ raise NotImplementedError (
177+ "Unsupported disjoint file format: " + file .split ("." )[- 1 ]
178+ )
179+
164180 disjoint_all = disjoint_pairs + disjoint_groups
165181 # one disjointness is commented out in the owl-file
166182 # (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
167- disjoint_all .remove ([22729 , 51880 ])
168- print (f"Found { len (disjoint_all )} disjoint groups" )
183+ if [22729 , 51880 ] in disjoint_all :
184+ disjoint_all .remove ([22729 , 51880 ])
185+ # print(f"Found {len(disjoint_all)} disjoint groups")
169186 return disjoint_all
170187
171188
172189class PredictionSmoother :
173190 """Removes implication and disjointness violations from predictions"""
174191
175- def __init__ (self , dataset ):
176- self .label_names = get_label_names (dataset )
192+ def __init__ (self , dataset , label_names = None , disjoint_files = None ):
193+ if label_names :
194+ self .label_names = label_names
195+ else :
196+ self .label_names = get_label_names (dataset )
177197 self .chebi_graph = get_chebi_graph (dataset , self .label_names )
178- self .disjoint_groups = get_disjoint_groups ()
198+ self .disjoint_groups = get_disjoint_groups (disjoint_files )
179199
180200 def __call__ (self , preds ):
181-
182201 preds_sum_orig = torch .sum (preds )
183- print (f"Preds sum: { preds_sum_orig } " )
184- # eliminate implication violations by setting each prediction to maximum of its successors
185202 for i , label in enumerate (self .label_names ):
186203 succs = [
187- self .label_names .index (p ) for p in self .chebi_graph .successors (label )
204+ self .label_names .index (str (p ))
205+ for p in self .chebi_graph .successors (int (label ))
188206 ] + [i ]
189207 if len (succs ) > 0 :
208+ if torch .max (preds [:, succs ], dim = 1 ).values > 0.5 and preds [:, i ] < 0.5 :
209+ print (
210+ f"Correcting prediction for { label } to max of subclasses { list (self .chebi_graph .successors (int (label )))} "
211+ )
212+ print (
213+ f"Original pred: { preds [:, i ]} , successors: { preds [:, succs ]} "
214+ )
190215 preds [:, i ] = torch .max (preds [:, succs ], dim = 1 ).values
191- print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
216+ if torch .sum (preds ) != preds_sum_orig :
217+ print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
192218 preds_sum_orig = torch .sum (preds )
193219 # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
194220 preds_bounded = torch .min (preds , torch .ones_like (preds ) * 0.49 )
195221 for disj_group in self .disjoint_groups :
196222 disj_group = [
197- self .label_names .index (g ) for g in disj_group if g in self .label_names
223+ self .label_names .index (str (g ))
224+ for g in disj_group
225+ if g in self .label_names
198226 ]
199227 if len (disj_group ) > 1 :
200228 old_preds = preds [:, disj_group ]
@@ -211,14 +239,12 @@ def __call__(self, preds):
211239 print (
212240 f"disjointness group { [self .label_names [d ] for d in disj_group ]} changed { samples_changed } samples"
213241 )
214- print (
215- f"Preds change after disjointness (step 2): { torch .sum (preds ) - preds_sum_orig } "
216- )
217242 preds_sum_orig = torch .sum (preds )
218243 # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
219244 for i , label in enumerate (self .label_names ):
220245 predecessors = [i ] + [
221- self .label_names .index (p ) for p in self .chebi_graph .predecessors (label )
246+ self .label_names .index (str (p ))
247+ for p in self .chebi_graph .predecessors (int (label ))
222248 ]
223249 lowest_predecessors = torch .min (preds [:, predecessors ], dim = 1 )
224250 preds [:, i ] = lowest_predecessors .values
0 commit comments