Skip to content

Commit 4e3c356

Browse files
committed
move smoothing logic to chebifier
1 parent 47126b8 commit 4e3c356

File tree

3 files changed

+261
-33
lines changed

3 files changed

+261
-33
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import torch
55
import tqdm
6-
from chebai.preprocessing.datasets.chebi import ChEBIOver50
7-
from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph
6+
from chebifier.inconsistency_resolution import PredictionSmoother
7+
from chebifier.utils import load_chebi_graph, get_disjoint_files
88

99
from chebifier.check_env import check_package_installed
1010
from chebifier.prediction_models.base_predictor import BasePredictor
@@ -21,32 +21,8 @@ def __init__(
2121
# Deferred Import: To avoid circular import error
2222
from chebifier.model_registry import MODEL_TYPES
2323

24-
self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version)
25-
self.chebi_dataset._download_required_data() # download chebi if not already downloaded
26-
self.chebi_graph = get_chebi_graph(self.chebi_dataset, None)
27-
local_disjoint_files = [
28-
os.path.join("data", "disjoint_chebi.csv"),
29-
os.path.join("data", "disjoint_additional.csv"),
30-
]
31-
self.disjoint_files = []
32-
for file in local_disjoint_files:
33-
if os.path.isfile(file):
34-
self.disjoint_files.append(file)
35-
else:
36-
print(
37-
f"Disjoint axiom file {file} not found. Loading from huggingface instead..."
38-
)
39-
from chebifier.hugging_face import download_model_files
40-
41-
self.disjoint_files.append(
42-
download_model_files(
43-
{
44-
"repo_id": "chebai/chebifier",
45-
"repo_type": "dataset",
46-
"files": {"disjoint_file": os.path.basename(file)},
47-
}
48-
)["disjoint_file"]
49-
)
24+
self.chebi_graph = load_chebi_graph()
25+
self.disjoint_files = get_disjoint_files()
5026

5127
self.models = []
5228
self.positive_prediction_threshold = 0.5
@@ -72,7 +48,7 @@ def __init__(
7248

7349
if resolve_inconsistencies:
7450
self.smoother = PredictionSmoother(
75-
self.chebi_dataset,
51+
self.chebi_graph,
7652
label_names=None,
7753
disjoint_files=self.disjoint_files,
7854
)
@@ -203,10 +179,11 @@ def predict_smiles_list(
203179
"Warning: No classes have been predicted for the given SMILES list."
204180
)
205181
# save predictions
206-
torch.save(ordered_predictions, preds_file)
207-
with open(predicted_classes_file, "w") as f:
208-
for cls in predicted_classes:
209-
f.write(f"{cls}\n")
182+
if load_preds_if_possible:
183+
torch.save(ordered_predictions, preds_file)
184+
with open(predicted_classes_file, "w") as f:
185+
for cls in predicted_classes:
186+
f.write(f"{cls}\n")
210187
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
211188
else:
212189
print(
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import csv
2+
import os
3+
import torch
4+
from pathlib import Path
5+
6+
def get_disjoint_groups(disjoint_files):
7+
if disjoint_files is None:
8+
disjoint_files = os.path.join("data", "chebi-disjoints.owl")
9+
disjoint_pairs, disjoint_groups = [], []
10+
for file in disjoint_files:
11+
if isinstance(file, Path):
12+
file = str(file)
13+
if file.endswith(".csv"):
14+
with open(file, "r") as f:
15+
reader = csv.reader(f)
16+
disjoint_pairs += [line for line in reader]
17+
elif file.endswith(".owl"):
18+
with open(file, "r") as f:
19+
plaintext = f.read()
20+
segments = plaintext.split("<")
21+
disjoint_pairs = []
22+
left = None
23+
for seg in segments:
24+
if seg.startswith("rdf:Description ") or seg.startswith(
25+
"owl:Class"
26+
):
27+
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
28+
elif seg.startswith("owl:disjointWith"):
29+
right = int(
30+
seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0]
31+
)
32+
disjoint_pairs.append([left, right])
33+
34+
disjoint_groups = []
35+
for seg in plaintext.split("<rdf:Description>"):
36+
if "owl;AllDisjointClasses" in seg:
37+
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
38+
classes = [int(c.split('"')[0]) for c in classes]
39+
disjoint_groups.append(classes)
40+
else:
41+
raise NotImplementedError(
42+
"Unsupported disjoint file format: " + file.split(".")[-1]
43+
)
44+
45+
disjoint_all = disjoint_pairs + disjoint_groups
46+
# one disjointness is commented out in the owl-file
47+
# (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
48+
if [22729, 51880] in disjoint_all:
49+
disjoint_all.remove([22729, 51880])
50+
# print(f"Found {len(disjoint_all)} disjoint groups")
51+
return disjoint_all
52+
53+
54+
class PredictionSmoother:
55+
"""Removes implication and disjointness violations from predictions"""
56+
57+
def __init__(self, chebi_graph, label_names=None, disjoint_files=None):
58+
self.chebi_graph = chebi_graph
59+
self.set_label_names(label_names)
60+
self.disjoint_groups = get_disjoint_groups(disjoint_files)
61+
62+
def set_label_names(self, label_names):
63+
if label_names is not None:
64+
self.label_names = label_names
65+
chebi_subgraph = self.chebi_graph.subgraph(self.label_names)
66+
self.label_successors = torch.zeros(
67+
(len(self.label_names), len(self.label_names)), dtype=torch.bool
68+
)
69+
for i, label in enumerate(self.label_names):
70+
self.label_successors[i, i] = 1
71+
for p in chebi_subgraph.successors(label):
72+
if p in self.label_names:
73+
self.label_successors[i, self.label_names.index(p)] = 1
74+
self.label_successors = self.label_successors.unsqueeze(0)
75+
76+
def __call__(self, preds):
77+
if preds.shape[1] == 0:
78+
# no labels predicted
79+
return preds
80+
# preds shape: (n_samples, n_labels)
81+
preds_sum_orig = torch.sum(preds)
82+
# step 1: apply implications: for each class, set prediction to max of itself and all successors
83+
preds = preds.unsqueeze(1)
84+
preds_masked_succ = torch.where(self.label_successors, preds, 0)
85+
# preds_masked_succ shape: (n_samples, n_labels, n_labels)
86+
87+
preds = preds_masked_succ.max(dim=2).values
88+
if torch.sum(preds) != preds_sum_orig:
89+
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
90+
preds_sum_orig = torch.sum(preds)
91+
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
92+
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
93+
for disj_group in self.disjoint_groups:
94+
disj_group = [
95+
self.label_names.index(g) for g in disj_group if g in self.label_names
96+
]
97+
if len(disj_group) > 1:
98+
old_preds = preds[:, disj_group]
99+
disj_max = torch.max(preds[:, disj_group], dim=1)
100+
for i, row in enumerate(preds):
101+
for l_ in range(len(preds[i])):
102+
if l_ in disj_group and l_ != disj_group[disj_max.indices[i]]:
103+
preds[i, l_] = preds_bounded[i, l_]
104+
samples_changed = 0
105+
for i, row in enumerate(preds[:, disj_group]):
106+
if any(r != o for r, o in zip(row, old_preds[i])):
107+
samples_changed += 1
108+
if samples_changed != 0:
109+
print(
110+
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
111+
)
112+
if torch.sum(preds) != preds_sum_orig:
113+
print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}")
114+
preds_sum_orig = torch.sum(preds)
115+
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
116+
preds = preds.unsqueeze(1)
117+
preds_masked_predec = torch.where(
118+
torch.transpose(self.label_successors, 1, 2), preds, 1
119+
)
120+
preds = preds_masked_predec.min(dim=2).values
121+
if torch.sum(preds) != preds_sum_orig:
122+
print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}")
123+
return preds

