Skip to content

Commit 79b191c

Browse files
committed
Merge branch 'release_01' of https://github.com/ECP-Candle/Benchmarks into release_01
2 parents f5360c8 + 0c7eaa4 commit 79b191c

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

common/keras_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,21 @@ def get_function(name):
5656

5757

5858
def build_optimizer(type, lr, kerasDefaults):
59-
""" Set the optimizer to the appropriate Keras optimizer function
60-
based on the input string and learning rate. Other required values
59+
""" Set the optimizer to the appropriate Keras optimizer function
60+
based on the input string and learning rate. Other required values
6161
are set to the Keras default values
6262
6363
Parameters
6464
----------
6565
type : string
6666
String to choose the optimizer
67+
6768
Options recognized: 'sgd', 'rmsprop', 'adagrad', adadelta', 'adam'
6869
See the Keras documentation for a full description of the options
6970
70-
Return
71+
Returns
7172
----------
72-
Returns the appropriate Keras optimizer function
73+
The appropriate Keras optimizer function
7374
"""
7475

7576
if type == 'sgd':
@@ -132,13 +133,15 @@ def build_initializer(type, kerasDefaults, seed=None, constant=0.):
132133
----------
133134
type : string
134135
String to choose the initializer
136+
135137
Options recognized: 'constant', 'uniform', 'normal',
136-
'glorot_uniform', 'lecun_uniform', 'he_normal'
138+
'glorot_uniform', 'lecun_uniform', 'he_normal'
139+
137140
See the Keras documentation for a full description of the options
138141
139-
Return
142+
Returns
140143
----------
141-
Returns the appropriate Keras initializer function
144+
The appropriate Keras initializer function
142145
"""
143146

144147
if type == 'constant':

common/solr_keras.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,16 @@
1010

1111

1212
def compute_trainable_params(model):
13+
""" Extract number of parameters from the given Keras model
1314
15+
Parameters
16+
-----------
17+
model : Keras model
18+
19+
Return
20+
----------
21+
python dictionary that contains trainable_params, non_trainable_params and total_params
22+
"""
1423
trainable_count = int(
1524
np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
1625
non_trainable_count = int(
@@ -136,13 +145,32 @@ def save(self):
136145
file_run_json.write(json.dumps(self.log_messages, indent=4, separators=(',', ': ')))
137146

138147
class TerminateOnTimeOut(Callback):
148+
""" This class implements timeout on model training. When the script reaches timeout,
149+
this class sets model.stop_training = True
150+
"""
139151
def __init__(self, timeout_in_sec = 10):
152+
"""Initialize TerminateOnTimeOut class.
153+
154+
Parameters
155+
-----------
156+
timeout_in_sec : int
157+
seconds to timeout
158+
"""
159+
140160
super(TerminateOnTimeOut, self).__init__()
141161
self.run_timestamp = None
142162
self.timeout_in_sec = timeout_in_sec
163+
164+
143165
def on_train_begin(self, logs={}):
166+
""" Start clock to calculate timeout
167+
"""
144168
self.run_timestamp = datetime.now()
169+
170+
145171
def on_epoch_end(self, epoch, logs={}):
172+
""" On every epoch end, check whether it exceeded timeout and terminate training if necessary
173+
"""
146174
run_end = datetime.now()
147175
run_duration = run_end - self.run_timestamp
148176
run_in_sec = run_duration.total_seconds() #/ (60 * 60)

0 commit comments

Comments
 (0)