Skip to content

Commit 9b851c5

Browse files
committed
ensemble: update for tpv/fpv value for each label
1 parent 26f5ab4 commit 9b851c5

File tree

2 files changed

+102
-53
lines changed

2 files changed

+102
-53
lines changed

chebai/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
4848
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
4949
)
5050

51+
parser.link_arguments(
52+
"data.processed_dir_main",
53+
"model.init_args.data_processed_dir_main",
54+
apply_on="instantiate",
55+
)
56+
5157
@staticmethod
5258
def subcommands() -> Dict[str, Set[str]]:
5359
"""

chebai/models/ensemble.py

Lines changed: 96 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib
2+
import json
23
import os.path
34
from abc import ABC, abstractmethod
45
from typing import Any, Dict, Optional, Tuple, Union
@@ -12,42 +13,19 @@
1213

1314

1415
class _EnsembleBase(ChebaiBaseNet, ABC):
15-
def __init__(self, model_configs: Dict[str, Dict], **kwargs):
16+
def __init__(
17+
self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs
18+
):
1619
super().__init__(**kwargs)
1720
self._validate_model_configs(model_configs)
1821

22+
self.data_processed_dir_main = data_processed_dir_main
1923
self.models: Dict[str, LightningModule] = {}
2024
self.model_configs = model_configs
25+
self.dm_labels: Dict[str, int] = {}
2126

22-
for model_name in self.model_configs:
23-
model_ckpt_path = self.model_configs[model_name]["ckpt_path"]
24-
model_class_path = self.model_configs[model_name]["class_path"]
25-
if not os.path.exists(model_ckpt_path):
26-
raise FileNotFoundError(
27-
f"Model path '{model_ckpt_path}' for '{model_name}' does not exist."
28-
)
29-
30-
class_name = model_class_path.split(".")[-1]
31-
module_path = ".".join(model_class_path.split(".")[:-1])
32-
33-
try:
34-
module = importlib.import_module(module_path)
35-
lightning_cls: LightningModule = getattr(module, class_name)
36-
37-
model = lightning_cls.load_from_checkpoint(model_ckpt_path)
38-
model.eval()
39-
model.freeze()
40-
self.models[model_name] = model
41-
42-
except ModuleNotFoundError:
43-
print(f"Module '{module_path}' not found!")
44-
except AttributeError:
45-
print(f"Class '{class_name}' not found in '{module_path}'!")
46-
47-
except Exception as e:
48-
raise RuntimeError(
49-
f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}"
50-
)
27+
self._load_data_module_labels()
28+
self._load_ensemble_models()
5129

5230
# TODO: Later discuss whether this threshold should be independent of metric threshold or not ?
5331
# if kwargs.get("threshold") is None:
@@ -98,6 +76,47 @@ def _extra_validation(
9876
):
9977
pass
10078

79+
def _load_ensemble_models(self):
80+
for model_name in self.model_configs:
81+
model_ckpt_path = self.model_configs[model_name]["ckpt_path"]
82+
model_class_path = self.model_configs[model_name]["class_path"]
83+
if not os.path.exists(model_ckpt_path):
84+
raise FileNotFoundError(
85+
f"Model path '{model_ckpt_path}' for '{model_name}' does not exist."
86+
)
87+
88+
class_name = model_class_path.split(".")[-1]
89+
module_path = ".".join(model_class_path.split(".")[:-1])
90+
91+
try:
92+
module = importlib.import_module(module_path)
93+
lightning_cls: LightningModule = getattr(module, class_name)
94+
95+
model = lightning_cls.load_from_checkpoint(model_ckpt_path)
96+
model.eval()
97+
model.freeze()
98+
self.models[model_name] = model
99+
100+
except ModuleNotFoundError:
101+
print(f"Module '{module_path}' not found!")
102+
except AttributeError:
103+
print(f"Class '{class_name}' not found in '{module_path}'!")
104+
105+
except Exception as e:
106+
raise RuntimeError(
107+
f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}"
108+
)
109+
110+
def _load_data_module_labels(self):
111+
classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt")
112+
if not os.path.exists(classes_txt_file):
113+
raise FileNotFoundError(f"{classes_txt_file} does not exist")
114+
else:
115+
with open(classes_txt_file, "r") as f:
116+
for line in f:
117+
if line.strip() not in self.dm_labels:
118+
self.dm_labels[line.strip()] = len(self.dm_labels)
119+
101120
@abstractmethod
102121
def _get_prediction_and_labels(
103122
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
@@ -132,49 +151,73 @@ def _extra_validation(
132151
# )
133152
sets_["labels"].add(labels_path)
134153

135-
if "TPV" not in config.keys() or "FPV" not in config.keys():
136-
raise AttributeError(
137-
f"Missing keys 'TPV' and/or 'FPV' in model '{model_name}' configuration."
138-
)
154+
with open(labels_path, "r") as f:
155+
model_labels = json.load(f)
139156

140-
# Validate 'tpv' and 'fpv' are either floats or convertible to float
141-
for key in ["TPV", "FPV"]:
142-
try:
143-
value = float(config[key])
144-
if value < 0:
157+
for label, label_dict in model_labels.items():
158+
159+
if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys():
160+
raise AttributeError(
161+
f"Missing keys 'TPV' and/or 'FPV' in model '{model_name}' configuration."
162+
)
163+
164+
# Validate 'tpv' and 'fpv' are either floats or convertible to float
165+
for key in ["TPV", "FPV"]:
166+
try:
167+
value = float(label_dict[key])
168+
if value < 0:
169+
raise ValueError(
170+
f"'{key}' in model '{model_name}' and label '{label}' must be non-negative, but got {value}."
171+
)
172+
except (TypeError, ValueError):
145173
raise ValueError(
146-
f"'{key}' in model '{model_name}' must be non-negative, but got {value}."
174+
f"'{key}' in model '{model_name}' and label '{label}' must be a float or convertible to float, but got {label_dict[key]}."
147175
)
148-
except (TypeError, ValueError):
149-
raise ValueError(
150-
f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}."
151-
)
152176

153177
def _generate_model_label_mask(self):
154-
labels_dict = {}
155178
num_models_per_label = torch.zeros(1, self.out_dim, device=self.device)
179+
156180
for model_name, model_config in self.model_configs.items():
157181
labels_path = model_config["labels_path"]
158182
if not os.path.exists(labels_path):
159183
raise FileNotFoundError(f"Labels path '{labels_path}' does not exist.")
160184

161185
with open(labels_path, "r") as f:
162-
labels_list = [int(line.strip()) for line in f]
186+
labels_dict = json.load(f)
163187

164-
model_label_indices = []
165-
for label in labels_list:
166-
if label not in labels_dict:
167-
labels_dict[label] = len(labels_dict)
188+
model_label_indices, tpv_label_values, fpv_label_values = [], [], []
189+
for label in labels_dict.keys():
190+
if label in self.dm_labels:
191+
model_label_indices.append(self.dm_labels[label])
192+
tpv_label_values.append(float(labels_dict[label]["TPV"]))
193+
fpv_label_values.append(float(labels_dict[label]["FPV"]))
168194

169-
model_label_indices.append(labels_dict[label])
195+
if not all([model_label_indices, tpv_label_values, fpv_label_values]):
196+
raise ValueError(f"Values are empty for labels of model {model_name}")
170197

171198
# Create masks to apply predictions only to known classes
172199
mask = torch.zeros(self.out_dim, device=self.device, dtype=torch.bool)
173200
mask[
174201
torch.tensor(model_label_indices, dtype=torch.int, device=self.device)
175202
] = True
176203

204+
tpv_tensor = torch.full_like(
205+
mask, -1, dtype=torch.float, device=self.device
206+
)
207+
fpv_tensor = torch.full_like(
208+
mask, -1, dtype=torch.float, device=self.device
209+
)
210+
211+
tpv_tensor[mask] = torch.tensor(
212+
tpv_label_values, dtype=torch.float, device=self.device
213+
)
214+
fpv_tensor[mask] = torch.tensor(
215+
fpv_label_values, dtype=torch.float, device=self.device
216+
)
217+
177218
self.model_configs[model_name]["labels_mask"] = mask
219+
self.model_configs[model_name]["tpv_tensor"] = tpv_tensor
220+
self.model_configs[model_name]["fpv_tensor"] = fpv_tensor
178221
num_models_per_label += mask
179222

180223
self._num_models_per_label = num_models_per_label
@@ -312,8 +355,8 @@ def aggregate_predictions(self, predictions, confidences):
312355
false_scores = torch.zeros(batch_size, num_classes, device=self.device)
313356

314357
for model, conf in confidences.items():
315-
tpv = float(self.model_configs[model]["TPV"])
316-
npv = float(self.model_configs[model]["FPV"])
358+
tpv = self.model_configs[model]["tpv_tensor"]
359+
npv = self.model_configs[model]["fpv_tensor"]
317360

318361
# Determine which classes the model provides predictions for
319362
mask = self.model_configs[model]["labels_mask"]

0 commit comments

Comments
 (0)