Skip to content

Commit f60b2d8

Browse files
committed
ensemble: code improvements
1 parent 7f892d9 commit f60b2d8

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

chebai/models/ensemble.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@ def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
1919

2020
for model_name in self.model_configs:
2121
model_path = self.model_configs[model_name]["path"]
22-
if os.path.exists(model_path):
22+
if not os.path.exists(model_path):
23+
raise FileNotFoundError(
24+
f"Model path '{model_path}' for '{model_name}' does not exist."
25+
)
26+
27+
# Attempt to load the model to check validity
28+
try:
2329
self.models[model_name] = Electra.load_from_checkpoint(
24-
model_path, map_location="cpu"
30+
model_path, map_location=self.device
2531
)
26-
else:
27-
raise FileNotFoundError(
28-
f"Model {model_name} does not exist in the given path {model_path}"
32+
except Exception as e:
33+
raise RuntimeError(
34+
f"Failed to load model '{model_name}' from {model_path}: {e}"
2935
)
3036

3137
for model in self.models.values():
@@ -70,10 +76,6 @@ def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]):
7076
)
7177

7278
model_path = config["path"]
73-
if not os.path.exists(model_path):
74-
raise FileNotFoundError(
75-
f"Model path '{model_path}' for '{model_name}' does not exist."
76-
)
7779

7880
# if model_path in path_set:
7981
# raise ValueError(
@@ -100,14 +102,13 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
100102
confidences = {}
101103
total_logits = torch.zeros(
102104
data["labels"].shape[0], data["labels"].shape[1], device=self.device
103-
).to(self.device)
105+
)
104106

105107
for name, model in self.models.items():
106108
output = model(data)
107-
confidences[name] = torch.sigmoid(output["logits"])
108-
predictions[name] = (
109-
torch.sigmoid(output["logits"]) > 0.5
110-
).long() # Multi-label classification
109+
sigmoid_logits = torch.sigmoid(output["logits"])
110+
confidences[name] = sigmoid_logits
111+
predictions[name] = (sigmoid_logits > 0.5).long()
111112
total_logits += output["logits"]
112113

113114
return {
@@ -211,21 +212,18 @@ def _execute(
211212
def aggregate_predictions(self, predictions, confidences):
212213
"""Implements weighted voting based on trustworthiness."""
213214
batch_size, num_classes = list(predictions.values())[0].shape
214-
215215
true_scores = torch.zeros(batch_size, num_classes, device=self.device)
216216
false_scores = torch.zeros(batch_size, num_classes, device=self.device)
217217

218218
for model, preds in predictions.items():
219219
tpv = float(self.model_configs[model]["TPV"])
220220
npv = float(self.model_configs[model]["FPV"])
221-
222-
confidence = confidences[model]
223-
weight = confidence * (tpv * preds + npv * (1 - preds))
221+
weight = confidences[model] * (tpv * preds + npv * (1 - preds))
224222

225223
true_scores += weight * preds
226224
false_scores += weight * (1 - preds)
227225

228-
return (true_scores > false_scores).long() # Final class decision
226+
return (true_scores > false_scores).long()
229227

230228
def _process_for_loss(
231229
self,
@@ -264,11 +262,7 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs):
264262
self.ffn: FFN = FFN(**ffn_kwargs)
265263

266264
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
267-
logits_list = []
268-
for name, model in self.models.items():
269-
output = model(data)
270-
logits_list.append(output["logits"])
271-
265+
logits_list = [model(data)["logits"] for model in self.models.values()]
272266
return self.ffn({"features": torch.cat(logits_list, dim=1)})
273267

274268
def _get_prediction_and_labels(

0 commit comments

Comments
 (0)