Skip to content

Commit ed92ac5

Browse files
committed
each model's each label has TPV, FPV
1 parent eb1798c commit ed92ac5

File tree

1 file changed

+55
-91
lines changed

1 file changed

+55
-91
lines changed

chebai/models/ensemble.py

Lines changed: 55 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,9 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
6161
AttributeError: If required keys are missing in the configuration.
6262
ValueError: If there are duplicate model paths or class paths.
6363
"""
64-
path_set = set()
65-
class_set = set()
66-
labels_set = set()
64+
path_set, class_set, labels_set = set(), set(), set()
6765

68-
sets_ = {"path": path_set, "class": class_set, "labels": labels_set}
69-
required_keys = {"class_path", "ckpt_path"}
66+
required_keys = {"class_path", "ckpt_path", "labels_path"}
7067

7168
for model_name, config in model_configs.items():
7269
missing_keys = required_keys - config.keys()
@@ -78,37 +75,26 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
7875

7976
model_path = config["ckpt_path"]
8077
class_path = config["class_path"]
78+
labels_path = config["labels_path"]
8179

8280
if model_path in path_set:
8381
raise ValueError(
84-
f"Duplicate model path detected: '{model_path}'. Each model must have a unique path."
82+
f"Duplicate model path detected: '{model_path}'. Each model must have a unique model-checkpoint path."
8583
)
8684

87-
if class_path not in class_set:
85+
if class_path in class_set:
8886
raise ValueError(
89-
f"Duplicate class path detected: '{class_path}'. Each model must have a unique path."
87+
f"Duplicate class path detected: '{class_path}'. Each model must have a unique class path."
88+
)
89+
90+
if labels_path in labels_set:
91+
raise ValueError(
92+
f"Duplicate labels path: {labels_path}. Each model must have unique labels path."
9093
)
9194

9295
path_set.add(model_path)
9396
class_set.add(class_path)
94-
95-
cls._extra_validation(model_name, config, sets_)
96-
97-
@classmethod
98-
def _extra_validation(
99-
cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set]
100-
):
101-
"""
102-
Perform extra validation on the model configuration, if necessary.
103-
104-
This method can be extended by subclasses to add additional validation logic.
105-
106-
Args:
107-
model_name (str): The name of the model.
108-
config (Dict[str, Any]): The configuration dictionary for the model.
109-
sets_ (Dict[str, set]): A dictionary of sets to track model paths, class paths, and labels.
110-
"""
111-
pass
97+
labels_path.add(labels_path)
11298

11399
def _load_ensemble_models(self):
114100
"""
@@ -122,6 +108,7 @@ def _load_ensemble_models(self):
122108
for model_name in self.model_configs:
123109
model_ckpt_path = self.model_configs[model_name]["ckpt_path"]
124110
model_class_path = self.model_configs[model_name]["class_path"]
111+
model_labels_path = self.model_configs[model_name]["labels_path"]
125112
if not os.path.exists(model_ckpt_path):
126113
raise FileNotFoundError(
127114
f"Model path '{model_ckpt_path}' for '{model_name}' does not exist."
@@ -134,10 +121,15 @@ def _load_ensemble_models(self):
134121
module = importlib.import_module(module_path)
135122
lightning_cls: LightningModule = getattr(module, class_name)
136123

137-
model = lightning_cls.load_from_checkpoint(model_ckpt_path)
124+
model = lightning_cls.load_from_checkpoint(
125+
model_ckpt_path, input_dim=self.input_dim
126+
)
138127
model.eval()
139128
model.freeze()
140129
self.models[model_name] = model
130+
self.models_configs[model_name]["labels"] = self._load_model_labels(
131+
model_labels_path
132+
)
141133

142134
except ModuleNotFoundError:
143135
print(f"Module '{module_path}' not found!")
@@ -149,21 +141,37 @@ def _load_ensemble_models(self):
149141
f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}"
150142
)
151143

