@@ -56,10 +56,13 @@ def get_disjoint_groups(disjoint_files):
5656class PredictionSmoother :
5757 """Removes implication and disjointness violations from predictions"""
5858
59- def __init__ (self , chebi_graph , label_names = None , disjoint_files = None ):
59+ def __init__ (
60+ self , chebi_graph , label_names = None , disjoint_files = None , verbose = False
61+ ):
6062 self .chebi_graph = chebi_graph
6163 self .set_label_names (label_names )
6264 self .disjoint_groups = get_disjoint_groups (disjoint_files )
65+ self .verbose = verbose
6366
6467 def set_label_names (self , label_names ):
6568 if label_names is not None :
@@ -75,43 +78,26 @@ def set_label_names(self, label_names):
7578 self .label_successors [i , self .label_names .index (p )] = 1
7679 self .label_successors = self .label_successors .unsqueeze (0 )
7780
78- def __call__ (self , preds ):
79- if preds .shape [1 ] == 0 :
80- # no labels predicted
81- return preds
82- # preds shape: (n_samples, n_labels)
83- preds_sum_orig = torch .sum (preds )
84- # step 1: apply implications: for each class, set prediction to max of itself and all successors
81+ def resolve_subsumption_violations (self , preds ):
8582 preds = preds .unsqueeze (1 )
8683 preds_masked_succ = torch .where (self .label_successors , preds , 0 )
8784 # preds_masked_succ shape: (n_samples, n_labels, n_labels)
85+ return preds_masked_succ .max (dim = 2 ).values
8886
89- preds = preds_masked_succ .max (dim = 2 ).values
90- if torch .sum (preds ) != preds_sum_orig :
91- print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
87+ def resolve_disjointness_violations (self , preds ):
9288 preds_sum_orig = torch .sum (preds )
93- # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
94- preds_bounded = torch .min (preds , torch .ones_like (preds ) * 0.49 )
89+
9590 for disj_group in self .disjoint_groups :
9691 disj_group = [
9792 self .label_names .index (g ) for g in disj_group if g in self .label_names
9893 ]
9994 if len (disj_group ) > 1 :
100- old_preds = preds [:, disj_group ]
10195 disj_max = torch .max (preds [:, disj_group ], dim = 1 )
10296 for i , row in enumerate (preds ):
10397 for l_ in range (len (preds [i ])):
10498 if l_ in disj_group and l_ != disj_group [disj_max .indices [i ]]:
105- preds [i , l_ ] = preds_bounded [i , l_ ]
106- samples_changed = 0
107- for i , row in enumerate (preds [:, disj_group ]):
108- if any (r != o for r , o in zip (row , old_preds [i ])):
109- samples_changed += 1
110- if samples_changed != 0 :
111- print (
112- f"disjointness group { [self .label_names [d ] for d in disj_group ]} changed { samples_changed } samples"
113- )
114- if torch .sum (preds ) != preds_sum_orig :
99+ preds [i , l_ ] = 0
100+ if self .verbose and torch .sum (preds ) != preds_sum_orig :
115101 print (f"Preds change (step 2): { torch .sum (preds ) - preds_sum_orig } " )
116102 preds_sum_orig = torch .sum (preds )
117103 # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
@@ -120,6 +106,51 @@ def __call__(self, preds):
120106 torch .transpose (self .label_successors , 1 , 2 ), preds , 1
121107 )
122108 preds = preds_masked_predec .min (dim = 2 ).values
123- if torch .sum (preds ) != preds_sum_orig :
109+ if self . verbose and torch .sum (preds ) != preds_sum_orig :
124110 print (f"Preds change (step 3): { torch .sum (preds ) - preds_sum_orig } " )
125111 return preds
112+
113+ def __call__ (self , preds ):
114+ if preds .shape [1 ] == 0 :
115+ # no labels predicted
116+ return preds
117+ # preds shape: (n_samples, n_labels)
118+ preds_sum_orig = torch .sum (preds )
119+ # step 1: apply implications: for each class, set prediction to max of itself and all successors
120+ preds = self .resolve_subsumption_violations (preds )
121+
122+ if self .verbose and torch .sum (preds ) != preds_sum_orig :
123+ print (f"Preds change (step 1): { torch .sum (preds ) - preds_sum_orig } " )
124+ # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
125+ preds = self .resolve_disjointness_violations (preds )
126+ return preds
127+
128+
129+ class PessimisticPredictionSmoother (PredictionSmoother ):
130+ """Always assumes the positive prediction is wrong (in case of implication violations)"""
131+
132+ def resolve_subsumption_violations (self , preds ):
133+ preds = preds .unsqueeze (1 )
134+ preds_masked_predec = torch .where (
135+ torch .transpose (self .label_successors , 1 , 2 ), preds , 1
136+ )
137+ preds = preds_masked_predec .min (dim = 2 ).values
138+ return preds
139+
140+
141+ class ScoreBasedPredictionSmoother (PredictionSmoother ):
142+ """Removes implication violations from predictions based on net scores: for A subclassOf B where score(A) > score(B), either set score(B) = max(score(B), score(A))
143+ if abs(score(A)) > abs(score(B)) or set score(A) = min(score(A), score(B)) otherwise.
144+ """
145+
146+ def resolve_subsumption_violations (self , preds ):
147+ preds = preds .unsqueeze (1 )
148+ preds_masked_succ = torch .where (self .label_successors , preds , 0 )
149+ preds_optimistic = preds_masked_succ .max (dim = 2 ).values
150+ preds_masked_predec = torch .where (
151+ torch .transpose (self .label_successors , 1 , 2 ), preds , 1
152+ )
153+ preds_pessimistic = preds_masked_predec .min (dim = 2 ).values
154+ # take the one with the higher absolute value
155+ preds_direction = preds_optimistic - preds_pessimistic > 0
156+ return torch .where (preds_direction , preds_optimistic , preds_pessimistic )
0 commit comments