Skip to content

Commit 7db384a

Browse files
committed
add ensemble base to new python dir
1 parent dabe5ff commit 7db384a

File tree

2 files changed

+260
-0
lines changed

2 files changed

+260
-0
lines changed

chebai/ensemble/__init__.py

Whitespace-only changes.

chebai/ensemble/base.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
import importlib
2+
import json
3+
import os
4+
from abc import ABC, abstractmethod
5+
from collections import deque
6+
from typing import Deque, Dict, Optional
7+
8+
import torch
9+
from lightning import LightningModule
10+
11+
from chebai.models import ChebaiBaseNet
12+
13+
14+
class EnsembleBase(ABC):
15+
"""
16+
Base class for ensemble models in the Chebai framework.
17+
18+
Inherits from ChebaiBaseNet and provides functionality to load multiple models,
19+
validate configuration, and manage predictions.
20+
21+
Attributes:
22+
data_processed_dir_main (str): Directory where the processed data is stored.
23+
models (Dict[str, LightningModule]): A dictionary of loaded models.
24+
model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble.
25+
dm_labels (Dict[str, int]): Mapping of label names to integer indices.
26+
"""
27+
28+
def __init__(
29+
self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs
30+
):
31+
"""
32+
Initializes the ensemble model and loads configuration, models, and labels.
33+
34+
Args:
35+
model_configs (Dict[str, Dict]): Dictionary of model configurations.
36+
data_processed_dir_main (str): Path to the processed data directory.
37+
**kwargs: Additional arguments for initialization.
38+
"""
39+
if kwargs.get("_validate_configs", False):
40+
self._validate_model_configs(model_configs)
41+
42+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
43+
self.input_dim = kwargs.get("input_dim", None)
44+
self.num_of_labels: Optional[int] = (
45+
None # will be set by `_load_data_module_labels` method
46+
)
47+
self.data_processed_dir_main = data_processed_dir_main
48+
self.models: Dict[str, LightningModule] = {}
49+
self.model_configs = model_configs
50+
self.dm_labels: Dict[str, int] = {}
51+
52+
self._load_data_module_labels()
53+
self._num_models_per_label: torch.Tensor = torch.zeros(
54+
1, self.num_of_labels, device=self.device
55+
)
56+
self._model_queue: Deque = deque()
57+
self._collated_data = None
58+
59+
@classmethod
60+
def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
61+
"""
62+
Validates the model configurations to ensure required keys are present.
63+
64+
Args:
65+
model_configs (Dict[str, Dict]): Dictionary of model configurations.
66+
67+
Raises:
68+
AttributeError: If required keys are missing in the configuration.
69+
ValueError: If there are duplicate model paths or class paths.
70+
"""
71+
path_set, class_set, labels_set = set(), set(), set()
72+
73+
required_keys = {"class_path", "ckpt_path", "labels_path"}
74+
75+
for model_name, config in model_configs.items():
76+
missing_keys = required_keys - config.keys()
77+
78+
if missing_keys:
79+
raise AttributeError(
80+
f"Missing keys {missing_keys} in model '{model_name}' configuration."
81+
)
82+
83+
model_path = config["ckpt_path"]
84+
class_path = config["class_path"]
85+
labels_path = config["labels_path"]
86+
87+
if model_path in path_set:
88+
raise ValueError(
89+
f"Duplicate model path detected: '{model_path}'. "
90+
f"Each model must have a unique model-checkpoint path."
91+
)
92+
93+
if class_path in class_set:
94+
raise ValueError(
95+
f"Duplicate class path detected: '{class_path}'. Each model must have a unique class path."
96+
)
97+
98+
if labels_path in labels_set:
99+
raise ValueError(
100+
f"Duplicate labels path: {labels_path}. Each model must have unique labels path."
101+
)
102+
103+
path_set.add(model_path)
104+
class_set.add(class_path)
105+
labels_set.add(labels_path)
106+
107+
def _load_data_module_labels(self):
108+
"""
109+
Loads the label mapping from the classes.txt file for loaded data.
110+
111+
Raises:
112+
FileNotFoundError: If the classes.txt file does not exist.
113+
"""
114+
classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt")
115+
if not os.path.exists(classes_txt_file):
116+
raise FileNotFoundError(f"{classes_txt_file} does not exist")
117+
else:
118+
with open(classes_txt_file, "r") as f:
119+
for line in f:
120+
if line.strip() not in self.dm_labels:
121+
self.dm_labels[line.strip()] = len(self.dm_labels)
122+
self.num_of_labels = len(self.dm_labels)
123+
124+
def run_ensemble(self):
125+
batch_size = 10
126+
true_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device)
127+
false_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device)
128+
129+
while self._model_queue:
130+
model, model_props = self._load_model_and_its_props(
131+
self._model_queue.popleft()
132+
)
133+
pred_conf_dict = self._controller(model, model_props)
134+
self._consolidator(
135+
pred_conf_dict,
136+
model_props,
137+
true_scores=true_scores,
138+
false_scores=false_scores,
139+
)
140+
141+
self._consolidate_on_finish(true_scores=true_scores, false_scores=false_scores)
142+
143+
def _load_model_and_its_props(self, model_name):
144+
"""
145+
Loads the models specified in the configuration and initializes them.
146+
"""
147+
model_ckpt_path = self.model_configs[model_name]["ckpt_path"]
148+
model_class_path = self.model_configs[model_name]["class_path"]
149+
model_labels_path = self.model_configs[model_name]["labels_path"]
150+
if not os.path.exists(model_ckpt_path):
151+
raise FileNotFoundError(
152+
f"Model path '{model_ckpt_path}' for '{model_name}' does not exist."
153+
)
154+
155+
class_name = model_class_path.split(".")[-1]
156+
module_path = ".".join(model_class_path.split(".")[:-1])
157+
module = importlib.import_module(module_path)
158+
lightning_cls: LightningModule = getattr(module, class_name)
159+
assert isinstance(lightning_cls, type), f"{class_name} is not a class."
160+
assert issubclass(
161+
lightning_cls, ChebaiBaseNet
162+
), f"{class_name} must inherit from ChebaiBaseNet"
163+
164+
model = lightning_cls.load_from_checkpoint(
165+
model_ckpt_path, input_dim=self.input_dim
166+
)
167+
model.eval()
168+
model.freeze()
169+
170+
model_label_props = self._generate_model_label_props(
171+
model_name, model_labels_path
172+
)
173+
174+
return model, model_label_props
175+
176+
def _generate_model_label_props(self, model_name: str, labels_path: str):
177+
"""
178+
Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values
179+
as tensors.
180+
181+
Raises:
182+
FileNotFoundError: If the labels path does not exist.
183+
ValueError: If label values are empty for any model.
184+
"""
185+
labels_dict = self._load_model_labels(labels_path)
186+
187+
model_label_indices, tpv_label_values, fpv_label_values = [], [], []
188+
for label in labels_dict.keys():
189+
if label in self.dm_labels:
190+
try:
191+
self._validate_model_labels_json_element(labels_dict[label])
192+
except Exception as e:
193+
raise Exception(f"Label '{label}' has an unexpected error: {e}")
194+
195+
model_label_indices.append(self.dm_labels[label])
196+
tpv_label_values.append(labels_dict[label]["TPV"])
197+
fpv_label_values.append(labels_dict[label]["FPV"])
198+
199+
if not all([model_label_indices, tpv_label_values, fpv_label_values]):
200+
raise ValueError(f"Values are empty for labels of model {model_name}")
201+
202+
# Create masks to apply predictions only to known classes
203+
mask = torch.zeros(self.num_of_labels, device=self.device, dtype=torch.bool)
204+
mask[torch.tensor(model_label_indices, dtype=torch.int, device=self.device)] = (
205+
True
206+
)
207+
208+
tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self.device)
209+
fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self.device)
210+
211+
tpv_tensor[mask] = torch.tensor(
212+
tpv_label_values, dtype=torch.float, device=self.device
213+
)
214+
fpv_tensor[mask] = torch.tensor(
215+
fpv_label_values, dtype=torch.float, device=self.device
216+
)
217+
self._num_models_per_label += mask
218+
return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor}
219+
220+
@staticmethod
221+
def _load_model_labels(labels_path: str) -> Dict[str, Dict[str, float]]:
222+
if not os.path.exists(labels_path):
223+
raise FileNotFoundError(f"{labels_path} does not exist.")
224+
225+
if not labels_path.endswith(".json"):
226+
raise TypeError(f"{labels_path} is not a JSON file.")
227+
228+
with open(labels_path, "r") as f:
229+
model_labels = json.load(f)
230+
return model_labels
231+
232+
@staticmethod
233+
def _validate_model_labels_json_element(label_dict: Dict[str, float]):
234+
if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys():
235+
raise AttributeError(f"Missing keys 'TPV' and/or 'FPV'")
236+
237+
# Validate 'tpv' and 'fpv' are either floats or convertible to float
238+
for key in ["TPV", "FPV"]:
239+
try:
240+
value = float(label_dict[key])
241+
if value < 0:
242+
raise ValueError(f"'{key}' must be non-negative but got {value}")
243+
except (TypeError, ValueError):
244+
raise ValueError(
245+
f"'{key}' must be a float or convertible to float, but got {label_dict[key]}"
246+
)
247+
248+
@abstractmethod
249+
def _controller(self, model, model_props, **kwargs):
250+
pass
251+
252+
@abstractmethod
253+
def _consolidator(
254+
self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs
255+
):
256+
pass
257+
258+
@abstractmethod
259+
def _consolidate_on_finish(self, *, true_scores, false_scores):
260+
pass

0 commit comments

Comments
 (0)