152-
def _load_data_module_labels(self):
153-
"""
154-
Loads the label mapping from the classes.txt file for loaded data.
144+
@staticmethod
145+
def _load_model_labels(labels_path: str) -> Dict[str, float]:
146+
if not os.path.exists(labels_path):
147+
raise FileNotFoundError(f"{labels_path} does not exist.")
155148

156-
Raises:
157-
FileNotFoundError: If the classes.txt file does not exist.
158-
"""
159-
classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt")
160-
if not os.path.exists(classes_txt_file):
161-
raise FileNotFoundError(f"{classes_txt_file} does not exist")
162-
else:
163-
with open(classes_txt_file, "r") as f:
164-
for line in f:
165-
if line.strip() not in self.dm_labels:
166-
self.dm_labels[line.strip()] = len(self.dm_labels)
149+
if not labels_path.endswith(".json"):
150+
raise TypeError(f"{labels_path} is not a JSON file.")
151+
152+
with open(labels_path, "r") as f:
153+
model_labels = json.load(f)
154+
155+
labels_dict = {}
156+
for label, label_dict in model_labels.items():
157+
msg = f"for label {label}"
158+
if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys():
159+
raise AttributeError(f"Missing keys 'TPV' and/or 'FPV' {msg}")
160+
161+
# Validate 'tpv' and 'fpv' are either floats or convertible to float
162+
for key in ["TPV", "FPV"]:
163+
try:
164+
value = float(label_dict[key])
165+
if value < 0:
166+
raise ValueError(
167+
f"'{key}' must be non-negative but got {value} {msg}"
168+
)
169+
except (TypeError, ValueError):
170+
raise ValueError(
171+
f"'{key}' must be a float or convertible to float, but got {label_dict[key]} {msg}"
172+
)
173+
labels_dict[label][key] = value
174+
return labels_dict
167175

168176
@abstractmethod
169177
def _get_prediction_and_labels(
@@ -182,6 +190,12 @@ def _get_prediction_and_labels(
182190
"""
183191
pass
184192

193+
def controller(self):
194+
pass
195+
196+
def consolidator(self):
197+
pass
198+
185199

186200
class ChebiEnsemble(_EnsembleBase):
187201
"""
@@ -212,56 +226,6 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs):
212226
self._num_models_per_label: Optional[torch.Tensor] = None
213227
self._generate_model_label_mask()
214228

215-
@classmethod
216-
def _extra_validation(
217-
cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set]
218-
):
219-
"""
220-
Additional validation for the ensemble model configuration.
221-
222-
Args:
223-
model_name (str): The model name.
224-
config (Dict[str, Any]): The configuration dictionary.
225-
sets_ (Dict[str, set]): The set of paths for labels.
226-
227-
Raises:
228-
AttributeError: If the 'labels_path' key is missing.
229-
ValueError: If the 'labels_path' contains duplicate entries or certain are not convertible to float.
230-
"""
231-
if "labels_path" not in config:
232-
raise AttributeError("Missing 'labels_path' key in config!")
233-
234-
labels_path = config["labels_path"]
235-
if labels_path not in sets_["labels"]:
236-
raise ValueError(
237-
f"Duplicate labels path detected: '{labels_path}'. Each model must have a unique path."
238-
)
239-
240-
sets_["labels"].add(labels_path)
241-
242-
with open(labels_path, "r") as f:
243-
model_labels = json.load(f)
244-
245-
for label, label_dict in model_labels.items():
246-
247-
if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys():
248-
raise AttributeError(
249-
f"Missing keys 'TPV' and/or 'FPV' in model '{model_name}' configuration."
250-
)
251-
252-
# Validate 'tpv' and 'fpv' are either floats or convertible to float
253-
for key in ["TPV", "FPV"]:
254-
try:
255-
value = float(label_dict[key])
256-
if value < 0:
257-
raise ValueError(
258-
f"'{key}' in model '{model_name}' and label '{label}' must be non-negative, but got {value}."
259-
)
260-
except (TypeError, ValueError):
261-
raise ValueError(
262-
f"'{key}' in model '{model_name}' and label '{label}' must be a float or convertible to float, but got {label_dict[key]}."
263-
)
264-
265229
def _generate_model_label_mask(self):
266230
"""
267231
Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values

0 commit comments

Comments
 (0)