Skip to content

Commit 2fe8a84

Browse files
committed
use the generalized prediction pipeline
1 parent eb990ca commit 2fe8a84

File tree

5 files changed

+26
-240
lines changed

5 files changed

+26
-240
lines changed

chebifier/model_registry.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
ChEBILookupPredictor,
88
ChemlogPeptidesPredictor,
99
ElectraPredictor,
10-
ResGatedPredictor,
10+
GNNPredictor,
1111
)
1212
from chebifier.prediction_models.c3p_predictor import C3PPredictor
1313
from chebifier.prediction_models.chemlog_predictor import (
@@ -26,7 +26,7 @@
2626

2727
MODEL_TYPES = {
2828
"electra": ElectraPredictor,
29-
"resgated": ResGatedPredictor,
29+
"resgated": GNNPredictor,
3030
"gat": GATPredictor,
3131
"chemlog": ChemlogAllPredictor,
3232
"chemlog_peptides": ChemlogPeptidesPredictor,
@@ -38,6 +38,6 @@
3838

3939

4040
common_keys = MODEL_TYPES.keys() & ENSEMBLES.keys()
41-
assert (
42-
not common_keys
43-
), f"Overlapping keys between MODEL_TYPES and ENSEMBLES: {common_keys}"
41+
assert not common_keys, (
42+
f"Overlapping keys between MODEL_TYPES and ENSEMBLES: {common_keys}"
43+
)

chebifier/prediction_models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from .chebi_lookup import ChEBILookupPredictor
44
from .chemlog_predictor import ChemlogExtraPredictor, ChemlogPeptidesPredictor
55
from .electra_predictor import ElectraPredictor
6-
from .gnn_predictor import ResGatedPredictor
6+
from .gnn_predictor import GNNPredictor
77

88
__all__ = [
99
"BasePredictor",
1010
"ChemlogPeptidesPredictor",
1111
"ElectraPredictor",
12-
"ResGatedPredictor",
12+
"GNNPredictor",
1313
"ChEBILookupPredictor",
1414
"ChemlogExtraPredictor",
1515
"C3PPredictor",

chebifier/prediction_models/electra_predictor.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
from typing import TYPE_CHECKING
2-
31
import numpy as np
42

53
from .nn_predictor import NNPredictor
64

7-
if TYPE_CHECKING:
8-
from chebai.models.electra import Electra
9-
105

116
def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
127
n_nodes = len(node_labels)
@@ -40,37 +35,22 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
4035

4136
class ElectraPredictor(NNPredictor):
4237
def __init__(self, model_name: str, ckpt_path: str, **kwargs):
43-
from chebai.preprocessing.reader import ChemDataReader
44-
45-
super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs)
38+
super().__init__(model_name, ckpt_path, **kwargs)
4639
print(f"Initialised Electra model {self.model_name} (device: {self.device})")
4740

48-
def init_model(self, ckpt_path: str, **kwargs) -> "Electra":
49-
from chebai.models.electra import Electra
50-
51-
model = Electra.load_from_checkpoint(
52-
ckpt_path,
53-
map_location=self.device,
54-
criterion=None,
55-
strict=False,
56-
metrics=dict(train=dict(), test=dict(), validation=dict()),
57-
pretrained_checkpoint=None,
58-
)
59-
model.eval()
60-
return model
61-
6241
def explain_smiles(self, smiles) -> dict:
6342
from chebai.preprocessing.reader import EMBEDDING_OFFSET
6443

65-
reader = self.reader_cls()
66-
token_dict = reader.to_data(dict(features=smiles, labels=None))
44+
token_dict = self._predictor._dm.reader.to_data(
45+
dict(features=smiles, labels=None)
46+
)
6747
tokens = np.array(token_dict["features"]).astype(int).tolist()
6848
result = self.calculate_results([token_dict])
6949

7050
token_labels = (
7151
["[CLR]"]
7252
+ [None for _ in range(EMBEDDING_OFFSET - 1)]
73-
+ list(reader.cache.keys())
53+
+ list(self._predictor._dm.reader.cache.keys())
7454
)
7555

7656
graphs = [
Lines changed: 2 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,120 +1,12 @@
1-
from typing import TYPE_CHECKING, Optional
2-
3-
import torch
4-
51
from .nn_predictor import NNPredictor
62

7-
if TYPE_CHECKING:
8-
from chebai_graph.models.gat import GATGraphPred
9-
from chebai_graph.models.resgated import ResGatedGraphPred
103

11-
12-
class ResGatedPredictor(NNPredictor):
4+
class GNNPredictor(NNPredictor):
135
def __init__(
146
self,
157
model_name: str,
168
ckpt_path: str,
17-
molecular_properties,
18-
dataset_cls: Optional[str] = None,
199
**kwargs,
2010
):
21-
from chebai_graph.preprocessing.datasets.chebi import (
22-
ChEBI50GraphProperties,
23-
GraphPropertiesMixIn,
24-
)
25-
from chebai_graph.preprocessing.properties import MolecularProperty
26-
27-
# molecular_properties is a list of class paths
28-
if molecular_properties is not None:
29-
properties = [self.load_class(prop)() for prop in molecular_properties]
30-
properties = sorted(
31-
properties, key=lambda prop: f"{prop.name}_{prop.encoder.name}"
32-
)
33-
else:
34-
properties = []
35-
for property in properties:
36-
property.encoder.eval = True
37-
self.molecular_properties = properties
38-
assert isinstance(self.molecular_properties, list) and all(
39-
isinstance(prop, MolecularProperty) for prop in self.molecular_properties
40-
)
41-
# TODO it should not be necessary to refer to the whole dataset class, disentangle dataset and molecule reading
42-
self.dataset_cls = (
43-
self.load_class(dataset_cls)
44-
if dataset_cls is not None
45-
else ChEBI50GraphProperties
46-
)
47-
self.dataset: Optional[GraphPropertiesMixIn] = self.dataset_cls(
48-
properties=molecular_properties
49-
)
50-
51-
super().__init__(
52-
model_name, ckpt_path, reader_cls=self.dataset.READER, **kwargs
53-
)
54-
11+
super().__init__(model_name, ckpt_path, **kwargs)
5512
print(f"Initialised GNN model {self.model_name} (device: {self.device})")
56-
57-
def load_class(self, class_path: str):
58-
module_path, class_name = class_path.rsplit(".", 1)
59-
module = __import__(module_path, fromlist=[class_name])
60-
return getattr(module, class_name)
61-
62-
def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphPred":
63-
import torch
64-
from chebai_graph.models.resgated import ResGatedGraphPred
65-
66-
model = ResGatedGraphPred.load_from_checkpoint(
67-
ckpt_path,
68-
map_location=torch.device(self.device),
69-
criterion=None,
70-
strict=False,
71-
metrics=dict(train=dict(), test=dict(), validation=dict()),
72-
pretrained_checkpoint=None,
73-
)
74-
model.eval()
75-
return model
76-
77-
def read_smiles(self, smiles):
78-
from chebai_graph.preprocessing.datasets.chebi import GraphPropAsPerNodeType
79-
80-
d = self.dataset.READER().to_data(dict(features=smiles, labels=None))
81-
property_data = d
82-
# TODO merge props into base should not be a method of a dataset (or at least static)
83-
for property in self.dataset.properties:
84-
property.encoder.eval = True
85-
property_value = self.reader.read_property(smiles, property)
86-
if property_value is None or len(property_value) == 0:
87-
encoded_value = None
88-
else:
89-
encoded_value = torch.stack(
90-
[property.encoder.encode(v) for v in property_value]
91-
)
92-
if len(encoded_value.shape) == 3:
93-
encoded_value = encoded_value.squeeze(0)
94-
property_data[property.name] = encoded_value
95-
# Augmented graphs need an additional argument
96-
if isinstance(self.dataset, GraphPropAsPerNodeType):
97-
d["features"] = self.dataset._merge_props_into_base(
98-
property_data, max_len_node_properties=self.model.gnn.in_channels
99-
)
100-
else:
101-
d["features"] = self.dataset._merge_props_into_base(property_data)
102-
return d
103-
104-
105-
class GATPredictor(ResGatedPredictor):
106-
107-
def init_model(self, ckpt_path: str, **kwargs) -> "GATGraphPred":
108-
import torch
109-
from chebai_graph.models.gat import GATGraphPred
110-
111-
model = GATGraphPred.load_from_checkpoint(
112-
ckpt_path,
113-
map_location=torch.device(self.device),
114-
criterion=None,
115-
strict=False,
116-
metrics=dict(train=dict(), test=dict(), validation=dict()),
117-
pretrained_checkpoint=None,
118-
)
119-
model.eval()
120-
return model
Lines changed: 12 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,29 @@
1-
import numpy as np
2-
import tqdm
3-
from rdkit import Chem
1+
from abc import ABC
2+
3+
from chebai.result.prediction import Predictor
44

55
from chebifier import modelwise_smiles_lru_cache
66

77
from .base_predictor import BasePredictor
88

99

10-
class NNPredictor(BasePredictor):
10+
class NNPredictor(BasePredictor, ABC):
1111
def __init__(
1212
self,
1313
model_name: str,
1414
ckpt_path: str,
15-
reader_cls,
16-
target_labels_path: str,
1715
**kwargs,
1816
):
19-
import torch
17+
self.batch_size = kwargs.get("batch_size", None)
18+
# If batch_size is not provided, it will be set to default batch size used during training in Predictor
19+
self._predictor: Predictor = Predictor(ckpt_path, self.batch_size)
2020

2121
super().__init__(model_name, **kwargs)
22-
self.reader_cls = reader_cls
23-
self.reader = reader_cls()
24-
25-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26-
self.model = self.init_model(ckpt_path=ckpt_path)
27-
self.target_labels = [
28-
line.strip() for line in open(target_labels_path, encoding="utf-8")
29-
]
30-
self.batch_size = kwargs.get("batch_size", 1)
31-
32-
def init_model(self, ckpt_path: str, **kwargs):
33-
raise NotImplementedError(
34-
"Model initialization must be implemented in subclasses."
35-
)
36-
37-
def calculate_results(self, batch):
38-
collator = self.reader_cls.COLLATOR()
39-
dat = self.model._process_batch(collator(batch).to(self.device), 0)
40-
return self.model(dat, **dat["model_kwargs"])
41-
42-
def batchify(self, batch):
43-
cache = []
44-
for r in batch:
45-
cache.append(r)
46-
if len(cache) >= self.batch_size:
47-
yield cache
48-
cache = []
49-
if cache:
50-
yield cache
51-
52-
def read_smiles(self, smiles):
53-
d = self.reader.to_data(dict(features=smiles, labels=None))
54-
return d
5522

5623
@modelwise_smiles_lru_cache.batch_decorator
5724
def predict_smiles_list(self, smiles_list: list[str]) -> list:
58-
"""Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary
59-
Of classes and predicted values."""
60-
import torch
61-
62-
token_dicts = []
63-
could_not_parse = []
64-
index_map = dict()
65-
for i, smiles in enumerate(smiles_list):
66-
if not smiles:
67-
print(
68-
f"Model {self.model_name} received a missing SMILES string at position {i}."
69-
)
70-
could_not_parse.append(i)
71-
continue
72-
try:
73-
d = self.read_smiles(smiles)
74-
# This is just for sanity checks
75-
rdmol = Chem.MolFromSmiles(smiles, sanitize=False)
76-
if rdmol is None:
77-
print(
78-
f"Model {self.model_name} received a SMILES string RDKit can't read at position {i}: {smiles}"
79-
)
80-
could_not_parse.append(i)
81-
continue
82-
except Exception:
83-
could_not_parse.append(i)
84-
print(
85-
f"Model {self.model_name} failed to parse a SMILES string at position {i}: {smiles}"
86-
)
87-
continue
88-
index_map[i] = len(token_dicts)
89-
token_dicts.append(d)
90-
results = []
91-
if len(token_dicts) > 0:
92-
for batch in tqdm.tqdm(
93-
self.batchify(token_dicts),
94-
desc=f"{self.model_name}",
95-
total=len(token_dicts) // self.batch_size,
96-
):
97-
result = self.calculate_results(batch)
98-
if isinstance(result, dict) and "logits" in result:
99-
result = result["logits"]
100-
results += torch.sigmoid(result).cpu().detach().tolist()
101-
results = np.stack(results, axis=0)
102-
preds = [
103-
(
104-
{
105-
self.target_labels[j]: p
106-
for j, p in enumerate(results[index_map[i]])
107-
}
108-
if i not in could_not_parse
109-
else None
110-
)
111-
for i in range(len(smiles_list))
112-
]
113-
return preds
114-
else:
115-
return [None for _ in smiles_list]
25+
"""
26+
Returns a list with the length of smiles_list, each element is
27+
either None (=failure) or a dictionary of classes and predicted values.
28+
"""
29+
return self._predictor.predict_smiles(smiles_list)

0 commit comments

Comments
 (0)