Skip to content

Commit 37d46f7

Browse files
committed
add utils.print_metrics to ensemble
1 parent 65a51e0 commit 37d46f7

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

chebai/ensemble/base.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from lightning import LightningModule
1010

1111
from chebai.models import ChebaiBaseNet
12+
from chebai.result.classification import print_metrics
1213

1314

1415
class EnsembleBase(ABC):
@@ -39,7 +40,7 @@ def __init__(
3940
if kwargs.get("_validate_configs", False):
4041
self._validate_model_configs(model_configs)
4142

42-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
43+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4344
self.input_dim = kwargs.get("input_dim", None)
4445
self.num_of_labels: Optional[int] = (
4546
None # will be set by `_load_data_module_labels` method
@@ -131,14 +132,24 @@ def run_ensemble(self):
131132
self._model_queue.popleft()
132133
)
133134
pred_conf_dict = self._controller(model, model_props)
135+
del model # Model can be huge to keep it in memory, delete as no longer needed
136+
134137
self._consolidator(
135138
pred_conf_dict,
136139
model_props,
137140
true_scores=true_scores,
138141
false_scores=false_scores,
139142
)
140143

141-
self._consolidate_on_finish(true_scores=true_scores, false_scores=false_scores)
144+
final_preds = self._consolidate_on_finish(
145+
true_scores=true_scores, false_scores=false_scores
146+
)
147+
print_metrics(
148+
final_preds,
149+
self._collated_data.y,
150+
self.device,
151+
classes=list(self.dm_labels.keys()),
152+
)
142153

143154
def _load_model_and_its_props(self, model_name):
144155
"""

0 commit comments

Comments
 (0)