Skip to content

Commit 43c113a

Browse files
committed
add new inconsistency resolution strategies
1 parent 219b73f commit 43c113a

File tree

3 files changed

+61
-30
lines changed

3 files changed

+61
-30
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
self.chebi_graph,
8080
label_names=None,
8181
disjoint_files=self.disjoint_files,
82+
verbose=self.verbose_output,
8283
)
8384
else:
8485
self.smoother = None
@@ -191,14 +192,13 @@ def consolidate_predictions(
191192
def apply_inconsistency_resolution(
192193
self, net_score, class_names, has_valid_predictions
193194
):
194-
# todo - this could be more elegant
195195
# Smooth predictions
196196
start_time = time.perf_counter()
197197
if self.smoother is not None:
198198
self.smoother.set_label_names(class_names)
199199
smooth_net_score = self.smoother(net_score)
200200
class_decisions = (
201-
smooth_net_score > 0.5
201+
smooth_net_score > 0
202202
) & has_valid_predictions # Shape: (num_smiles, num_classes)
203203
else:
204204
class_decisions = (

chebifier/ensemble/weighted_majority_ensemble.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ def __init__(
99
self, config_path=None, weighting_strength=0.5, weighting_exponent=1.0, **kwargs
1010
):
1111
"""WMV ensemble that weights models based on their class-wise positive / negative predictive values. For each class, the weight is calculated as:
12-
weight = weighting_strength * PPV + (1 - weighting_strength)
12+
weight = (weighting_strength * PPV + (1 - weighting_strength)) ** weighting_exponent
1313
where PPV is the class-specific positive predictive value of the model on the validation set
1414
or (if the prediction is negative):
15-
weight = weighting_strength * NPV + (1 - weighting_strength)
15+
weight = (weighting_strength * NPV + (1 - weighting_strength)) ** weighting_exponent
1616
where NPV is the class-specific negative predictive value of the model on the validation set.
1717
"""
1818
super().__init__(config_path, **kwargs)
@@ -60,7 +60,7 @@ def __init__(
6060
self, config_path=None, weighting_strength=0.5, weighting_exponent=1.0, **kwargs
6161
):
6262
"""WMV ensemble that weights models based on their class-wise F1 scores. For each class, the weight is calculated as:
63-
weight = model_weight * (weighting_strength * F1 + (1 - weighting_strength))
63+
weight = model_weight * (weighting_strength * F1 + (1 - weighting_strength)) ** weighting_exponent
6464
where F1 is the class-specific F1 score ("trust") of the model on the validation set.
6565
"""
6666
super().__init__(config_path, **kwargs)

chebifier/inconsistency_resolution.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@ def get_disjoint_groups(disjoint_files):
5656
class PredictionSmoother:
5757
"""Removes implication and disjointness violations from predictions"""
5858

59-
def __init__(self, chebi_graph, label_names=None, disjoint_files=None):
59+
def __init__(
60+
self, chebi_graph, label_names=None, disjoint_files=None, verbose=False
61+
):
6062
self.chebi_graph = chebi_graph
6163
self.set_label_names(label_names)
6264
self.disjoint_groups = get_disjoint_groups(disjoint_files)
65+
self.verbose = verbose
6366

6467
def set_label_names(self, label_names):
6568
if label_names is not None:
@@ -75,43 +78,26 @@ def set_label_names(self, label_names):
7578
self.label_successors[i, self.label_names.index(p)] = 1
7679
self.label_successors = self.label_successors.unsqueeze(0)
7780

78-
def __call__(self, preds):
79-
if preds.shape[1] == 0:
80-
# no labels predicted
81-
return preds
82-
# preds shape: (n_samples, n_labels)
83-
preds_sum_orig = torch.sum(preds)
84-
# step 1: apply implications: for each class, set prediction to max of itself and all successors
81+
def resolve_subsumption_violations(self, preds):
8582
preds = preds.unsqueeze(1)
8683
preds_masked_succ = torch.where(self.label_successors, preds, 0)
8784
# preds_masked_succ shape: (n_samples, n_labels, n_labels)
85+
return preds_masked_succ.max(dim=2).values
8886

89-
preds = preds_masked_succ.max(dim=2).values
90-
if torch.sum(preds) != preds_sum_orig:
91-
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
87+
def resolve_disjointness_violations(self, preds):
9288
preds_sum_orig = torch.sum(preds)
93-
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
94-
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
89+
9590
for disj_group in self.disjoint_groups:
9691
disj_group = [
9792
self.label_names.index(g) for g in disj_group if g in self.label_names
9893
]
9994
if len(disj_group) > 1:
100-
old_preds = preds[:, disj_group]
10195
disj_max = torch.max(preds[:, disj_group], dim=1)
10296
for i, row in enumerate(preds):
10397
for l_ in range(len(preds[i])):
10498
if l_ in disj_group and l_ != disj_group[disj_max.indices[i]]:
105-
preds[i, l_] = preds_bounded[i, l_]
106-
samples_changed = 0
107-
for i, row in enumerate(preds[:, disj_group]):
108-
if any(r != o for r, o in zip(row, old_preds[i])):
109-
samples_changed += 1
110-
if samples_changed != 0:
111-
print(
112-
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
113-
)
114-
if torch.sum(preds) != preds_sum_orig:
99+
preds[i, l_] = 0
100+
if self.verbose and torch.sum(preds) != preds_sum_orig:
115101
print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}")
116102
preds_sum_orig = torch.sum(preds)
117103
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
@@ -120,6 +106,51 @@ def __call__(self, preds):
120106
torch.transpose(self.label_successors, 1, 2), preds, 1
121107
)
122108
preds = preds_masked_predec.min(dim=2).values
123-
if torch.sum(preds) != preds_sum_orig:
109+
if self.verbose and torch.sum(preds) != preds_sum_orig:
124110
print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}")
125111
return preds
112+
113+
def __call__(self, preds):
114+
if preds.shape[1] == 0:
115+
# no labels predicted
116+
return preds
117+
# preds shape: (n_samples, n_labels)
118+
preds_sum_orig = torch.sum(preds)
119+
# step 1: apply implications: for each class, set prediction to max of itself and all successors
120+
preds = self.resolve_subsumption_violations(preds)
121+
122+
if self.verbose and torch.sum(preds) != preds_sum_orig:
123+
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
124+
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
125+
preds = self.resolve_disjointness_violations(preds)
126+
return preds
127+
128+
129+
class PessimisticPredictionSmoother(PredictionSmoother):
130+
"""Always assumes the positive prediction is wrong (in case of implication violations)"""
131+
132+
def resolve_subsumption_violations(self, preds):
133+
preds = preds.unsqueeze(1)
134+
preds_masked_predec = torch.where(
135+
torch.transpose(self.label_successors, 1, 2), preds, 1
136+
)
137+
preds = preds_masked_predec.min(dim=2).values
138+
return preds
139+
140+
141+
class ScoreBasedPredictionSmoother(PredictionSmoother):
142+
"""Removes implication violations from predictions based on net scores: for A subclassOf B where score(A) > score(B), either set score(B) = max(score(B), score(A))
143+
if abs(score(A)) > abs(score(B)) or set score(A) = min(score(A), score(B)) otherwise.
144+
"""
145+
146+
def resolve_subsumption_violations(self, preds):
147+
preds = preds.unsqueeze(1)
148+
preds_masked_succ = torch.where(self.label_successors, preds, 0)
149+
preds_optimistic = preds_masked_succ.max(dim=2).values
150+
preds_masked_predec = torch.where(
151+
torch.transpose(self.label_successors, 1, 2), preds, 1
152+
)
153+
preds_pessimistic = preds_masked_predec.min(dim=2).values
154+
# take the one with the higher absolute value
155+
preds_direction = preds_optimistic - preds_pessimistic > 0
156+
return torch.where(preds_direction, preds_optimistic, preds_pessimistic)

0 commit comments

Comments
 (0)