-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrandom_parameter_search.py
More file actions
125 lines (101 loc) · 4.47 KB
/
random_parameter_search.py
File metadata and controls
125 lines (101 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from argparse import ArgumentParser
import copy
import json
import logging
import os
from pathlib import Path
from random import sample
import subprocess
from tqdm import tqdm
import yaml
variations_config = {
'network_config': {
"learning_rate": lambda: sample([0.01, 0.05, 0.1], 1)[0],
"loss_type": lambda: sample(['huber_loss', 'mean_l1_loss', 'mean_l2_loss'], 1)[0],
"bilinear": lambda: sample([True, False], 1)[0],
"lr_patience": lambda: sample([1, 3], 1)[0],
"output_activation": lambda: sample(['relu', 'none'], 1)[0],
"initial_channels": lambda: sample([8, 16, 32], 1)[0],
"skip_connections": lambda: sample([False, True], 1)[0],
},
'dataset_config': {
"depth_difference_threshold": lambda: sample([0, 1, 2, 3], 1)[0],
"scale_images": lambda: sample([0.5, 1], 1)[0],
'add_region_mask_to_input': lambda: sample([False, True], 1)[0],
}
}
def generate_all_adaptions(num_adaptions: int):
adaptions = []
while len(adaptions) < num_adaptions:
adaption = {}
for upper_key, params in variations_config.items():
adaption[upper_key] = {}
for key, config in params.items():
adaption[upper_key][key] = config()
if adaption['dataset_config']['scale_images'] == 0.5:
if adaption['network_config']['initial_channels'] == 8:
adaption['network_config']['batch_size'] = 70
elif adaption['network_config']['initial_channels'] == 16:
adaption['network_config']['batch_size'] = 30
else:
adaption['network_config']['batch_size'] = 16
else: # 1.0
if adaption['network_config']['initial_channels'] == 8:
adaption['network_config']['batch_size'] = 35
elif adaption['network_config']['initial_channels'] == 16:
adaption['network_config']['batch_size'] = 16
else:
adaption['network_config']['batch_size'] = 8
if adaption not in adaptions:
adaptions.append(adaption)
return adaptions
def get_adapted_config(cfg, adaption):
adapted_config = copy.deepcopy(cfg)
for key in adapted_config.keys():
if key in adaption and isinstance(adapted_config[key], dict):
adapted_config[key].update(adaption[key])
return adapted_config
ROOT_DIR = Path(__file__).parent.parent.parent
MAIN_SCRIPT = ROOT_DIR / "src/trainers/train_models.py"
def main(args):
default_config = ROOT_DIR / args.default_config
num_configurations = args.num_configurations
adaptions_file = ROOT_DIR / args.adaptions_file
tmp_config = ROOT_DIR / f"tmp_config_for_{adaptions_file.stem}.yml"
# get adaptions
if not adaptions_file.exists():
adaptions = generate_all_adaptions(num_configurations)
adaptions_file.parent.mkdir(parents=True, exist_ok=True)
# file keeping track of processed adaptations
with open(adaptions_file, 'w') as f:
json.dump(adaptions, f)
# file saving all adaptions
with open(adaptions_file.parent / "backup_adaptations.json", 'w') as f:
json.dump(adaptions, f)
else:
with open(adaptions_file, 'r') as f:
adaptions = json.load(f)
with open(default_config, 'r') as f:
cfg = yaml.safe_load(f)
adaptions_without_processed = copy.deepcopy(adaptions)
for adaption in tqdm(adaptions, desc='processed adaptions'):
adapted_config = get_adapted_config(cfg, adaption)
with open(tmp_config, 'w') as f:
yaml.safe_dump(adapted_config, f)
p = subprocess.run(["python", MAIN_SCRIPT, tmp_config.as_posix()], cwd=os.getcwd())
if p.returncode == 0:
adaptions_without_processed.remove(adaption)
with open(adaptions_file, 'w') as f:
json.dump(adaptions_without_processed, f)
else:
logging.info(f"exception for training {adaption}")
logging.info("continuing with next adaption")
logging.info("processed all adaptions")
tmp_config.unlink()
if __name__ == "__main__":
argparse = ArgumentParser()
argparse.add_argument("default_config", type=Path, help="default config to use")
argparse.add_argument("--num-configurations", type=int, default=100)
argparse.add_argument("--adaptions-file", type=Path, default="adaptions.json",
help='where the adaptions to iterate through are saved')
main(argparse.parse_args())