Skip to content

Commit 95d49c1

Browse files
committed
seperate method for evaluate and prediction
1 parent 76d8a79 commit 95d49c1

File tree

5 files changed

+103
-46
lines changed

5 files changed

+103
-46
lines changed

chebai/ensemble/_base.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
from abc import ABC, abstractmethod
22
from collections import deque
33
from pathlib import Path
4-
from typing import Any, Deque, Dict, Optional
4+
from typing import Any, Deque, Dict, Literal, Optional
55

66
import pandas as pd
77
import torch
88
from lightning import LightningModule
99

1010
from chebai.result.classification import print_metrics
1111

12-
from ._constants import MODEL_CLS_PATH, MODEL_LBL_PATH, WRAPPER_CLS_PATH
12+
from ._constants import (
13+
EVAL_OP,
14+
MODEL_CLS_PATH,
15+
MODEL_LBL_PATH,
16+
PRED_OP,
17+
WRAPPER_CLS_PATH,
18+
)
1319

1420

1521
class EnsembleBase(ABC):
@@ -22,38 +28,40 @@ class EnsembleBase(ABC):
2228
def __init__(
2329
self,
2430
model_configs: Dict[str, Dict[str, Any]],
25-
data_file_path: str,
26-
classes_file_path: str,
31+
data_processed_dir_main: str,
32+
operation: str = EVAL_OP,
2733
**kwargs: Any,
2834
) -> None:
2935
"""
3036
Initializes the ensemble model and loads configurations, labels, and sets up the environment.
3137
3238
Args:
3339
model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations.
34-
data_file_path (str): Path to the processed data directory.
40+
data_processed_dir_main (str): Path to the processed data directory.
3541
reader_dir_name (str): Name of the directory used by the reader. Defaults to 'smiles_token'.
3642
**kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'.
3743
"""
38-
if bool(kwargs.get("_validate_configs", True)):
39-
self._validate_model_configs(model_configs)
44+
if bool(kwargs.get("_perform_validation_checks", True)):
45+
self._perform_validation_checks(
46+
model_configs, operation=operation, **kwargs
47+
)
4048

4149
self._model_configs: Dict[str, Dict[str, Any]] = model_configs
42-
self._data_file_path: str = data_file_path
43-
self._classes_file_path: str = classes_file_path
50+
self._data_processed_dir_main: str = data_processed_dir_main
51+
self._operation: str = operation
52+
print(f"Ensemble operation: {self._operation}")
53+
4454
self._input_dim: Optional[int] = kwargs.get("input_dim", None)
4555
self._total_data_size: int = None
4656
self._ensemble_input: list[str] | Path = self._process_input_to_ensemble(
47-
data_file_path
57+
**kwargs
4858
)
4959
print(f"Total data size (data.pkl) is {self._total_data_size}")
5060

5161
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5262

5363
self._models: Dict[str, LightningModule] = {}
54-
self._dm_labels: Dict[str, int] = self._load_data_module_labels(
55-
classes_file_path
56-
)
64+
self._dm_labels: Dict[str, int] = self._load_data_module_labels()
5765
self._num_of_labels: int = len(self._dm_labels)
5866
print(f"Number of labes for this data is {self._num_of_labels} ")
5967

