Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chebai/loss/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from chebai.loss.bce_weighted import BCEWeighted
from chebai.preprocessing.datasets import XYBaseDataModule
from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed

Expand Down
72 changes: 32 additions & 40 deletions chebai/result/analyse_sem.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def get_chebi_graph(data_module, label_names):
chebi_graph = data_module._extract_class_hierarchy(
os.path.join(data_module.raw_dir, "chebi.obo")
)
if label_names is None:
return chebi_graph
return chebi_graph.subgraph([int(n) for n in label_names])
print(
f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found"
Expand Down Expand Up @@ -196,39 +198,38 @@ class PredictionSmoother:
"""Removes implication and disjointness violations from predictions"""

def __init__(self, dataset, label_names=None, disjoint_files=None):
if label_names:
self.label_names = label_names
else:
self.label_names = get_label_names(dataset)
self.chebi_graph = get_chebi_graph(dataset, self.label_names)
self.chebi_graph = get_chebi_graph(dataset, None)
self.set_label_names(label_names)
self.disjoint_groups = get_disjoint_groups(disjoint_files)

def set_label_names(self, label_names):
if label_names is not None:
self.label_names = [int(label) for label in label_names]
chebi_subgraph = self.chebi_graph.subgraph(self.label_names)
self.label_successors = torch.zeros(
(len(self.label_names), len(self.label_names)), dtype=torch.bool
)
for i, label in enumerate(self.label_names):
self.label_successors[i, i] = 1
for p in chebi_subgraph.successors(label):
if p in self.label_names:
self.label_successors[i, self.label_names.index(p)] = 1
self.label_successors = self.label_successors.unsqueeze(0)

def __call__(self, preds):
preds_sum_orig = torch.sum(preds)
for i, label in enumerate(self.label_names):
succs = [
self.label_names.index(str(p))
for p in self.chebi_graph.successors(int(label))
] + [i]
if len(succs) > 0:
if torch.max(preds[:, succs], dim=1).values > 0.5 and preds[:, i] < 0.5:
print(
f"Correcting prediction for {label} to max of subclasses {list(self.chebi_graph.successors(int(label)))}"
)
print(
f"Original pred: {preds[:, i]}, successors: {preds[:, succs]}"
)
preds[:, i] = torch.max(preds[:, succs], dim=1).values
# step 1: apply implications: for each class, set prediction to max of itself and all successors
preds = preds.unsqueeze(1)
preds_masked_succ = torch.where(self.label_successors, preds, 0)
preds = preds_masked_succ.max(dim=2).values
if torch.sum(preds) != preds_sum_orig:
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
preds_sum_orig = torch.sum(preds)
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
for disj_group in self.disjoint_groups:
disj_group = [
self.label_names.index(str(g))
for g in disj_group
if g in self.label_names
self.label_names.index(g) for g in disj_group if g in self.label_names
]
if len(disj_group) > 1:
old_preds = preds[:, disj_group]
Expand All @@ -245,26 +246,17 @@ def __call__(self, preds):
print(
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
)
if torch.sum(preds) != preds_sum_orig:
print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}")
preds_sum_orig = torch.sum(preds)
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
for i, label in enumerate(self.label_names):
predecessors = [i] + [
self.label_names.index(str(p))
for p in self.chebi_graph.predecessors(int(label))
]
lowest_predecessors = torch.min(preds[:, predecessors], dim=1)
preds[:, i] = lowest_predecessors.values
for idx_idx, idx in enumerate(lowest_predecessors.indices):
if idx > 0:
print(
f"class {label}: changed prediction of sample {idx_idx} to value of class "
f"{self.label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})"
)
if torch.sum(preds) != preds_sum_orig:
print(
f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}"
)
preds_sum_orig = torch.sum(preds)
preds = preds.unsqueeze(1)
preds_masked_predec = torch.where(
torch.transpose(self.label_successors, 1, 2), preds, 1
)
preds = preds_masked_predec.min(dim=2).values
if torch.sum(preds) != preds_sum_orig:
print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}")
return preds


Expand Down