-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMultipleTrainingRunner.py
More file actions
30 lines (19 loc) · 924 Bytes
/
MultipleTrainingRunner.py
File metadata and controls
30 lines (19 loc) · 924 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import Dict
from DmpConfig import DmpConfig
from DmpDataset import DmpDataset
from TrainingRunner import TrainingRunner
config = DmpConfig()
class MultipleTrainingRunner(TrainingRunner):
def __call__(self):
self._check_name_availability(config.model_name)
test_error_list: Dict[DmpDataset, float] = {}
for dataset in DmpDataset:
print(f"\n\nStarted training for dataset {dataset.value}")
_, test_error = self._train_and_assess_on_dataset(dataset)
test_error_list[dataset] = test_error
if not config.fast_run:
self.modelAssessment.write_results(test_error_list)
def _check_name_availability(self, model_name: str):
model_name_list: list[str] = self.modelAssessment.get_saved_model_names()
if model_name in model_name_list:
raise Exception(f"{model_name} was already used as model name")