Skip to content

Commit 76d8a79

Browse files
committed
predict method implementation for data file and list of smiles
1 parent bf3cf64 commit 76d8a79

File tree

4 files changed

+93
-80
lines changed

4 files changed

+93
-80
lines changed

chebai/ensemble/_base.py

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,15 @@
1-
import importlib
2-
import json
3-
import os
41
from abc import ABC, abstractmethod
52
from collections import deque
6-
from typing import Any, Deque, Dict, Optional, Tuple
3+
from pathlib import Path
4+
from typing import Any, Deque, Dict, Optional
75

6+
import pandas as pd
87
import torch
98
from lightning import LightningModule
109

11-
from chebai.models import ChebaiBaseNet
12-
from chebai.preprocessing.structures import XYData
1310
from chebai.result.classification import print_metrics
1411

15-
from ._constants import (
16-
MODEL_CKPT_PATH,
17-
MODEL_CLS_PATH,
18-
MODEL_LBL_PATH,
19-
READER_CLS_PATH,
20-
WRAPPER_CLS_PATH,
21-
)
22-
from ._utils import _load_class
23-
from ._wrappers import BaseWrapper
12+
from ._constants import MODEL_CLS_PATH, MODEL_LBL_PATH, WRAPPER_CLS_PATH
2413

2514

2615
class EnsembleBase(ABC):
@@ -33,38 +22,45 @@ class EnsembleBase(ABC):
3322
def __init__(
3423
self,
3524
model_configs: Dict[str, Dict[str, Any]],
36-
data_processed_dir_main: str,
25+
data_file_path: str,
26+
classes_file_path: str,
3727
**kwargs: Any,
3828
) -> None:
3929
"""
4030
Initializes the ensemble model and loads configurations, labels, and sets up the environment.
4131
4232
Args:
4333
model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations.
44-
data_processed_dir_main (str): Path to the processed data directory.
34+
data_file_path (str): Path to the processed data directory.
4535
reader_dir_name (str): Name of the directory used by the reader. Defaults to 'smiles_token'.
4636
**kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'.
4737
"""
4838
if bool(kwargs.get("_validate_configs", True)):
4939
self._validate_model_configs(model_configs)
5040

5141
self._model_configs: Dict[str, Dict[str, Any]] = model_configs
52-
self._data_processed_dir_main: str = data_processed_dir_main
42+
self._data_file_path: str = data_file_path
43+
self._classes_file_path: str = classes_file_path
5344
self._input_dim: Optional[int] = kwargs.get("input_dim", None)
54-
self._total_data_size: int = len(self._collated_data)
45+
self._total_data_size: int = None
46+
self._ensemble_input: list[str] | Path = self._process_input_to_ensemble(
47+
data_file_path
48+
)
49+
print(f"Total data size (data.pkl) is {self._total_data_size}")
5550

5651
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5752

5853
self._models: Dict[str, LightningModule] = {}
59-
self._dm_labels: Dict[str, int] = self._load_data_module_labels()
54+
self._dm_labels: Dict[str, int] = self._load_data_module_labels(
55+
classes_file_path
56+
)
6057
self._num_of_labels: int = len(self._dm_labels)
58+
print(f"Number of labes for this data is {self._num_of_labels} ")
6159

6260
self._num_models_per_label: torch.Tensor = torch.zeros(
6361
1, self._num_of_labels, device=self._device
6462
)
6563
self._model_queue: Deque[str] = deque()
66-
self._collated_data: Optional[XYData] = None
67-
self._total_data_size: Optional[int] = None
6864

6965
@classmethod
7066
def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> None:
@@ -107,21 +103,43 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No
107103
class_set.add(model_class_path)
108104
labels_set.add(model_labels_path)
109105

