Skip to content

Commit e8e4ec3

Browse files
committed
use class scores for smoothing, explicitly predict transitive closure for all models
1 parent 89b4812 commit e8e4ec3

File tree

4 files changed

+67
-39
lines changed

4 files changed

+67
-39
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import tqdm
66
from chebai.preprocessing.datasets.chebi import ChEBIOver50
7-
from chebai.result.analyse_sem import PredictionSmoother
7+
from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph
88

99
from chebifier.prediction_models.base_predictor import BasePredictor
1010

@@ -15,6 +15,14 @@ def __init__(self, model_configs: dict, chebi_version: int = 241):
1515
# Deferred Import: To avoid circular import error
1616
from chebifier.model_registry import MODEL_TYPES
1717

18+
self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version)
19+
self.chebi_dataset._download_required_data() # download chebi if not already downloaded
20+
self.chebi_graph = get_chebi_graph(self.chebi_dataset, None)
21+
self.disjoint_files = [
22+
os.path.join("data", "disjoint_chebi.csv"),
23+
os.path.join("data", "disjoint_additional.csv"),
24+
]
25+
1826
self.models = []
1927
self.positive_prediction_threshold = 0.5
2028
for model_name, model_config in model_configs.items():
@@ -25,17 +33,12 @@ def __init__(self, model_configs: dict, chebi_version: int = 241):
2533
else:
2634
hugging_face_kwargs = {}
2735
model_instance = model_cls(
28-
model_name, **model_config, **hugging_face_kwargs
36+
model_name, **model_config, **hugging_face_kwargs, chebi_graph=self.chebi_graph
2937
)
3038
assert isinstance(model_instance, BasePredictor)
3139
self.models.append(model_instance)
3240

33-
self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version)
34-
self.chebi_dataset._download_required_data() # download chebi if not already downloaded
35-
self.disjoint_files = [
36-
os.path.join("data", "disjoint_chebi.csv"),
37-
os.path.join("data", "disjoint_additional.csv"),
38-
]
41+
3942

