diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index acb4771e..18485269 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -7,7 +7,7 @@ import torch from chebai.loss.bce_weighted import BCEWeighted -from chebai.preprocessing.datasets import XYBaseDataModule +from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index ccf120f8..19276a66 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -141,6 +141,8 @@ def get_chebi_graph(data_module, label_names): chebi_graph = data_module._extract_class_hierarchy( os.path.join(data_module.raw_dir, "chebi.obo") ) + if label_names is None: + return chebi_graph return chebi_graph.subgraph([int(n) for n in label_names]) print( f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found" @@ -196,29 +198,30 @@ class PredictionSmoother: """Removes implication and disjointness violations from predictions""" def __init__(self, dataset, label_names=None, disjoint_files=None): - if label_names: - self.label_names = label_names - else: - self.label_names = get_label_names(dataset) - self.chebi_graph = get_chebi_graph(dataset, self.label_names) + self.chebi_graph = get_chebi_graph(dataset, None) + self.set_label_names(label_names) self.disjoint_groups = get_disjoint_groups(disjoint_files) + def set_label_names(self, label_names): + if label_names is not None: + self.label_names = [int(label) for label in label_names] + chebi_subgraph = self.chebi_graph.subgraph(self.label_names) + self.label_successors = torch.zeros( + (len(self.label_names), len(self.label_names)), dtype=torch.bool + ) + for i, label in enumerate(self.label_names): + self.label_successors[i, i] = 1 + for p in chebi_subgraph.successors(label): + if p in self.label_names: + self.label_successors[i, self.label_names.index(p)] = 1 + self.label_successors = self.label_successors.unsqueeze(0) + def __call__(self, preds): preds_sum_orig = torch.sum(preds) - for i, label in enumerate(self.label_names): - succs = [ - self.label_names.index(str(p)) - for p in self.chebi_graph.successors(int(label)) - ] + [i] - if len(succs) > 0: - if torch.max(preds[:, succs], dim=1).values > 0.5 and preds[:, i] < 0.5: - print( - f"Correcting prediction for {label} to max of subclasses {list(self.chebi_graph.successors(int(label)))}" - ) - print( - f"Original pred: {preds[:, i]}, successors: {preds[:, succs]}" - ) - preds[:, i] = torch.max(preds[:, succs], dim=1).values + # step 1: apply implications: for each class, set prediction to max of itself and all successors + preds = preds.unsqueeze(1) + preds_masked_succ = torch.where(self.label_successors, preds, 0) + preds = preds_masked_succ.max(dim=2).values if torch.sum(preds) != preds_sum_orig: print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") preds_sum_orig = torch.sum(preds) @@ -226,9 +229,7 @@ def __call__(self, preds): preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49) for disj_group in self.disjoint_groups: disj_group = [ - self.label_names.index(str(g)) - for g in disj_group - if g in self.label_names + self.label_names.index(g) for g in disj_group if g in self.label_names ] if len(disj_group) > 1: old_preds = preds[:, disj_group] @@ -245,26 +246,17 @@ def __call__(self, preds): print( f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples" ) + if torch.sum(preds) != preds_sum_orig: + print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}") preds_sum_orig = torch.sum(preds) # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors - for i, label in enumerate(self.label_names): - predecessors = [i] + [ - self.label_names.index(str(p)) - for p in self.chebi_graph.predecessors(int(label)) - ] - lowest_predecessors = torch.min(preds[:, predecessors], dim=1) - preds[:, i] = lowest_predecessors.values - for idx_idx, idx in enumerate(lowest_predecessors.indices): - if idx > 0: - print( - f"class {label}: changed prediction of sample {idx_idx} to value of class " - f"{self.label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})" - ) - if torch.sum(preds) != preds_sum_orig: - print( - f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}" - ) - preds_sum_orig = torch.sum(preds) + preds = preds.unsqueeze(1) + preds_masked_predec = torch.where( + torch.transpose(self.label_successors, 1, 2), preds, 1 + ) + preds = preds_masked_predec.min(dim=2).values + if torch.sum(preds) != preds_sum_orig: + print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}") return preds