110-
def _load_data_module_labels(self) -> dict[str, int]:
106+
def _process_input_to_ensemble(self, path: str):
107+
p = Path(path)
108+
if p.is_file():
109+
smiles_list = []
110+
with open(p, "r") as f:
111+
for line in f:
112+
# Skip empty or whitespace-only lines
113+
if line.strip():
114+
# Split on whitespace and take the first item as the SMILES
115+
smiles = line.strip().split()[0]
116+
smiles_list.append(smiles)
117+
self._total_data_size = len(smiles_list)
118+
return smiles_list
119+
elif p.is_dir():
120+
data_pkl_path = p / "data.pkl"
121+
if not data_pkl_path.exists():
122+
raise FileNotFoundError()
123+
self._total_data_size = len(pd.read_pickle(data_pkl_path))
124+
return p
125+
else:
126+
raise "Invalid path"
127+
128+
@staticmethod
129+
def _load_data_module_labels(classes_file_path: str) -> dict[str, int]:
111130
"""
112131
Loads class labels from the classes.txt file and sets internal label mapping.
113132
114133
Raises:
115134
FileNotFoundError: If the expected classes.txt file is not found.
116135
"""
117-
classes_txt_file = os.path.join(self._data_processed_dir_main, "classes.txt")
118-
print(f"Loading {classes_txt_file} ....")
119-
120-
if not os.path.exists(classes_txt_file):
121-
raise FileNotFoundError(f"{classes_txt_file} does not exist")
136+
classes_file_path = Path(classes_file_path)
137+
if not classes_file_path.exists():
138+
raise FileNotFoundError(f"{classes_file_path} does not exist")
139+
print(f"Loading {classes_file_path} ....")
122140

123141
dm_labels_dict = {}
124-
with open(classes_txt_file, "r") as f:
142+
with open(classes_file_path, "r") as f:
125143
for line in f:
126144
label = line.strip()
127145
if label not in dm_labels_dict:
@@ -132,6 +150,7 @@ def run_ensemble(self) -> None:
132150
"""
133151
Executes the full ensemble prediction pipeline, aggregating predictions and printing metrics.
134152
"""
153+
assert self._total_data_size is not None and self._num_of_labels is not None
135154
true_scores = torch.zeros(
136155
self._total_data_size, self._num_of_labels, device=self._device
137156
)
@@ -144,12 +163,12 @@ def run_ensemble(self) -> None:
144163
print(f"Processing model: {model_name}")
145164

146165
print("\t Passing model to controller to generate predictions...")
147-
pred_conf_dict, model_props = self._controller(model_name)
166+
controller_output = self._controller(model_name, self._ensemble_input)
148167