4043
self.smoother = PredictionSmoother(
4144
self.chebi_dataset,
@@ -54,7 +57,7 @@ def gather_predictions(self, smiles_list):
5457
if logits_for_smiles is not None:
5558
for cls in logits_for_smiles:
5659
predicted_classes.add(cls)
57-
print("Sorting predictions...")
60+
print(f"Sorting predictions from {len(model_predictions)} models...")
5861
predicted_classes = sorted(list(predicted_classes))
5962
predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)}
6063
ordered_logits = (
@@ -75,7 +78,7 @@ def gather_predictions(self, smiles_list):
7578

7679
return ordered_logits, predicted_classes
7780

78-
def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
81+
def consolidate_predictions(self, predictions, classwise_weights, predicted_classes, **kwargs):
7982
"""
8083
Aggregates predictions from multiple models using weighted majority voting.
8184
Optimized version using tensor operations instead of for loops.
@@ -124,8 +127,17 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
124127

125128
# Determine which classes to include for each SMILES
126129
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
130+
131+
# Smooth predictions
132+
start_time = time.perf_counter()
133+
class_names = list(predicted_classes.keys())
134+
self.smoother.set_label_names(class_names)
135+
smooth_net_score = self.smoother(net_score)
136+
end_time = time.perf_counter()
137+
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
138+
127139
class_decisions = (
128-
net_score > 0
140+
smooth_net_score > 0.5
129141
) & has_valid_predictions # Shape: (num_smiles, num_classes)
130142

131143
complete_failure = torch.all(~has_valid_predictions, dim=1)
@@ -139,14 +151,16 @@ def calculate_classwise_weights(self, predicted_classes):
139151
return positive_weights, negative_weights
140152

141153
def predict_smiles_list(
142-
self, smiles_list, load_preds_if_possible=True, **kwargs
154+
self, smiles_list, load_preds_if_possible=False, **kwargs
143155
) -> list:
144156
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
145157
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
146158
if not load_preds_if_possible or not os.path.isfile(preds_file):
147159
ordered_predictions, predicted_classes = self.gather_predictions(
148160
smiles_list
149161
)
162+
if len(predicted_classes) == 0:
163+
print(f"Warning: No classes have been predicted for the given SMILES list.")
150164
# save predictions
151165
torch.save(ordered_predictions, preds_file)
152166
with open(predicted_classes_file, "w") as f:
@@ -165,15 +179,8 @@ def predict_smiles_list(
165179

166180
classwise_weights = self.calculate_classwise_weights(predicted_classes)
167181
class_decisions, is_failure = self.consolidate_predictions(
168-
ordered_predictions, classwise_weights, **kwargs
182+
ordered_predictions, classwise_weights, predicted_classes, **kwargs
169183
)
170-
# Smooth predictions
171-
start_time = time.perf_counter()
172-
class_names = list(predicted_classes.keys())
173-
self.smoother.set_label_names(class_names)
174-
class_decisions = self.smoother(class_decisions)
175-
end_time = time.perf_counter()
176-
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
177184

178185
class_names = list(predicted_classes.keys())
179186
class_indices = {predicted_classes[cls]: cls for cls in class_names}

chebifier/prediction_models/c3p_predictor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ def __init__(self, model_name: str, program_directory: Optional[Path]=None, chem
1515
super().__init__(model_name, **kwargs)
1616
self.program_directory = program_directory
1717
self.chemical_classes = chemical_classes
18+
self.chebi_graph = kwargs.get("chebi_graph", None)
1819

1920
def predict_smiles_list(self, smiles_list: list[str]) -> list:
20-
result_list = c3p_classifier.classify(smiles_list, self.program_directory, self.chemical_classes, strict=False)
21+
result_list = c3p_classifier.classify(smiles_list, self.program_directory, self.chemical_classes, strict=True)
2122
result_reformatted = [dict() for _ in range(len(smiles_list))]
2223
for result in result_list:
23-
result_reformatted[smiles_list.index(result.input_smiles)][result.class_id.split(":")[1]] = result.is_match
24-
print(f"C3P predictions for {len(smiles_list)} SMILES strings:")
25-
for i, smiles in enumerate(smiles_list):
26-
print(f"{smiles}: {result_reformatted[i]}")
24+
chebi_id = result.class_id.split(":")[1]
25+
result_reformatted[smiles_list.index(result.input_smiles)][chebi_id] = result.is_match
26+
if result.is_match and self.chebi_graph is not None:
27+
for parent in list(self.chebi_graph.predecessors(int(chebi_id))):
28+
result_reformatted[smiles_list.index(result.input_smiles)][str(parent)] = 1
2729
return result_reformatted

chebifier/prediction_models/chebi_lookup.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ def __init__(self, model_name: str, description: str = None, chebi_version: int
1010
super().__init__(model_name, **kwargs)
1111
self._description = description or "ChEBI Lookup: If the SMILES is equivalent to a ChEBI entry, retrieve the classification of that entry."
1212
self.chebi_version = chebi_version
13+
self.chebi_graph = kwargs.get("chebi_graph", None)
14+
if self.chebi_graph is None:
15+
from chebai.preprocessing.datasets.chebi import ChEBIOver50
16+
self.chebi_dataset = ChEBIOver50(chebi_version=self.chebi_version)
17+
self.chebi_dataset._download_required_data()
18+
self.chebi_graph = self.chebi_dataset._extract_class_hierarchy(
19+
os.path.join(self.chebi_dataset.raw_dir, "chebi.obo")
20+
)
1321
self.lookup_table = self.get_smiles_lookup()
1422

1523
def get_smiles_lookup(self):
@@ -26,15 +34,8 @@ def get_smiles_lookup(self):
2634

2735

2836
def build_smiles_lookup(self):
29-
# todo test
30-
from chebai.preprocessing.datasets.chebi import ChEBIOver50
31-
self.chebi_dataset = ChEBIOver50(chebi_version=self.chebi_version)
32-
self.chebi_dataset._download_required_data()
33-
chebi_graph = self.chebi_dataset._extract_class_hierarchy(
34-
os.path.join(self.chebi_dataset.raw_dir, "chebi.obo")
35-
)
3637
smiles_lookup = dict()
37-
for chebi_id, smiles in nx.get_node_attributes(chebi_graph, "smiles").items():
38+
for chebi_id, smiles in nx.get_node_attributes(self.chebi_graph, "smiles").items():
3839
if smiles is not None:
3940
try:
4041
mol = Chem.MolFromSmiles(smiles)
@@ -45,7 +46,7 @@ def build_smiles_lookup(self):
4546
if canonical_smiles not in smiles_lookup:
4647
smiles_lookup[canonical_smiles] = []
4748
# if the canonical SMILES is already in the lookup, append "different interpretation of the SMILES"
48-
smiles_lookup[canonical_smiles].append((chebi_id, list(chebi_graph.predecessors(chebi_id))))
49+
smiles_lookup[canonical_smiles].append((chebi_id, list(self.chebi_graph.predecessors(chebi_id))))
4950
except Exception as e:
5051
print(f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}: {e}")
5152
return smiles_lookup

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,24 @@ class ChemlogExtraPredictor(BasePredictor):
4545

4646
def __init__(self, model_name: str, **kwargs):
4747
super().__init__(model_name, **kwargs)
48+
self.chebi_graph = kwargs.get("chebi_graph", None)
4849
self.classifier = self.CHEMLOG_CLASSIFIER()
4950

5051
def predict_smiles_list(self, smiles_list: list[str]) -> list:
5152
mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list]
52-
return self.classifier.classify(mol_list)
53+
res = self.classifier.classify(mol_list)
54+
if self.chebi_graph is not None:
55+
for sample in res:
56+
sample_additions = dict()
57+
for cls in sample:
58+
if sample[cls] == 1:
59+
successors = list(self.chebi_graph.predecessors(int(cls)))
60+
if successors:
61+
for succ in successors:
62+
sample_additions[str(succ)] = 1
63+
sample.update(sample_additions)
64+
return res
65+
5366

5467
class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor):
5568

@@ -63,6 +76,7 @@ class ChemlogPeptidesPredictor(BasePredictor):
6376
def __init__(self, model_name: str, **kwargs):
6477
super().__init__(model_name, **kwargs)
6578
self.strategy = "algo"
79+
self.chebi_graph = kwargs.get("chebi_graph", None)
6680
self.classifier_instances = {
6781
k: v() for k, v in CLASSIFIERS[self.strategy].items()
6882
}
@@ -81,17 +95,21 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list:
8195
if mol is None:
8296
results.append(None)
8397
else:
98+
pos_labels = [label for label in self.peptide_labels if label in strategy_call(
99+
self.strategy, self.classifier_instances, mol
100+
)["chebi_classes"]]
101+
if self.chebi_graph:
102+
indirect_pos_labels = [str(pr) for label in pos_labels for pr in self.chebi_graph.predecessors(int(label))]
103+
pos_labels = list(set(pos_labels + indirect_pos_labels))
84104
results.append(
85105
{
86106
label: (
87107
1
88108
if label
89-
in strategy_call(
90-
self.strategy, self.classifier_instances, mol
91-
)["chebi_classes"]
109+
in pos_labels
92110
else 0
93111
)
94-
for label in self.peptide_labels
112+
for label in self.peptide_labels + pos_labels
95113
}
96114
)
97115

0 commit comments

Comments
 (0)