Skip to content

Commit c48bfd2

Browse files
committed
update base for wrapper
1 parent f812cd7 commit c48bfd2

File tree

1 file changed

+30
-31
lines changed

1 file changed

+30
-31
lines changed

chebai/ensemble/_base.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,20 @@
77

88
import torch
99
from lightning import LightningModule
10-
from lightning_utilities.core.rank_zero import rank_zero_info
1110

1211
from chebai.models import ChebaiBaseNet
1312
from chebai.preprocessing.structures import XYData
1413
from chebai.result.classification import print_metrics
1514

16-
from ._constants import *
15+
from ._constants import (
16+
MODEL_CKPT_PATH,
17+
MODEL_CLS_PATH,
18+
MODEL_LBL_PATH,
19+
READER_CLS_PATH,
20+
WRAPPER_CLS_PATH,
21+
)
22+
from ._utils import _load_class
23+
from ._wrappers import BaseWrapper
1724

1825

1926
class EnsembleBase(ABC):
@@ -41,18 +48,17 @@ def __init__(
4148
if bool(kwargs.get("_validate_configs", True)):
4249
self._validate_model_configs(model_configs)
4350

44-
self.model_configs: Dict[str, Dict[str, Any]] = model_configs
45-
self.data_processed_dir_main: str = data_processed_dir_main
46-
self.input_dim: Optional[int] = kwargs.get("input_dim", None)
51+
self._model_configs: Dict[str, Dict[str, Any]] = model_configs
52+
self._data_processed_dir_main: str = data_processed_dir_main
53+
self._input_dim: Optional[int] = kwargs.get("input_dim", None)
54+
self._total_data_size: int = len(self._collated_data)
4755

4856
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49-
self._num_of_labels: Optional[int] = (
50-
None # will be set by `_load_data_module_labels` method
51-
)
57+
5258
self._models: Dict[str, LightningModule] = {}
53-
self._dm_labels: Dict[str, int] = {}
59+
self._dm_labels: Dict[str, int] = self._load_data_module_labels()
60+
self._num_of_labels: int = len(self._dm_labels)
5461

55-
self._load_data_module_labels()
5662
self._num_models_per_label: torch.Tensor = torch.zeros(
5763
1, self._num_of_labels, device=self._device
5864
)
@@ -72,13 +78,11 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No
7278
AttributeError: If any model config is missing required keys.
7379
ValueError: If duplicate paths are found for model checkpoint, class, or labels.
7480
"""
75-
path_set, class_set, labels_set = set(), set(), set()
81+
class_set, labels_set = set(), set()
7682
required_keys = {
77-
MODEL_CKPT_PATH,
7883
MODEL_CLS_PATH,
7984
MODEL_LBL_PATH,
8085
WRAPPER_CLS_PATH,
81-
READER_CLS_PATH,
8286
}
8387

8488
for model_name, config in model_configs.items():
@@ -88,44 +92,41 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No
8892
f"Missing keys {missing_keys} in model '{model_name}' configuration."
8993
)
9094

91-
model_ckpt_path, model_class_path, model_labels_path = (
92-
config[MODEL_CKPT_PATH],
95+
model_class_path, model_labels_path = (
9396
config[MODEL_CLS_PATH],
9497
config[MODEL_LBL_PATH],
9598
)
9699

97-
if model_ckpt_path in path_set:
98-
raise ValueError(f"Duplicate model path detected: '{model_ckpt_path}'.")
99100
if model_class_path in class_set:
100101
raise ValueError(
101102
f"Duplicate class path detected: '{model_class_path}'."
102103
)
103104
if model_labels_path in labels_set:
104105
raise ValueError(f"Duplicate labels path: {model_labels_path}.")
105106

106-
path_set.add(model_ckpt_path)
107107
class_set.add(model_class_path)
108108
labels_set.add(model_labels_path)
109109

110-
def _load_data_module_labels(self) -> None:
110+
def _load_data_module_labels(self) -> dict[str, int]:
111111
"""
112112
Loads class labels from the classes.txt file and sets internal label mapping.
113113
114114
Raises:
115115
FileNotFoundError: If the expected classes.txt file is not found.
116116
"""
117-
classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt")
118-
rank_zero_info(f"Loading {classes_txt_file} ....")
117+
classes_txt_file = os.path.join(self._data_processed_dir_main, "classes.txt")
118+
print(f"Loading {classes_txt_file} ....")
119119

120120
if not os.path.exists(classes_txt_file):
121121
raise FileNotFoundError(f"{classes_txt_file} does not exist")
122122

123+
dm_labels_dict = {}
123124
with open(classes_txt_file, "r") as f:
124125
for line in f:
125126
label = line.strip()
126-
if label not in self._dm_labels:
127-
self._dm_labels[label] = len(self._dm_labels)
128-
self._num_of_labels = len(self._dm_labels)
127+
if label not in dm_labels_dict:
128+
dm_labels_dict[label] = len(dm_labels_dict)
129+
return dm_labels_dict
129130

130131
def run_ensemble(self) -> None:
131132
"""
@@ -140,22 +141,20 @@ def run_ensemble(self) -> None:
140141

141142
while self._model_queue:
142143
model_name = self._model_queue.popleft()
143-
rank_zero_info(f"Processing model: {model_name}")
144-
model, model_props = self._load_model_and_its_props(model_name)
144+
print(f"Processing model: {model_name}")
145145

146-
rank_zero_info("\t Passing model to controller to generate predictions...")
147-
pred_conf_dict = self._controller(model, model_props)
148-
del model # Model can be huge to keep it in memory, delete as no longer needed
146+
print("\t Passing model to controller to generate predictions...")
147+
pred_conf_dict, model_props = self._controller(model_name)
149148

150-
rank_zero_info("\t Passing predictions to consolidator for aggregation...")
149+
print("\t Passing predictions to consolidator for aggregation...")
151150
self._consolidator(
152151
pred_conf_dict,
153152
model_props,
154153
true_scores=true_scores,
155154
false_scores=false_scores,
156155
)
157156

158-
rank_zero_info(f"Consolidating predictions for {self.__class__.__name__}")
157+
print(f"Consolidating predictions for {self.__class__.__name__}")
159158
final_preds = self._consolidate_on_finish(
160159
true_scores=true_scores, false_scores=false_scores
161160
)

0 commit comments

Comments
 (0)