From bda3ed7a998d888aeb8dd1d8a77e755a2dcfad94 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 16:06:36 +0200 Subject: [PATCH 1/3] remove print statement that fails for batch of size >1 --- chebai/result/analyse_sem.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 0aa5de31..1ac38966 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -205,13 +205,6 @@ def __call__(self, preds): for p in self.chebi_graph.successors(int(label)) ] + [i] if len(succs) > 0: - if torch.max(preds[:, succs], dim=1).values > 0.5 and preds[:, i] < 0.5: - print( - f"Correcting prediction for {label} to max of subclasses {list(self.chebi_graph.successors(int(label)))}" - ) - print( - f"Original pred: {preds[:, i]}, successors: {preds[:, succs]}" - ) preds[:, i] = torch.max(preds[:, succs], dim=1).values if torch.sum(preds) != preds_sum_orig: print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") From d58b09728c07e61e7e5eb2f4a7000ddfda611986 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 18:00:18 +0200 Subject: [PATCH 2/3] precalculate graph and successor relationship for more efficient call-time performance --- chebai/result/analyse_sem.py | 65 ++++++++++++++++++------------------ setup.py | 2 ++ 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 1ac38966..3a079373 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -135,6 +135,8 @@ def get_chebi_graph(data_module, label_names): chebi_graph = data_module._extract_class_hierarchy( os.path.join(data_module.raw_dir, "chebi.obo") ) + if label_names is None: + return chebi_graph return chebi_graph.subgraph([int(n) for n in label_names]) print( f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found" @@ -190,22 +192,30 @@ class PredictionSmoother: """Removes implication and disjointness violations from predictions""" def __init__(self, dataset, label_names=None, disjoint_files=None): - if label_names: - self.label_names = label_names - else: - self.label_names = get_label_names(dataset) - self.chebi_graph = get_chebi_graph(dataset, self.label_names) + self.chebi_graph = get_chebi_graph(dataset, None) + self.set_label_names(label_names) self.disjoint_groups = get_disjoint_groups(disjoint_files) + def set_label_names(self, label_names): + if label_names is not None: + self.label_names = [int(label) for label in label_names] + chebi_subgraph = self.chebi_graph.subgraph(self.label_names) + self.label_successors = torch.zeros( + (len(self.label_names), len(self.label_names)), dtype=torch.bool + ) + for i, label in enumerate(self.label_names): + self.label_successors[i, i] = 1 + for p in chebi_subgraph.successors(label): + if p in self.label_names: + self.label_successors[i, self.label_names.index(p)] = 1 + self.label_successors = self.label_successors.unsqueeze(0) + def __call__(self, preds): preds_sum_orig = torch.sum(preds) - for i, label in enumerate(self.label_names): - succs = [ - self.label_names.index(str(p)) - for p in self.chebi_graph.successors(int(label)) - ] + [i] - if len(succs) > 0: - preds[:, i] = torch.max(preds[:, succs], dim=1).values + # step 1: apply implications: for each class, set prediction to max of itself and all successors + preds = preds.unsqueeze(1) + preds_masked_succ = torch.where(self.label_successors, preds, 0) + preds = preds_masked_succ.max(dim=2).values if torch.sum(preds) != preds_sum_orig: print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") preds_sum_orig = torch.sum(preds) @@ -213,9 +223,7 @@ def __call__(self, preds): preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49) for disj_group in self.disjoint_groups: disj_group = [ - self.label_names.index(str(g)) - for g in disj_group - if g in self.label_names + self.label_names.index(g) for g in disj_group if g in self.label_names ] if len(disj_group) > 1: old_preds = preds[:, disj_group] @@ -232,26 +240,17 @@ def __call__(self, preds): print( f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples" ) + if torch.sum(preds) != preds_sum_orig: + print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}") preds_sum_orig = torch.sum(preds) # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors - for i, label in enumerate(self.label_names): - predecessors = [i] + [ - self.label_names.index(str(p)) - for p in self.chebi_graph.predecessors(int(label)) - ] - lowest_predecessors = torch.min(preds[:, predecessors], dim=1) - preds[:, i] = lowest_predecessors.values - for idx_idx, idx in enumerate(lowest_predecessors.indices): - if idx > 0: - print( - f"class {label}: changed prediction of sample {idx_idx} to value of class " - f"{self.label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})" - ) - if torch.sum(preds) != preds_sum_orig: - print( - f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}" - ) - preds_sum_orig = torch.sum(preds) + preds = preds.unsqueeze(1) + preds_masked_predec = torch.where( + torch.transpose(self.label_successors, 1, 2), preds, 1 + ) + preds = preds_masked_predec.min(dim=2).values + if torch.sum(preds) != preds_sum_orig: + print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}") return preds diff --git a/setup.py b/setup.py index 54a88780..230d519e 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,8 @@ license="", author="MGlauer", author_email="martin.glauer@ovgu.de", + maintainer="sfluegel05", + maintainer_email="simon.fluegel@uni-osnabrueck.de", description="", zip_safe=False, python_requires=">=3.9, <3.13", From ff6b52a99b991c04a48bd6109e907508fc9b1327 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 14 Jul 2025 17:37:18 +0200 Subject: [PATCH 3/3] fix import --- chebai/loss/semantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index acb4771e..18485269 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -7,7 +7,7 @@ import torch from chebai.loss.bce_weighted import BCEWeighted -from chebai.preprocessing.datasets import XYBaseDataModule +from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed