Skip to content

Commit f268d87

Browse files
committed
add trust calculation for c3p
1 parent c64be19 commit f268d87

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

chebifier/prediction_models/c3p_predictor.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from pathlib import Path
23
from typing import List, Optional
34

@@ -82,3 +83,91 @@ def explain_smiles(self, smiles):
8283
] + highlights
8384

8485
return {"highlights": highlights}
86+
87+
def calculate_trust(self, c3p_classes_path, output_path="c3p_trust.json"):
88+
"""Use reported confidence of C3P to calculate the trust. Use either the directly reported values or infer based on subclasses"""
89+
from c3p.classifier import PROGRAM_DIR
90+
91+
program_dir = self.program_directory or PROGRAM_DIR
92+
confusion_matrix = dict()
93+
for f in os.listdir(program_dir):
94+
if f.startswith("__"):
95+
continue
96+
with open(os.path.join(program_dir, f), encoding="utf-8") as file:
97+
txt = file.read()
98+
99+
if "__metadata__" in txt:
100+
txt = txt[txt.rindex("__metadata__") + 15 :]
101+
chebi_id = txt[
102+
txt.index("id")
103+
+ 12 : txt.index("id")
104+
+ txt[txt.index("id") :].index(",")
105+
- 1
106+
]
107+
conf = []
108+
if (
109+
chebi_id == ""
110+
or chebi_id.startswith("R")
111+
or chebi_id.startswith("oxy")
112+
):
113+
print(f, chebi_id)
114+
for name in [
115+
"num_true_positives",
116+
"num_false_positives",
117+
"num_true_negatives",
118+
"num_false_negatives",
119+
]:
120+
start_index = txt.index(name) + len(name) + 2
121+
end_index = start_index + txt[start_index:].index(",")
122+
try:
123+
number = int(txt[start_index:end_index])
124+
except ValueError:
125+
print(
126+
"Failed to read value near ",
127+
txt[start_index - 17 : end_index + 5],
128+
)
129+
number = 0
130+
conf.append(number)
131+
confusion_matrix[chebi_id] = {
132+
"TP": conf[0],
133+
"FP": conf[1],
134+
"TN": conf[2],
135+
"FN": conf[3],
136+
}
137+
else:
138+
print(f"Couldnt find metadata in {f}")
139+
140+
# for classes where c3p doesn't have a number, take the sum of the subclasses
141+
new_confusion = dict()
142+
for cls in confusion_matrix:
143+
for parent in self.chebi_graph.predecessors(cls):
144+
if parent not in confusion_matrix:
145+
if parent not in new_confusion:
146+
new_confusion[parent] = {"TP": 0, "FP": 0, "TN": 0, "FN": 0}
147+
new_confusion[parent]["TP"] += confusion_matrix[cls]["TP"]
148+
new_confusion[parent]["FP"] += confusion_matrix[cls]["FP"]
149+
new_confusion[parent]["TN"] += confusion_matrix[cls]["TN"]
150+
new_confusion[parent]["FN"] += confusion_matrix[cls]["FN"]
151+
152+
import json
153+
154+
confusion_matrix = {**confusion_matrix, **new_confusion}
155+
print(
156+
f"After adding parent classes, confusion matrix contains {len(confusion_matrix)} classes ({len(new_confusion)} indirect)"
157+
)
158+
json.dump(confusion_matrix, open(output_path, "w+"))
159+
160+
161+
if __name__ == "__main__":
162+
import os
163+
164+
from chebifier.utils import load_chebi_graph
165+
166+
chebi_graph = load_chebi_graph()
167+
predictor = C3PPredictor(
168+
"demo",
169+
program_directory=os.path.join("..", "c3p", "c3p", "programs"),
170+
chebi_graph=chebi_graph,
171+
)
172+
print(predictor.predict_smiles_list(["CO", "CO"]))
173+
# predictor.calculate_trust(os.path.join("..", "ensemble-eval", "ensemble_eval_model_preds", "c3p_classes.txt"), "c3p_trust_new.json")

0 commit comments

Comments
 (0)