chebifier/utils.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import os
2+
3+
import networkx as nx
4+
import requests
5+
import fastobo
6+
from chebifier.hugging_face import download_model_files
7+
import pickle
8+
9+
def load_chebi_graph(filename=None):
10+
"""Load ChEBI graph from Hugging Face (if filename is None) or local file"""
11+
if filename is None:
12+
print(f"Loading ChEBI graph from Hugging Face...")
13+
file = download_model_files(
14+
{
15+
"repo_id": "chebai/chebifier",
16+
"repo_type": "dataset",
17+
"files": {"f": "chebi_graph.pkl"},
18+
}
19+
)["f"]
20+
else:
21+
print(f"Loading ChEBI graph from local {filename}...")
22+
file = filename
23+
return pickle.load(open(file, "rb"))
24+
25+
def term_callback(doc):
26+
"""Similar to the chebai function, but reduced to the necessary fields. Also, ChEBI IDs are strings"""
27+
parents = []
28+
name = None
29+
smiles = None
30+
for clause in doc:
31+
if isinstance(clause, fastobo.term.PropertyValueClause):
32+
t = clause.property_value
33+
if str(t.relation) == "http://purl.obolibrary.org/obo/chebi/smiles":
34+
assert smiles is None
35+
smiles = t.value
36+
# in older chebi versions, smiles strings are synonyms
37+
# e.g. synonym: "[F-].[Na+]" RELATED SMILES [ChEBI]
38+
elif isinstance(clause, fastobo.term.SynonymClause):
39+
if "SMILES" in clause.raw_value():
40+
assert smiles is None
41+
smiles = clause.raw_value().split('"')[1]
42+
elif isinstance(clause, fastobo.term.IsAClause):
43+
chebi_id = str(clause.term)
44+
chebi_id = chebi_id[chebi_id.index(":") + 1:]
45+
parents.append(chebi_id)
46+
elif isinstance(clause, fastobo.term.NameClause):
47+
name = str(clause.name)
48+
49+
if isinstance(clause, fastobo.term.IsObsoleteClause):
50+
if clause.obsolete:
51+
# if the term document contains clause as obsolete as true, skips this document.
52+
return False
53+
chebi_id = str(doc.id)
54+
chebi_id = chebi_id[chebi_id.index(":") + 1:]
55+
return {
56+
"id": chebi_id,
57+
"parents": parents,
58+
"name": name,
59+
"smiles": smiles,
60+
}
61+
62+
def build_chebi_graph(chebi_version=241):
63+
"""Creates a networkx graph for the ChEBI hierarchy. Usually, you don't want to call this function directly, but rather use the `load_chebi_graph` function."""
64+
chebi_path = os.path.join("data", f"chebi_v{chebi_version}", "chebi.obo")
65+
os.makedirs(os.path.join("data", f"chebi_v{chebi_version}"), exist_ok=True)
66+
if not os.path.exists(chebi_path):
67+
url = f"http://purl.obolibrary.org/obo/chebi/{chebi_version}/chebi.obo"
68+
r = requests.get(url, allow_redirects=True)
69+
open(chebi_path, "wb").write(r.content)
70+
with open(chebi_path, encoding="utf-8") as chebi:
71+
chebi = "\n".join(line for line in chebi if not line.startswith("xref:"))
72+
73+
elements = []
74+
for term_doc in fastobo.loads(chebi):
75+
if (
76+
term_doc
77+
and isinstance(term_doc.id, fastobo.id.PrefixedIdent)
78+
and term_doc.id.prefix == "CHEBI"
79+
):
80+
term_dict = term_callback(term_doc)
81+
if term_dict:
82+
elements.append(term_dict)
83+
84+
g = nx.DiGraph()
85+
for n in elements:
86+
g.add_node(n["id"], **n)
87+
88+
# Only take the edges which connect the existing nodes, to avoid internal creation of obsolete nodes
89+
# https://github.com/ChEB-AI/python-chebai/pull/55#issuecomment-2386654142
90+
g.add_edges_from(
91+
[(p, q["id"]) for q in elements for p in q["parents"] if g.has_node(p)]
92+
)
93+
return nx.transitive_closure_dag(g)
94+
95+
96+
def get_disjoint_files():
97+
"""Gets local disjointness files if they are present in the right location, otherwise downloads them from Hugging Face."""
98+
local_disjoint_files = [
99+
os.path.join("data", "disjoint_chebi.csv"),
100+
os.path.join("data", "disjoint_additional.csv"),
101+
]
102+
disjoint_files = []
103+
for file in local_disjoint_files:
104+
if os.path.isfile(file):
105+
disjoint_files.append(file)
106+
else:
107+
print(
108+
f"Disjoint axiom file {file} not found. Loading from huggingface instead..."
109+
)
110+
111+
disjoint_files.append(
112+
download_model_files(
113+
{
114+
"repo_id": "chebai/chebifier",
115+
"repo_type": "dataset",
116+
"files": {"disjoint_file": os.path.basename(file)},
117+
}
118+
)["disjoint_file"]
119+
)
120+
return disjoint_files
121+
122+
123+
if __name__ == "__main__":
124+
#chebi_graph = build_chebi_graph(chebi_version=241)
125+
# save the graph to a file
126+
#pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb"))
127+
chebi_graph = load_chebi_graph()
128+
print(chebi_graph)

0 commit comments

Comments
 (0)