Skip to content

Commit 89b4812

Browse files
committed
add c3p integration
1 parent ecb48ff commit 89b4812

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

chebifier/model_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ResGatedPredictor,
1010
ChEBILookupPredictor
1111
)
12+
from chebifier.prediction_models.c3p_predictor import C3PPredictor
1213
from chebifier.prediction_models.chemlog_predictor import ChemlogXMolecularEntityPredictor, ChemlogOrganoXCompoundPredictor
1314

1415
ENSEMBLES = {
@@ -25,6 +26,7 @@
2526
"chebi_lookup": ChEBILookupPredictor,
2627
"chemlog_element": ChemlogXMolecularEntityPredictor,
2728
"chemlog_organox": ChemlogOrganoXCompoundPredictor,
29+
"c3p": C3PPredictor
2830
}
2931

3032

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Optional, List
2+
from pathlib import Path
3+
4+
from c3p import classifier as c3p_classifier
5+
6+
from chebifier.prediction_models import BasePredictor
7+
8+
9+
class C3PPredictor(BasePredictor):
10+
"""
11+
Wrapper for C3P (url).
12+
"""
13+
14+
def __init__(self, model_name: str, program_directory: Optional[Path]=None, chemical_classes: Optional[List[str]]=None, **kwargs):
15+
super().__init__(model_name, **kwargs)
16+
self.program_directory = program_directory
17+
self.chemical_classes = chemical_classes
18+
19+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
20+
result_list = c3p_classifier.classify(smiles_list, self.program_directory, self.chemical_classes, strict=False)
21+
result_reformatted = [dict() for _ in range(len(smiles_list))]
22+
for result in result_list:
23+
result_reformatted[smiles_list.index(result.input_smiles)][result.class_id.split(":")[1]] = result.is_match
24+
print(f"C3P predictions for {len(smiles_list)} SMILES strings:")
25+
for i, smiles in enumerate(smiles_list):
26+
print(f"{smiles}: {result_reformatted[i]}")
27+
return result_reformatted

0 commit comments

Comments
 (0)