Skip to content

Commit a20ce76

Browse files
committed
store collated label or any model in instance var
1 parent 95d49c1 commit a20ce76

File tree

5 files changed

+49
-45
lines changed

5 files changed

+49
-45
lines changed

chebai/ensemble/_base.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from collections import deque
33
from pathlib import Path
4-
from typing import Any, Deque, Dict, Literal, Optional
4+
from typing import Any, Deque, Dict, Optional
55

66
import pandas as pd
77
import torch
@@ -38,7 +38,6 @@ def __init__(
3838
Args:
3939
model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations.
4040
data_processed_dir_main (str): Path to the processed data directory.
41-
reader_dir_name (str): Name of the directory used by the reader. Defaults to 'smiles_token'.
4241
**kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'.
4342
"""
4443
if bool(kwargs.get("_perform_validation_checks", True)):
@@ -51,16 +50,15 @@ def __init__(
5150
self._operation: str = operation
5251
print(f"Ensemble operation: {self._operation}")
5352

54-
self._input_dim: Optional[int] = kwargs.get("input_dim", None)
55-
self._total_data_size: int = None
53+
# These instance variable will be set in method `_process_input_to_ensemble`
54+
self._total_data_size: int | None = None
5655
self._ensemble_input: list[str] | Path = self._process_input_to_ensemble(
5756
**kwargs
5857
)
5958
print(f"Total data size (data.pkl) is {self._total_data_size}")
6059

6160
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6261

63-
self._models: Dict[str, LightningModule] = {}
6462
self._dm_labels: Dict[str, int] = self._load_data_module_labels()
6563
self._num_of_labels: int = len(self._dm_labels)
6664
print(f"Number of labes for this data is {self._num_of_labels} ")
@@ -69,6 +67,7 @@ def __init__(
6967
1, self._num_of_labels, device=self._device
7068
)
7169
self._model_queue: Deque[str] = deque()
70+
self._collated_labels: torch.Tensor | None = None
7271

7372
@classmethod
7473
def _perform_validation_checks(
@@ -126,10 +125,10 @@ def _perform_validation_checks(
126125
class_set.add(model_class_path)
127126
labels_set.add(model_labels_path)
128127

129-
def _process_input_to_ensemble(self, **kwargs: any) -> list[str] | Path:
128+
def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path:
130129
if self._operation == PRED_OP:
131130
p = Path(kwargs["smiles_list_file_path"])
132-
smiles_list = []
131+
smiles_list: list[str] = []
133132
with open(p, "r") as f:
134133
for line in f:
135134
# Skip empty or whitespace-only lines
@@ -140,11 +139,14 @@ def _process_input_to_ensemble(self, **kwargs: any) -> list[str] | Path:
140139
self._total_data_size = len(smiles_list)
141140
return smiles_list
142141
elif self._operation == EVAL_OP:
143-
data_pkl_path = Path(self._data_processed_dir_main) / "data.pkl"
142+
processed_dir_path = Path(self._data_processed_dir_main)
143+
data_pkl_path = processed_dir_path / "data.pkl"
144144
if not data_pkl_path.exists():
145-
raise FileNotFoundError()
145+
raise FileNotFoundError(
146+
f"data.pkl does not exist in the {processed_dir_path} directory"
147+
)
146148
self._total_data_size = len(pd.read_pickle(data_pkl_path))
147-
return p
149+
return processed_dir_path
148150
else:
149151
raise ValueError("Invalid operation")
150152

@@ -180,6 +182,9 @@ def run_ensemble(self) -> None:
180182
self._total_data_size, self._num_of_labels, device=self._device
181183
)
182184

185+
print(
186+
f"Running {self.__class__.__name__} ensemble for {self._operation} operation..."
187+
)
183188
while self._model_queue:
184189
model_name = self._model_queue.popleft()
185190
print(f"Processing model: {model_name}")
@@ -195,16 +200,17 @@ def run_ensemble(self) -> None:
195200
false_scores=false_scores,
196201
)
197202

198-
print(f"Consolidating predictions for {self.__class__.__name__}")
199203
final_preds = self._consolidate_on_finish(
200204
true_scores=true_scores, false_scores=false_scores
201205
)
202-
print_metrics(
203-
final_preds,
204-
self._collated_data.y,
205-
self._device,
206-
classes=list(self._dm_labels.keys()),
207-
)
206+
207+
if self._operation == EVAL_OP:
208+
print_metrics(
209+
final_preds,
210+
self._collated_labels,
211+
self._device,
212+
classes=list(self._dm_labels.keys()),
213+
)
208214

209215
@abstractmethod
210216
def _controller(

chebai/ensemble/_controller.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from abc import ABC, abstractmethod
1+
from abc import ABC
22
from collections import deque
3+
from pathlib import Path
34
from typing import Any, Deque, Dict
45

56
import torch
67
from torch import Tensor
78

89
from ._base import EnsembleBase
910
from ._constants import EVAL_OP, PRED_OP, WRAPPER_CLS_PATH
10-
from ._utils import _load_class
11+
from ._utils import load_class
1112
from ._wrappers import BaseWrapper
1213

1314

@@ -30,9 +31,16 @@ def __init__(self, **kwargs: Any):
3031
"""
3132
super().__init__(**kwargs)
3233
self._kwargs = kwargs
34+
# If an activation condition correponding model is added to queue, removed from this set
35+
# This is in order to avoid re-adding models that have already been processed
36+
self._model_key_set: set[str] = set(self._model_configs.keys())
3337

34-
@abstractmethod
35-
def _controller(self, model_name, model_input, **kwargs: Any) -> Dict[str, Tensor]:
38+
# Labels from any processed data.pt file for any reader
39+
self._collated_labels: torch.Tensor | None = None
40+
41+
def _controller(
42+
self, model_name: str, model_input: list[str] | Path, **kwargs: Any
43+
) -> Dict[str, Tensor]:
3644
"""
3745
Performs inference with the model and extracts predictions and confidence values.
3846
@@ -82,14 +90,17 @@ def _get_pred_conf_from_model_output(
8290

8391
def _wrap_model(self, model_name: str) -> BaseWrapper:
8492
model_config = self._model_configs[model_name]
85-
wrp_cls = _load_class(model_config[WRAPPER_CLS_PATH])
93+
wrp_cls = load_class(model_config[WRAPPER_CLS_PATH])
8694
assert issubclass(wrp_cls, BaseWrapper), ""
8795
wrapped_model = wrp_cls(
8896
model_name=model_name,
8997
model_config=model_config,
9098
dm_labels=self._dm_labels,
9199
**self._kwargs
92100
)
101+
if self._collated_labels is not None and self._operation == EVAL_OP:
102+
self._collated_labels = wrapped_model.collated_labels
103+
93104
assert isinstance(wrapped_model, BaseWrapper), ""
94105
return wrapped_model
95106

@@ -110,19 +121,3 @@ def __init__(self, **kwargs: Any):
110121
"""
111122
super().__init__(**kwargs)
112123
self._model_queue: Deque[str] = deque(list(self._model_configs.keys()))
113-
114-
def _controller(self, model_name, model_input, **kwargs: Any) -> Dict[str, Tensor]:
115-
"""
116-
Performs inference with the model and extracts predictions and confidence values.
117-
118-
Args:
119-
model (ChebaiBaseNet): The model to perform inference with.
120-
model_props (Dict[str, Tensor]): Dictionary with label mask and trust scores.
121-
122-
Returns:
123-
Dict[str, Tensor]: Dictionary containing predictions and confidence scores.
124-
"""
125-
126-
output_dict = super()._controller(model_name, model_input, **kwargs)
127-
# Some activation condition can be applied, not in this controller, so we return the output directly
128-
return output_dict

chebai/ensemble/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib
22

33

4-
def _load_class(class_path):
4+
def load_class(class_path: str) -> type:
55
module_path, class_name = class_path.rsplit(".", 1)
66
module = importlib.import_module(module_path)
77
return getattr(module, class_name)

chebai/ensemble/_wrappers/_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
self._model_class_path = self._model_config[MODEL_CLS_PATH]
2323
self._model_labels_path = self._model_config[MODEL_LBL_PATH]
2424
self._model_props = self._generate_model_label_props(dm_labels=dm_labels)
25+
self.collated_labels = None
2526

2627
def _generate_model_label_props(self, dm_labels) -> dict[str, torch.Tensor]:
2728
"""

chebai/ensemble/_wrappers/_neural_network.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
from chebai.models import ChebaiBaseNet
99
from chebai.preprocessing.reader import DataReader
10+
from chebai.preprocessing.structures import XYData
1011

1112
from .._constants import MODEL_CKPT_PATH, READER_CLS_PATH, READER_KWARGS
12-
from .._utils import _load_class
13+
from .._utils import load_class
1314
from ._base import BaseWrapper
1415

1516

@@ -29,10 +30,11 @@ def __init__(self, **kwargs):
2930
else dict()
3031
)
3132

32-
reader_cls: Type[DataReader] = _load_class(self._reader_class_path)
33+
reader_cls: Type[DataReader] = load_class(self._reader_class_path)
3334
assert issubclass(reader_cls, DataReader), ""
3435
self._reader = reader_cls(**self._reader_kwargs)
3536
self._collator = reader_cls.COLLATOR()
37+
self.collated_labels = None
3638
self._model: ChebaiBaseNet = self._load_model_(
3739
input_dim=kwargs.get("input_dim", None)
3840
)
@@ -80,7 +82,7 @@ def _load_model_(self, input_dim: int | None) -> ChebaiBaseNet:
8082
f"Model path '{self._model_ckpt_path}' for '{self._model_name}' does not exist."
8183
)
8284

83-
lightning_cls = _load_class(self._model_class_path)
85+
lightning_cls = load_class(self._model_class_path)
8486

8587
assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class."
8688
assert issubclass(
@@ -137,9 +139,9 @@ def _read_smiles(self, smiles):
137139
return self._reader.to_data(dict(features=smiles, labels=None))
138140

139141
def _forward_pass(self, batch):
140-
processable_data = self._model._process_batch( # noqa
141-
self._collator(batch).to(self._device), 0
142-
)
142+
collated_batch: XYData = self._collator(batch).to(self._device)
143+
self.collated_labels = collated_batch.y
144+
processable_data = self._model._process_batch(collated_batch, 0) # noqa
143145
return self._model(processable_data, **processable_data["model_kwargs"])
144146

145147
def _evaluate_from_data_file(

0 commit comments

Comments
 (0)