149168
print("\t Passing predictions to consolidator for aggregation...")
150169
self._consolidator(
151-
pred_conf_dict,
152-
model_props,
170+
pred_conf_dict=controller_output["pred_conf_dict"],
171+
model_props=controller_output["model_props"],
153172
true_scores=true_scores,
154173
false_scores=false_scores,
155174
)
@@ -168,8 +187,8 @@ def run_ensemble(self) -> None:
168187
@abstractmethod
169188
def _controller(
170189
self,
171-
model: LightningModule,
172-
model_props: Dict[str, torch.Tensor],
190+
model_name: str,
191+
model_input: list[str] | Path,
173192
**kwargs: Any,
174193
) -> Dict[str, torch.Tensor]:
175194
"""

chebai/ensemble/_controller.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
import os.path
21
from abc import ABC
32
from collections import deque
43
from typing import Any, Deque, Dict
54

65
import torch
7-
from lightning import LightningModule
86
from torch import Tensor
97

108
from chebai.models import ChebaiBaseNet
11-
from chebai.preprocessing.collate import RaggedCollator
129

1310
from ._base import EnsembleBase
1411
from ._constants import WRAPPER_CLS_PATH
@@ -72,7 +69,6 @@ def _wrap_model(self, model_name: str) -> BaseWrapper:
7269
**self._kwargs
7370
)
7471
assert isinstance(wrapped_model, BaseWrapper), ""
75-
# del wrapped_model # Model can be huge to keep it in memory, delete as no longer needed
7672
return wrapped_model
7773

7874

@@ -93,7 +89,7 @@ def __init__(self, **kwargs: Any):
9389
super().__init__(**kwargs)
9490
self._model_queue: Deque[str] = deque(list(self._model_configs.keys()))
9591

96-
def _controller(self, model_name, **kwargs: Any) -> Dict[str, Tensor]:
92+
def _controller(self, model_name, model_input, **kwargs: Any) -> Dict[str, Tensor]:
9793
"""
9894
Performs inference with the model and extracts predictions and confidence values.
9995
@@ -105,4 +101,9 @@ def _controller(self, model_name, **kwargs: Any) -> Dict[str, Tensor]:
105101
Dict[str, Tensor]: Dictionary containing predictions and confidence scores.
106102
"""
107103
wrapped_model = self._wrap_model(model_name)
108-
return self._get_pred_conf_from_model_output(model_output, model_props["mask"])
104+
model_output, model_props = wrapped_model.predict(model_input)
105+
del wrapped_model # Model can be huge to keep it in memory, delete asap as no longer needed
106+
pred_conf_dict = self._get_pred_conf_from_model_output(
107+
model_output, model_props["mask"]
108+
)
109+
return {"pred_conf_dict": pred_conf_dict, "model_props": model_props}

chebai/ensemble/_wrappers/_base.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import importlib
21
import json
32
import os
43
from abc import ABC, abstractmethod
5-
from typing import overload
4+
from pathlib import Path
65

76
import torch
87

@@ -22,10 +21,9 @@ def __init__(
2221
self._model_name = model_name
2322
self._model_class_path = self._model_config[MODEL_CLS_PATH]
2423
self._model_labels_path = self._model_config[MODEL_LBL_PATH]
25-
self._dm_labels: dict[str, int] = dm_labels
26-
self._model_props = self._generate_model_label_props()
24+
self._model_props = self._generate_model_label_props(dm_labels=dm_labels)
2725

28-
def _generate_model_label_props(self) -> dict[str, torch.Tensor]:
26+
def _generate_model_label_props(self, dm_labels) -> dict[str, torch.Tensor]:
2927
"""
3028
Generates label mask and confidence tensors (TPV, FPV) for a model.
3129
@@ -38,13 +36,15 @@ def _generate_model_label_props(self) -> dict[str, torch.Tensor]:
3836
model_label_indices, tpv_label_values, fpv_label_values = [], [], []
3937

4038
for label, props in labels_dict.items():
41-
if label in self._dm_labels:
39+
if label in dm_labels:
4240
try:
4341
self._validate_model_labels_json_element(labels_dict[label])
4442
except Exception as e:
45-
raise Exception(f"Label '{label}' has an unexpected error") from e
43+
raise Exception(
44+
f"Label '{label}' has an unexpected error \n Error: {e}"
45+
)
4646

47-
model_label_indices.append(self._dm_labels[label])
47+
model_label_indices.append(dm_labels[label])
4848
tpv_label_values.append(props["TPV"])
4949
fpv_label_values.append(props["FPV"])
5050

@@ -54,7 +54,7 @@ def _generate_model_label_props(self) -> dict[str, torch.Tensor]:
5454
)
5555

5656
# Create masks to apply predictions only to known classes
57-
mask = torch.zeros(len(self._dm_labels), dtype=torch.bool, device=self._device)
57+
mask = torch.zeros(len(dm_labels), dtype=torch.bool, device=self._device)
5858
mask[torch.tensor(model_label_indices, device=self._device)] = True
5959

6060
tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device)
@@ -113,26 +113,14 @@ def _validate_model_labels_json_element(label_dict: dict[str, float]) -> None:
113113
def name(self):
114114
return f"Wrapper({self.__class__.__name__}) for model: {self._model_name}"
115115

116-
@overload
117-
def predict(self, smiles_list: list) -> tuple[dict, dict]:
118-
pass
119-
120-
@overload
121-
def predict(self, data_file_path: str) -> tuple[dict, dict]:
122-
pass
123-
124-
def predict(self, x: list | str) -> tuple[dict, dict]:
125-
if isinstance(x, list):
126-
return self._predict_from_list_of_smiles(x), self._model_props
127-
elif isinstance(x, str):
128-
return self._predict_from_data_file(x), self._model_props
129-
else:
130-
raise TypeError(f"Type {type(x)} is not supported.")
116+
def predict(self, x: list) -> tuple[dict, dict]:
117+
return self._predict_from_list_of_smiles(x), self._model_props
131118

