|
9 | 9 | from lightning import LightningModule |
10 | 10 |
|
11 | 11 | from chebai.models import ChebaiBaseNet |
| 12 | +from chebai.result.classification import print_metrics |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class EnsembleBase(ABC): |
@@ -39,7 +40,7 @@ def __init__( |
39 | 40 | if kwargs.get("_validate_configs", False): |
40 | 41 | self._validate_model_configs(model_configs) |
41 | 42 |
|
42 | | - self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| 43 | + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
43 | 44 | self.input_dim = kwargs.get("input_dim", None) |
44 | 45 | self.num_of_labels: Optional[int] = ( |
45 | 46 | None # will be set by `_load_data_module_labels` method |
@@ -131,14 +132,24 @@ def run_ensemble(self): |
131 | 132 | self._model_queue.popleft() |
132 | 133 | ) |
133 | 134 | pred_conf_dict = self._controller(model, model_props) |
| 135 | + del model # Model can be huge to keep it in memory, delete as no longer needed |
| 136 | + |
134 | 137 | self._consolidator( |
135 | 138 | pred_conf_dict, |
136 | 139 | model_props, |
137 | 140 | true_scores=true_scores, |
138 | 141 | false_scores=false_scores, |
139 | 142 | ) |
140 | 143 |
|
141 | | - self._consolidate_on_finish(true_scores=true_scores, false_scores=false_scores) |
| 144 | + final_preds = self._consolidate_on_finish( |
| 145 | + true_scores=true_scores, false_scores=false_scores |
| 146 | + ) |
| 147 | + print_metrics( |
| 148 | + final_preds, |
| 149 | + self._collated_data.y, |
| 150 | + self.device, |
| 151 | + classes=list(self.dm_labels.keys()), |
| 152 | + ) |
142 | 153 |
|
143 | 154 | def _load_model_and_its_props(self, model_name): |
144 | 155 | """ |
|
0 commit comments