|
| 1 | +from typing import Optional |
| 2 | + |
1 | 3 | import tqdm |
2 | 4 | from chemlog.alg_classification.charge_classifier import get_charge_category |
3 | 5 | from chemlog.alg_classification.peptide_size_classifier import get_n_amino_acid_residues |
|
10 | 12 | ) |
11 | 13 | from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call |
12 | 14 | from chemlog_extra.alg_classification.by_element_classification import XMolecularEntityClassifier, OrganoXCompoundClassifier |
| 15 | +from functools import lru_cache |
13 | 16 |
|
14 | 17 | from .base_predictor import BasePredictor |
15 | 18 |
|
@@ -48,7 +51,7 @@ def __init__(self, model_name: str, **kwargs): |
48 | 51 | self.chebi_graph = kwargs.get("chebi_graph", None) |
49 | 52 | self.classifier = self.CHEMLOG_CLASSIFIER() |
50 | 53 |
|
51 | | - def predict_smiles_list(self, smiles_list: list[str]) -> list: |
| 54 | + def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: |
52 | 55 | mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list] |
53 | 56 | res = self.classifier.classify(mol_list) |
54 | 57 | if self.chebi_graph is not None: |
@@ -88,30 +91,32 @@ def __init__(self, model_name: str, **kwargs): |
88 | 91 | # fmt: on |
89 | 92 | print(f"Initialised ChemLog model {self.model_name}") |
90 | 93 |
|
91 | | - def predict_smiles_list(self, smiles_list: list[str]) -> list: |
| 94 | + @lru_cache(maxsize=100) |
| 95 | + def predict_smiles(self, smiles: str) -> Optional[dict]: |
| 96 | + mol = _smiles_to_mol(smiles) |
| 97 | + if mol is None: |
| 98 | + return None |
| 99 | + pos_labels = [label for label in self.peptide_labels if label in strategy_call( |
| 100 | + self.strategy, self.classifier_instances, mol |
| 101 | + )["chebi_classes"]] |
| 102 | + if self.chebi_graph: |
| 103 | + indirect_pos_labels = [str(pr) for label in pos_labels for pr in |
| 104 | + self.chebi_graph.predecessors(int(label))] |
| 105 | + pos_labels = list(set(pos_labels + indirect_pos_labels)) |
| 106 | + return { |
| 107 | + label: ( |
| 108 | + 1 |
| 109 | + if label |
| 110 | + in pos_labels |
| 111 | + else 0 |
| 112 | + ) |
| 113 | + for label in self.peptide_labels + pos_labels |
| 114 | + } |
| 115 | + |
| 116 | + def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: |
92 | 117 | results = [] |
93 | 118 | for i, smiles in tqdm.tqdm(enumerate(smiles_list)): |
94 | | - mol = _smiles_to_mol(smiles) |
95 | | - if mol is None: |
96 | | - results.append(None) |
97 | | - else: |
98 | | - pos_labels = [label for label in self.peptide_labels if label in strategy_call( |
99 | | - self.strategy, self.classifier_instances, mol |
100 | | - )["chebi_classes"]] |
101 | | - if self.chebi_graph: |
102 | | - indirect_pos_labels = [str(pr) for label in pos_labels for pr in self.chebi_graph.predecessors(int(label))] |
103 | | - pos_labels = list(set(pos_labels + indirect_pos_labels)) |
104 | | - results.append( |
105 | | - { |
106 | | - label: ( |
107 | | - 1 |
108 | | - if label |
109 | | - in pos_labels |
110 | | - else 0 |
111 | | - ) |
112 | | - for label in self.peptide_labels + pos_labels |
113 | | - } |
114 | | - ) |
| 119 | + results.append(self.predict_smiles(smiles)) |
115 | 120 |
|
116 | 121 | for classifier in self.classifier_instances.values(): |
117 | 122 | classifier.on_finish() |
|
0 commit comments