7
7
import json
8
8
import os
9
9
import p1b1
10
- import numpy as np
11
-
12
- DATA_TYPES = {type (np .float16 ): 'f16' , type (np .float32 ): 'f32' , type (np .float64 ): 'f64' }
13
-
14
- def write_params (params , hyper_parameter_map ):
15
- parent_dir = hyper_parameter_map ['instance_directory' ] if 'instance_directory' in hyper_parameter_map else '.'
16
- f = "{}/parameters_p1b1.txt" .format (parent_dir )
17
- with open (f , "w" ) as f_out :
18
- f_out .write ("[parameters]\n " )
19
- for k ,v in params .items ():
20
- if type (v ) in DATA_TYPES :
21
- v = DATA_TYPES [type (v )]
22
- if isinstance (v , basestring ):
23
- v = "'{}'" .format (v )
24
- f_out .write ("{}={}\n " .format (k , v ))
25
-
26
- def is_numeric (val ):
27
- try :
28
- float (val )
29
- return True
30
- except ValueError :
31
- return False
32
-
33
- def format_params (hyper_parameter_map ):
34
- for k ,v in hyper_parameter_map .items ():
35
- vals = str (v ).split (" " )
36
- if len (vals ) > 1 and is_numeric (vals [0 ]):
37
- # assume this should be a list
38
- if "." in vals [0 ]:
39
- hyper_parameter_map [k ] = [float (x ) for x in vals ]
40
- else :
41
- hyper_parameter_map [k ] = [int (x ) for x in vals ]
42
-
10
+ import runner_utils
43
11
44
12
def run (hyper_parameter_map ):
45
13
framework = hyper_parameter_map ['framework' ]
@@ -56,17 +24,15 @@ def run(hyper_parameter_map):
56
24
raise ValueError ("Invalid framework: {}" .format (framework ))
57
25
58
26
# params is python dictionary
59
- sys .argv = ['fail here' , '--epochs' , '54321' ]
60
27
params = pkg .initialize_parameters ()
61
- format_params (hyper_parameter_map )
28
+ runner_utils . format_params (hyper_parameter_map )
62
29
63
30
for k ,v in hyper_parameter_map .items ():
64
31
#if not k in params:
65
32
# raise Exception("Parameter '{}' not found in set of valid arguments".format(k))
66
33
params [k ] = v
67
34
68
35
print (params )
69
- write_params (params , hyper_parameter_map )
70
36
history = pkg .run (params )
71
37
72
38
if framework is 'keras' :
@@ -81,27 +47,3 @@ def run(hyper_parameter_map):
81
47
# use the last validation_loss as the value to minimize
82
48
val_loss = history .history ['val_loss' ]
83
49
return val_loss [- 1 ]
84
-
85
- def write_output (result , instance_directory ):
86
- with open ('{}/result.txt' .format (instance_directory ), 'w' ) as f_out :
87
- f_out .write ("{}\n " .format (result ))
88
-
89
- def init (param_file , instance_directory ):
90
- with open (param_file ) as f_in :
91
- hyper_parameter_map = json .load (f_in )
92
-
93
- hyper_parameter_map ['framework' ] = 'keras'
94
- hyper_parameter_map ['save' ] = '{}/output' .format (instance_directory )
95
- hyper_parameter_map ['instance_directory' ] = instance_directory
96
-
97
- return hyper_parameter_map
98
-
99
- if __name__ == '__main__' :
100
- print ('p1b1_runner main ' + str (argv ))
101
- param_file = sys .argv [1 ]
102
- instance_directory = sys .argv [2 ]
103
- hyper_parameter_map = init (param_file , instance_directory )
104
- # clear sys.argv so that argparse doesn't object
105
- sys .argv = ['p1b1_runner' ]
106
- result = run (hyper_parameter_map )
107
- write_output (result , instance_directory )
0 commit comments