@@ -141,6 +141,8 @@ def get_chebi_graph(data_module, label_names):
141141 chebi_graph = data_module ._extract_class_hierarchy (
142142 os .path .join (data_module .raw_dir , "chebi.obo" )
143143 )
144+ if label_names is None :
145+ return chebi_graph
144146 return chebi_graph .subgraph ([int (n ) for n in label_names ])
145147 print (
146148 f"Failed to retrieve ChEBI graph, { os .path .join (data_module .raw_dir , 'chebi.obo' )} not found"
@@ -196,39 +198,38 @@ class PredictionSmoother:
196198 """Removes implication and disjointness violations from predictions"""
197199
198200 def __init__ (self , dataset , label_names = None , disjoint_files = None ):
199- if label_names :
200- self .label_names = label_names
201- else :
202- self .label_names = get_label_names (dataset )
203- self .chebi_graph = get_chebi_graph (dataset , self .label_names )
201+ self .chebi_graph = get_chebi_graph (dataset , None )
202+ self .set_label_names (label_names )
204203 self .disjoint_groups = get_disjoint_groups (disjoint_files )
205204
205+ def set_label_names (self , label_names ):
206+ if label_names is not None :
207+ self .label_names = [int (label ) for label in label_names ]
208+ chebi_subgraph = self .chebi_graph .subgraph (self .label_names )
209+ self .label_successors = torch .zeros (
210+ (len (self .label_names ), len (self .label_names )), dtype = torch .bool
211+ )
212+ for i , label in enumerate (self .label_names ):
213+ self .label_successors [i , i ] = 1
214+ for p in chebi_subgraph .successors (label ):
215+ if p in self .label_names :
216+ self .label_successors [i , self .label_names .index (p )] = 1
217+ self .label_successors = self .label_successors .unsqueeze (0 )
218+
206219 def __call__ (self , preds ):
207220 preds_sum_orig = torch .sum (preds )
208- for i , label in enumerate (self .label_names ):
209- succs = [
210- self .label_names .index (str (p ))
211- for p in self .chebi_graph .successors (int (label ))
212- ] + [i ]
213- if len (succs ) > 0 :
214- if torch .max (preds [:, succs ], dim = 1 ).values > 0.5 and preds [:, i ] < 0.5 :
215- print (
216- f"Correcting prediction for { label } to max of subclasses { list (self .chebi_graph .successors (int (label )))} "
217- )
218- print (
219- f"Original pred: { preds [:, i ]} , successors: { preds [:, succs ]} "
220- )
221- preds [:, i ] = torch .max (preds [:, succs ], dim = 1 ).values
221+ # step 1: apply implications: for each class, set prediction to max of itself and all successors
222+ preds = preds .unsqueeze (1 )
223+ preds_masked_succ = torch .where (self .label_successors , preds , 0 )
224+ preds = preds_masked_succ .max (dim = 2 ).values
222225 if torch .sum (preds ) != preds_sum_orig :
223226 print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
224227 preds_sum_orig = torch .sum (preds )
225228 # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
226229 preds_bounded = torch .min (preds , torch .ones_like (preds ) * 0.49 )
227230 for disj_group in self .disjoint_groups :
228231 disj_group = [
229- self .label_names .index (str (g ))
230- for g in disj_group
231- if g in self .label_names
232+ self .label_names .index (g ) for g in disj_group if g in self .label_names
232233 ]
233234 if len (disj_group ) > 1 :
234235 old_preds = preds [:, disj_group ]
@@ -245,26 +246,17 @@ def __call__(self, preds):
245246 print (
246247 f"disjointness group { [self .label_names [d ] for d in disj_group ]} changed { samples_changed } samples"
247248 )
249+ if torch .sum (preds ) != preds_sum_orig :
250+ print (f"Preds change (step 2): { torch .sum (preds ) - preds_sum_orig } " )
248251 preds_sum_orig = torch .sum (preds )
249252 # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
250- for i , label in enumerate (self .label_names ):
251- predecessors = [i ] + [
252- self .label_names .index (str (p ))
253- for p in self .chebi_graph .predecessors (int (label ))
254- ]
255- lowest_predecessors = torch .min (preds [:, predecessors ], dim = 1 )
256- preds [:, i ] = lowest_predecessors .values
257- for idx_idx , idx in enumerate (lowest_predecessors .indices ):
258- if idx > 0 :
259- print (
260- f"class { label } : changed prediction of sample { idx_idx } to value of class "
261- f"{ self .label_names [predecessors [idx ]]} ({ preds [idx_idx , i ].item ():.2f} )"
262- )
263- if torch .sum (preds ) != preds_sum_orig :
264- print (
265- f"Preds change (step 3) for { label } : { torch .sum (preds ) - preds_sum_orig } "
266- )
267- preds_sum_orig = torch .sum (preds )
253+ preds = preds .unsqueeze (1 )
254+ preds_masked_predec = torch .where (
255+ torch .transpose (self .label_successors , 1 , 2 ), preds , 1
256+ )
257+ preds = preds_masked_predec .min (dim = 2 ).values
258+ if torch .sum (preds ) != preds_sum_orig :
259+ print (f"Preds change (step 3): { torch .sum (preds ) - preds_sum_orig } " )
268260 return preds
269261
270262
0 commit comments