Skip to content

Commit 71521df

Browse files
committed
update inconsistency removal for ensemble
1 parent 5e0b683 commit 71521df

File tree

1 file changed

+66
-40
lines changed

1 file changed

+66
-40
lines changed

chebai/result/analyse_sem.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
import gc
2-
import sys
32
import traceback
43
from datetime import datetime
54
from typing import List, LiteralString
65

6+
import pandas as pd
77
from torchmetrics.functional.classification import (
88
multilabel_auroc,
99
multilabel_average_precision,
1010
multilabel_f1_score,
1111
)
12-
from utils import *
1312

1413
from chebai.loss.semantic import DisjointLoss
14+
from chebai.models import Electra
1515
from chebai.preprocessing.datasets.base import _DynamicDataset
1616
from chebai.preprocessing.datasets.chebi import ChEBIOver100
1717
from chebai.preprocessing.datasets.pubchem import PubChemKMeans
18+
from chebai.result.utils import *
1819

1920
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2021

@@ -122,7 +123,7 @@ def load_preds_labels(
122123
def get_label_names(data_module):
123124
if os.path.exists(os.path.join(data_module.processed_dir_main, "classes.txt")):
124125
with open(os.path.join(data_module.processed_dir_main, "classes.txt")) as fin:
125-
return [int(line.strip()) for line in fin]
126+
return [line.strip() for line in fin]
126127
print(
127128
f"Failed to retrieve label names, {os.path.join(data_module.processed_dir_main, 'classes.txt')} not found"
128129
)
@@ -131,70 +132,97 @@ def get_label_names(data_module):
131132

132133
def get_chebi_graph(data_module, label_names):
133134
if os.path.exists(os.path.join(data_module.raw_dir, "chebi.obo")):
134-
chebi_graph = data_module.extract_class_hierarchy(
135+
chebi_graph = data_module._extract_class_hierarchy(
135136
os.path.join(data_module.raw_dir, "chebi.obo")
136137
)
137-
return chebi_graph.subgraph(label_names)
138+
return chebi_graph.subgraph([int(n) for n in label_names])
138139
print(
139140
f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found"
140141
)
141142
return None
142143

143144

144-
def get_disjoint_groups():
145-
disjoints_owl_file = os.path.join("data", "chebi-disjoints.owl")
146-
with open(disjoints_owl_file, "r") as f:
147-
plaintext = f.read()
148-
segments = plaintext.split("<")
149-
disjoint_pairs = []
150-
left = None
151-
for seg in segments:
152-
if seg.startswith("rdf:Description ") or seg.startswith("owl:Class"):
153-
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
154-
elif seg.startswith("owl:disjointWith"):
155-
right = int(seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0])
156-
disjoint_pairs.append([left, right])
157-
158-
disjoint_groups = []
159-
for seg in plaintext.split("<rdf:Description>"):
160-
if "owl;AllDisjointClasses" in seg:
161-
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
162-
classes = [int(c.split('"')[0]) for c in classes]
163-
disjoint_groups.append(classes)
145+
def get_disjoint_groups(disjoint_files):
146+
if disjoint_files is None:
147+
disjoint_files = os.path.join("data", "chebi-disjoints.owl")
148+
disjoint_pairs, disjoint_groups = [], []
149+
for file in disjoint_files:
150+
if file.split(".")[-1] == "csv":
151+
disjoint_pairs += pd.read_csv(file, header=None).values.tolist()
152+
elif file.split(".")[-1] == "owl":
153+
with open(file, "r") as f:
154+
plaintext = f.read()
155+
segments = plaintext.split("<")
156+
disjoint_pairs = []
157+
left = None
158+
for seg in segments:
159+
if seg.startswith("rdf:Description ") or seg.startswith(
160+
"owl:Class"
161+
):
162+
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
163+
elif seg.startswith("owl:disjointWith"):
164+
right = int(
165+
seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0]
166+
)
167+
disjoint_pairs.append([left, right])
168+
169+
disjoint_groups = []
170+
for seg in plaintext.split("<rdf:Description>"):
171+
if "owl;AllDisjointClasses" in seg:
172+
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
173+
classes = [int(c.split('"')[0]) for c in classes]
174+
disjoint_groups.append(classes)
175+
else:
176+
raise NotImplementedError(
177+
"Unsupported disjoint file format: " + file.split(".")[-1]
178+
)
179+
164180
disjoint_all = disjoint_pairs + disjoint_groups
165181
# one disjointness is commented out in the owl-file
166182
# (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
167-
disjoint_all.remove([22729, 51880])
168-
print(f"Found {len(disjoint_all)} disjoint groups")
183+
if [22729, 51880] in disjoint_all:
184+
disjoint_all.remove([22729, 51880])
185+
# print(f"Found {len(disjoint_all)} disjoint groups")
169186
return disjoint_all
170187

171188

172189
class PredictionSmoother:
173190
"""Removes implication and disjointness violations from predictions"""
174191

175-
def __init__(self, dataset):
176-
self.label_names = get_label_names(dataset)
192+
def __init__(self, dataset, label_names=None, disjoint_files=None):
193+
if label_names:
194+
self.label_names = label_names
195+
else:
196+
self.label_names = get_label_names(dataset)
177197
self.chebi_graph = get_chebi_graph(dataset, self.label_names)
178-
self.disjoint_groups = get_disjoint_groups()
198+
self.disjoint_groups = get_disjoint_groups(disjoint_files)
179199

180200
def __call__(self, preds):
181-
182201
preds_sum_orig = torch.sum(preds)
183-
print(f"Preds sum: {preds_sum_orig}")
184-
# eliminate implication violations by setting each prediction to maximum of its successors
185202
for i, label in enumerate(self.label_names):
186203
succs = [
187-
self.label_names.index(p) for p in self.chebi_graph.successors(label)
204+
self.label_names.index(str(p))
205+
for p in self.chebi_graph.successors(int(label))
188206
] + [i]
189207
if len(succs) > 0:
208+
if torch.max(preds[:, succs], dim=1).values > 0.5 and preds[:, i] < 0.5:
209+
print(
210+
f"Correcting prediction for {label} to max of subclasses {list(self.chebi_graph.successors(int(label)))}"
211+
)
212+
print(
213+
f"Original pred: {preds[:, i]}, successors: {preds[:, succs]}"
214+
)
190215
preds[:, i] = torch.max(preds[:, succs], dim=1).values
191-
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
216+
if torch.sum(preds) != preds_sum_orig:
217+
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
192218
preds_sum_orig = torch.sum(preds)
193219
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
194220
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
195221
for disj_group in self.disjoint_groups:
196222
disj_group = [
197-
self.label_names.index(g) for g in disj_group if g in self.label_names
223+
self.label_names.index(str(g))
224+
for g in disj_group
225+
if g in self.label_names
198226
]
199227
if len(disj_group) > 1:
200228
old_preds = preds[:, disj_group]
@@ -211,14 +239,12 @@ def __call__(self, preds):
211239
print(
212240
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
213241
)
214-
print(
215-
f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}"
216-
)
217242
preds_sum_orig = torch.sum(preds)
218243
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
219244
for i, label in enumerate(self.label_names):
220245
predecessors = [i] + [
221-
self.label_names.index(p) for p in self.chebi_graph.predecessors(label)
246+
self.label_names.index(str(p))
247+
for p in self.chebi_graph.predecessors(int(label))
222248
]
223249
lowest_predecessors = torch.min(preds[:, predecessors], dim=1)
224250
preds[:, i] = lowest_predecessors.values

0 commit comments

Comments
 (0)