132119
@abstractmethod
133-
def _predict_from_list_of_smiles(self, smiles_list: list) -> dict:
134-
pass
120+
def _predict_from_list_of_smiles(self, smiles_list: list) -> dict: ...
121+
122+
def evaluate(self, data_processed_dir_main: Path) -> tuple[dict, dict]:
123+
return self._evaluate_from_data_file(data_processed_dir_main), self._model_props
135124

136125
@abstractmethod
137-
def _predict_from_data_file(self, data_file_path: str) -> dict:
138-
pass
126+
def _evaluate_from_data_file(self, data_file_path: str) -> dict: ...

chebai/ensemble/_wrappers/_neural_network.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Optional, Type
2+
from typing import Type
33

44
import torch
55
from rdkit import Chem
@@ -15,7 +15,9 @@
1515
class NNWrapper(BaseWrapper):
1616

1717
def __init__(self, **kwargs):
18-
self._validate_model_configs(**kwargs)
18+
self._validate_model_configs(
19+
model_config=kwargs["model_config"], model_name=kwargs["model_name"]
20+
)
1921
super().__init__(**kwargs)
2022

2123
self._model_ckpt_path = self._model_config[MODEL_CKPT_PATH]
@@ -30,11 +32,15 @@ def __init__(self, **kwargs):
3032
assert issubclass(reader_cls, DataReader), ""
3133
self._reader = reader_cls(**self._reader_kwargs)
3234
self._collator = reader_cls.COLLATOR()
33-
self._model: ChebaiBaseNet = self._load_model_()
35+
self._model: ChebaiBaseNet = self._load_model_(
36+
input_dim=kwargs.get("input_dim", None)
37+
)
3438

3539
@classmethod
3640
def _validate_model_configs(
37-
cls, model_config: dict[str, str], model_name: str
41+
cls,
42+
model_config: dict[str, str],
43+
model_name: str,
3844
) -> None:
3945
"""
4046
Validates model configuration dictionary for required keys and uniqueness.
@@ -57,12 +63,12 @@ def _validate_model_configs(
5763
f"Missing keys {missing_keys} in model '{model_name}' configuration."
5864
)
5965

60-
def _load_model_(self) -> ChebaiBaseNet:
66+
def _load_model_(self, input_dim: int | None) -> ChebaiBaseNet:
6167
"""
6268
Loads a model checkpoint and its label-related properties.
6369
6470
Args:
65-
model_name (str): Name of the model to load.
71+
input_dim (int): Name of the model to load.
6672
6773
Returns:
6874
Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties.
@@ -73,22 +79,21 @@ def _load_model_(self) -> ChebaiBaseNet:
7379
f"Model path '{self._model_ckpt_path}' for '{self._model_name}' does not exist."
7480
)
7581

76-
lightning_cls = self._load_class(self._model_class_path)
82+
lightning_cls = _load_class(self._model_class_path)
7783

7884
assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class."
7985
assert issubclass(
8086
lightning_cls, ChebaiBaseNet
8187
), f"{lightning_cls} must inherit from ChebaiBaseNet"
82-
8388
try:
8489
model = lightning_cls.load_from_checkpoint(
85-
self._model_ckpt_path, input_dim=self.input_dim
90+
self._model_ckpt_path, input_dim=5
8691
)
87-
model.eval()
88-
model.freeze()
8992
except Exception as e:
90-
raise RuntimeError(f"Error loading model {self._model_name}") from e
93+
raise RuntimeError(f"Error loading model {self._model_name} \n Error: {e}")
9194

95+
model.eval()
96+
model.freeze()
9297
return model
9398

9499
def _predict_from_list_of_smiles(self, smiles_list) -> list:

0 commit comments

Comments
 (0)