@@ -63,7 +71,9 @@ def __init__(
6371
self._model_queue: Deque[str] = deque()
6472

6573
@classmethod
66-
def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> None:
74+
def _perform_validation_checks(
75+
cls, model_configs: Dict[str, Dict[str, Any]], operation, **kwargs
76+
) -> None:
6777
"""
6878
Validates model configuration dictionary for required keys and uniqueness.
6979
@@ -74,6 +84,19 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No
7484
AttributeError: If any model config is missing required keys.
7585
ValueError: If duplicate paths are found for model checkpoint, class, or labels.
7686
"""
87+
if operation not in ["evaluate", "predict"]:
88+
raise ValueError(
89+
f"Invalid operation '{operation}'. Must be 'evaluate' or 'predict'."
90+
)
91+
92+
if operation == "predict" and not kwargs.get("smiles_list_file_path", None):
93+
raise ValueError(
94+
"For 'predict' operation, 'smiles_list_file_path' must be provided."
95+
)
96+
97+
if not Path(kwargs.get("smiles_list_file_path")).exists():
98+
raise FileNotFoundError(f"{kwargs.get('smiles_list_file_path')}")
99+
77100
class_set, labels_set = set(), set()
78101
required_keys = {
79102
MODEL_CLS_PATH,
@@ -103,9 +126,9 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No
103126
class_set.add(model_class_path)
104127
labels_set.add(model_labels_path)
105128

106-
def _process_input_to_ensemble(self, path: str):
107-
p = Path(path)
108-
if p.is_file():
129+
def _process_input_to_ensemble(self, **kwargs: any) -> list[str] | Path:
130+
if self._operation == PRED_OP:
131+
p = Path(kwargs["smiles_list_file_path"])
109132
smiles_list = []
110133
with open(p, "r") as f:
111134
for line in f:
@@ -116,24 +139,23 @@ def _process_input_to_ensemble(self, path: str):
116139
smiles_list.append(smiles)
117140
self._total_data_size = len(smiles_list)
118141
return smiles_list
119-
elif p.is_dir():
120-
data_pkl_path = p / "data.pkl"
142+
elif self._operation == EVAL_OP:
143+
data_pkl_path = Path(self._data_processed_dir_main) / "data.pkl"
121144
if not data_pkl_path.exists():
122145
raise FileNotFoundError()
123146
self._total_data_size = len(pd.read_pickle(data_pkl_path))
124147
return p
125148
else:
126-
raise "Invalid path"
149+
raise ValueError("Invalid operation")
127150

128-
@staticmethod
129-
def _load_data_module_labels(classes_file_path: str) -> dict[str, int]:
151+
def _load_data_module_labels(self) -> dict[str, int]:
130152
"""
131153
Loads class labels from the classes.txt file and sets internal label mapping.
132154
133155
Raises:
134156
FileNotFoundError: If the expected classes.txt file is not found.
135157
"""
136-
classes_file_path = Path(classes_file_path)
158+
classes_file_path = Path(self._data_processed_dir_main) / "classes.txt"
137159
if not classes_file_path.exists():
138160
raise FileNotFoundError(f"{classes_file_path} does not exist")
139161
print(f"Loading {classes_file_path} ....")
@@ -197,14 +219,13 @@ def _controller(
197219
Returns:
198220
Dict[str, torch.Tensor]: Predictions or confidence scores.
199221
"""
200-
pass
201222

202223
@abstractmethod
203224
def _consolidator(
204225
self,
226+
*,
205227
pred_conf_dict: Dict[str, torch.Tensor],
206228
model_props: Dict[str, torch.Tensor],
207-
*,
208229
true_scores: torch.Tensor,
209230
false_scores: torch.Tensor,
210231
**kwargs: Any,
@@ -214,7 +235,6 @@ def _consolidator(
214235
215236
Should update the provided `true_scores` and `false_scores`.
216237
"""
217-
pass
218238

219239
@abstractmethod
220240
def _consolidate_on_finish(
@@ -226,4 +246,3 @@ def _consolidate_on_finish(
226246
Returns:
227247
torch.Tensor: Final aggregated predictions.
228248
"""
229-
pass

chebai/ensemble/_constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@
66

77
READER_CLS_PATH = "reader_class_path"
88
READER_KWARGS = "reader_kwargs"
9+
10+
11+
PRED_OP = "prediction"
12+
EVAL_OP = "evaluation"

chebai/ensemble/_controller.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
from abc import ABC
1+
from abc import ABC, abstractmethod
22
from collections import deque
33
from typing import Any, Deque, Dict
44

55
import torch
66
from torch import Tensor
77

8-
from chebai.models import ChebaiBaseNet
9-
108
from ._base import EnsembleBase
11-
from ._constants import WRAPPER_CLS_PATH
9+
from ._constants import EVAL_OP, PRED_OP, WRAPPER_CLS_PATH
1210
from ._utils import _load_class
1311
from ._wrappers import BaseWrapper
1412

@@ -33,6 +31,30 @@ def __init__(self, **kwargs: Any):
3331
super().__init__(**kwargs)
3432
self._kwargs = kwargs
3533

34+
@abstractmethod
35+
def _controller(self, model_name, model_input, **kwargs: Any) -> Dict[str, Tensor]:
36+
"""
37+
Performs inference with the model and extracts predictions and confidence values.
38+
39+
Args:
40+
model (ChebaiBaseNet): The model to perform inference with.
41+
model_props (Dict[str, Tensor]): Dictionary with label mask and trust scores.
42+
43+
Returns:
44+
Dict[str, Tensor]: Dictionary containing predictions and confidence scores.
45+
"""
46+
wrapped_model = self._wrap_model(model_name)
47+
if self._operation == PRED_OP:
48+
model_output, model_props = wrapped_model.predict(model_input)
49+
else:
50+
model_output, model_props = wrapped_model.evaluate(model_input)
51+
del wrapped_model # Model can be huge to keep it in memory, delete asap as no longer needed
52+
53+
pred_conf_dict = self._get_pred_conf_from_model_output(
54+
model_output, model_props["mask"]
55+
)
56+
return {"pred_conf_dict": pred_conf_dict, "model_props": model_props}
57+
3658
def _get_pred_conf_from_model_output(
3759
self, model_output: Dict[str, Tensor], model_label_mask: Tensor
3860
) -> Dict[str, Tensor]:
@@ -100,10 +122,7 @@ def _controller(self, model_name, model_input, **kwargs: Any) -> Dict[str, Tenso
100122
Returns:
101123
Dict[str, Tensor]: Dictionary containing predictions and confidence scores.
102124
"""
103-
wrapped_model = self._wrap_model(model_name)
104-
model_output, model_props = wrapped_model.predict(model_input)
105-
del wrapped_model # Model can be huge to keep it in memory, delete asap as no longer needed
106-
pred_conf_dict = self._get_pred_conf_from_model_output(
107-
model_output, model_props["mask"]
108-
)
109-
return {"pred_conf_dict": pred_conf_dict, "model_props": model_props}
125+
126+
output_dict = super()._controller(model_name, model_input, **kwargs)
127+
# Some activation condition can be applied, not in this controller, so we return the output directly
128+
return output_dict

chebai/ensemble/_wrappers/_base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,22 @@ def name(self):
114114
return f"Wrapper({self.__class__.__name__}) for model: {self._model_name}"
115115

116116
def predict(self, x: list) -> tuple[dict, dict]:
117+
if not isinstance(x, list):
118+
raise TypeError(f"Input must be a list of SMILES strings, got {type(x)}")
117119
return self._predict_from_list_of_smiles(x), self._model_props
118120

119121
@abstractmethod
120122
def _predict_from_list_of_smiles(self, smiles_list: list) -> dict: ...
121123

122-
def evaluate(self, data_processed_dir_main: Path) -> tuple[dict, dict]:
123-
return self._evaluate_from_data_file(data_processed_dir_main), self._model_props
124+
def evaluate(
125+
self, data_processed_dir_main: Path, **kwargs: any
126+
) -> tuple[dict, dict]:
127+
if not data_processed_dir_main.is_dir():
128+
raise NotADirectoryError(f"{data_processed_dir_main} is not a directory.")
129+
return (
130+
self._evaluate_from_data_file(data_processed_dir_main, **kwargs),
131+
self._model_props,
132+
)
124133

125134
@abstractmethod
126135
def _evaluate_from_data_file(self, data_file_path: str) -> dict: ...

chebai/ensemble/_wrappers/_neural_network.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from pathlib import Path
23
from typing import Type
34

45
import torch
@@ -90,13 +91,18 @@ def _load_model_(self, input_dim: int | None) -> ChebaiBaseNet:
9091
self._model_ckpt_path, input_dim=5
9192
)
9293
except Exception as e:
93-
raise RuntimeError(f"Error loading model {self._model_name} \n Error: {e}")
94+
raise RuntimeError(
95+
f"Error loading model {self._model_name} \n Error: {e}"
96+
) from e
9497

98+
assert isinstance(
99+
model, ChebaiBaseNet
100+
), f"{model} is not a ChebaiBaseNet instance."
95101
model.eval()
96102
model.freeze()
97103
return model
98104

99-
def _predict_from_list_of_smiles(self, smiles_list) -> list:
105+
def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> list:
100106
token_dicts = []
101107
could_not_parse = []
102108
index_map = dict()
@@ -131,16 +137,16 @@ def _read_smiles(self, smiles):
131137
return self._reader.to_data(dict(features=smiles, labels=None))
132138

133139
def _forward_pass(self, batch):
134-
processable_data = self._model._process_batch(
140+
processable_data = self._model._process_batch( # noqa
135141
self._collator(batch).to(self._device), 0
136142
)
137143
return self._model(processable_data, **processable_data["model_kwargs"])
138144

139-
def _predict_from_data_file(
140-
self, processed_dir_main: str, data_file_name="data.pt"
145+
def _evaluate_from_data_file(
146+
self, data_processed_dir_main: Path, data_file_name="data.pt"
141147
) -> list:
142148
data = torch.load(
143-
os.path.join(processed_dir_main, self._reader.name(), data_file_name),
149+
data_processed_dir_main / self._reader.name() / data_file_name,
144150
weights_only=False,
145151
map_location=self._device,
146152
)

0 commit comments

Comments
 (0)