Skip to content

Commit e4f1c54

Browse files
committed
add cache
1 parent 4d918d5 commit e4f1c54

File tree

6 files changed

+79
-51
lines changed

6 files changed

+79
-51
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph
88

99
from chebifier.prediction_models.base_predictor import BasePredictor
10-
10+
from functools import lru_cache
1111

1212
class BaseEnsemble:
1313

chebifier/prediction_models/base_predictor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
from abc import ABC
33

4+
from functools import lru_cache
5+
46

57
class BasePredictor(ABC):
68
def __init__(
@@ -22,7 +24,16 @@ def __init__(
2224
self._description = kwargs.get("description", None)
2325

2426
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
25-
raise NotImplementedError
27+
# list is not hashable, so we convert it to a tuple (useful for caching)
28+
return self.predict_smiles_tuple(tuple(smiles_list))
29+
30+
@lru_cache(maxsize=100)
31+
def predict_smiles_tuple(self, smiles_tuple: tuple[str]) -> dict:
32+
raise NotImplementedError()
33+
34+
def predict_smiles(self, smiles: str) -> dict:
35+
# by default, use list-based prediction
36+
return self.predict_smiles_tuple((smiles,))[0]
2637

2738
@property
2839
def info_text(self):

chebifier/prediction_models/c3p_predictor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import lru_cache
12
from typing import Optional, List
23
from pathlib import Path
34

@@ -17,8 +18,9 @@ def __init__(self, model_name: str, program_directory: Optional[Path]=None, chem
1718
self.chemical_classes = chemical_classes
1819
self.chebi_graph = kwargs.get("chebi_graph", None)
1920

20-
def predict_smiles_list(self, smiles_list: list[str]) -> list:
21-
result_list = c3p_classifier.classify(smiles_list, self.program_directory, self.chemical_classes, strict=False)
21+
@lru_cache(maxsize=100)
22+
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
23+
result_list = c3p_classifier.classify(list(smiles_list), self.program_directory, self.chemical_classes, strict=False)
2224
result_reformatted = [dict() for _ in range(len(smiles_list))]
2325
for result in result_list:
2426
chebi_id = result.class_id.split(":")[1]

chebifier/prediction_models/chebi_lookup.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from functools import lru_cache
2+
from typing import Optional
3+
14
from chebifier.prediction_models import BasePredictor
25
import os
36
import networkx as nx
@@ -51,32 +54,36 @@ def build_smiles_lookup(self):
5154
print(f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}: {e}")
5255
return smiles_lookup
5356

57+
@lru_cache(maxsize=100)
58+
def predict_smiles(self, smiles: str) -> Optional[dict]:
59+
if not smiles:
60+
return None
61+
mol = Chem.MolFromSmiles(smiles)
62+
if mol is None:
63+
return None
64+
canonical_smiles = Chem.MolToSmiles(mol)
65+
if canonical_smiles in self.lookup_table:
66+
parent_candidates = self.lookup_table[canonical_smiles]
67+
preds_i = dict()
68+
if len(parent_candidates) > 1:
69+
print(
70+
f"Multiple matches found in ChEBI for SMILES {smiles}: {', '.join(str(chebi_id) for chebi_id, _ in parent_candidates)}")
71+
for k in list(set(pp for _, p in parent_candidates for pp in p)):
72+
preds_i[str(k)] = 1
73+
elif len(parent_candidates) == 1:
74+
chebi_id, parents = parent_candidates[0]
75+
for k in parents:
76+
preds_i[str(k)] = 1
77+
else:
78+
preds_i = None
79+
return preds_i
80+
else:
81+
return None
5482

55-
def predict_smiles_list(self, smiles_list: list[str]) -> list:
83+
def predict_smiles_tuple(self, smiles_list: list[str]) -> list:
5684
predictions = []
5785
for smiles in smiles_list:
58-
if not smiles:
59-
predictions.append(None)
60-
continue
61-
mol = Chem.MolFromSmiles(smiles)
62-
if mol is None:
63-
predictions.append(None)
64-
continue
65-
canonical_smiles = Chem.MolToSmiles(mol)
66-
if canonical_smiles in self.lookup_table:
67-
parent_candidates = self.lookup_table[canonical_smiles]
68-
preds_i = dict()
69-
if len(parent_candidates) > 1:
70-
print(f"Multiple matches found in ChEBI for SMILES {smiles}: {', '.join(str(chebi_id) for chebi_id, _ in parent_candidates)}")
71-
for k in list(set(pp for _, p in parent_candidates for pp in p)):
72-
preds_i[str(k)] = 1
73-
elif len(parent_candidates) == 1:
74-
chebi_id, parents = parent_candidates[0]
75-
for k in parents:
76-
preds_i[str(k)] = 1
77-
else:
78-
preds_i = None
79-
predictions.append(preds_i)
86+
predictions.append(self.predict_smiles(smiles))
8087

8188
return predictions
8289

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import tqdm
24
from chemlog.alg_classification.charge_classifier import get_charge_category
35
from chemlog.alg_classification.peptide_size_classifier import get_n_amino_acid_residues
@@ -10,6 +12,7 @@
1012
)
1113
from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call
1214
from chemlog_extra.alg_classification.by_element_classification import XMolecularEntityClassifier, OrganoXCompoundClassifier
15+
from functools import lru_cache
1316

1417
from .base_predictor import BasePredictor
1518

@@ -48,7 +51,7 @@ def __init__(self, model_name: str, **kwargs):
4851
self.chebi_graph = kwargs.get("chebi_graph", None)
4952
self.classifier = self.CHEMLOG_CLASSIFIER()
5053

51-
def predict_smiles_list(self, smiles_list: list[str]) -> list:
54+
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
5255
mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list]
5356
res = self.classifier.classify(mol_list)
5457
if self.chebi_graph is not None:
@@ -88,30 +91,32 @@ def __init__(self, model_name: str, **kwargs):
8891
# fmt: on
8992
print(f"Initialised ChemLog model {self.model_name}")
9093

91-
def predict_smiles_list(self, smiles_list: list[str]) -> list:
94+
@lru_cache(maxsize=100)
95+
def predict_smiles(self, smiles: str) -> Optional[dict]:
96+
mol = _smiles_to_mol(smiles)
97+
if mol is None:
98+
return None
99+
pos_labels = [label for label in self.peptide_labels if label in strategy_call(
100+
self.strategy, self.classifier_instances, mol
101+
)["chebi_classes"]]
102+
if self.chebi_graph:
103+
indirect_pos_labels = [str(pr) for label in pos_labels for pr in
104+
self.chebi_graph.predecessors(int(label))]
105+
pos_labels = list(set(pos_labels + indirect_pos_labels))
106+
return {
107+
label: (
108+
1
109+
if label
110+
in pos_labels
111+
else 0
112+
)
113+
for label in self.peptide_labels + pos_labels
114+
}
115+
116+
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
92117
results = []
93118
for i, smiles in tqdm.tqdm(enumerate(smiles_list)):
94-
mol = _smiles_to_mol(smiles)
95-
if mol is None:
96-
results.append(None)
97-
else:
98-
pos_labels = [label for label in self.peptide_labels if label in strategy_call(
99-
self.strategy, self.classifier_instances, mol
100-
)["chebi_classes"]]
101-
if self.chebi_graph:
102-
indirect_pos_labels = [str(pr) for label in pos_labels for pr in self.chebi_graph.predecessors(int(label))]
103-
pos_labels = list(set(pos_labels + indirect_pos_labels))
104-
results.append(
105-
{
106-
label: (
107-
1
108-
if label
109-
in pos_labels
110-
else 0
111-
)
112-
for label in self.peptide_labels + pos_labels
113-
}
114-
)
119+
results.append(self.predict_smiles(smiles))
115120

116121
for classifier in self.classifier_instances.values():
117122
classifier.on_finish()

chebifier/prediction_models/nn_predictor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import lru_cache
2+
13
import numpy as np
24
import torch
35
import tqdm
@@ -50,7 +52,8 @@ def read_smiles(self, smiles):
5052
d = reader.to_data(dict(features=smiles, labels=None))
5153
return d
5254

53-
def predict_smiles_list(self, smiles_list) -> list:
55+
@lru_cache(maxsize=100)
56+
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
5457
"""Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary
5558
Of classes and predicted values."""
5659
token_dicts = []

0 commit comments

Comments
 (0)