@@ -201,6 +201,14 @@ def build_feature_model(input_shape, name='', dense_layers=[1000, 1000],
201
201
model = Model (x_input , h , name = name )
202
202
return model
203
203
204
+ class SimpleWeightSaver (Callback ):
205
+ def __init__ (self , fname ):
206
+ self .fname = fname
207
+
208
+ def on_train_end (self , logs = {}):
209
+ self .model .save_weights (self .fname )
210
+
211
+
204
212
205
213
def build_model (loader , args , permanent_dropout = True , silent = False ):
206
214
input_models = {}
@@ -386,12 +394,17 @@ def warmup_scheduler(epoch):
386
394
logger .info ('Cross validation fold {}/{}:' .format (fold + 1 , cv ))
387
395
cv_ext = '.cv{}' .format (fold + 1 )
388
396
397
+ template_model = build_model (loader , args , silent = True )
398
+ if args .initial_weights :
399
+ logger .info ("Loading weights from {}" .format (args .initial_weights ))
400
+ template_model .load_weights (args .initial_weights )
401
+
389
402
if len (args .gpus ) > 1 :
390
403
from keras .utils import multi_gpu_model
391
404
gpu_count = len (args .gpus )
392
- model = multi_gpu_model (build_model ( loader , args , silent = True ) , cpu_merge = False , gpus = gpu_count )
405
+ model = multi_gpu_model (template_model , cpu_merge = False , gpus = gpu_count )
393
406
else :
394
- model = build_model ( loader , args , silent = True )
407
+ model = template_model
395
408
396
409
optimizer = optimizers .deserialize ({'class_name' : args .optimizer , 'config' : {}})
397
410
base_lr = args .base_lr or K .get_value (optimizer .lr )
@@ -411,7 +424,7 @@ def warmup_scheduler(epoch):
411
424
checkpointer = MultiGPUCheckpoint (prefix + cv_ext + '.model.h5' , save_best_only = True )
412
425
tensorboard = TensorBoard (log_dir = "tb/{}{}{}" .format (args .tb_prefix , ext , cv_ext ))
413
426
history_logger = LoggingCallback (logger .debug )
414
-
427
+
415
428
callbacks = [candle_monitor , timeout_monitor , history_logger ]
416
429
if args .reduce_lr :
417
430
callbacks .append (reduce_lr )
@@ -421,6 +434,8 @@ def warmup_scheduler(epoch):
421
434
callbacks .append (checkpointer )
422
435
if args .tb :
423
436
callbacks .append (tensorboard )
437
+ if args .save_weights :
438
+ callbacks .append (SimpleWeightSaver (args .save_path + '/' + args .save_weights ))
424
439
425
440
if args .use_exported_data is not None :
426
441
train_gen = DataFeeder (filename = args .use_exported_data , batch_size = args .batch_size , shuffle = args .shuffle , single = args .single , agg_dose = args .agg_dose )
0 commit comments