Skip to content

Commit 9b1ed42

Browse files
committed
Added ability to save and load weights via args
1 parent 4ca9ccd commit 9b1ed42

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

Pilot1/Uno/uno.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,15 @@ def set_locals(self):
191191
{'name': 'growth_bins',
192192
'type': int,
193193
'default': 0,
194-
'help': 'number of bins to use when discretizing growth response'}
194+
'help': 'number of bins to use when discretizing growth response'},
195+
{'name' : 'initial_weights',
196+
'type' : str,
197+
'default': None,
198+
'help' : 'file name of initial weights'},
199+
{'name' : 'save_weights',
200+
'type': str,
201+
'default' : None,
202+
'help': 'name of file to save weights to' }
195203
]
196204

197205
required = [

Pilot1/Uno/uno_baseline_keras2.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ def build_feature_model(input_shape, name='', dense_layers=[1000, 1000],
201201
model = Model(x_input, h, name=name)
202202
return model
203203

204+
class SimpleWeightSaver(Callback):
205+
def __init__(self, fname):
206+
self.fname = fname
207+
208+
def on_train_end(self, logs={}):
209+
self.model.save_weights(self.fname)
210+
211+
204212

205213
def build_model(loader, args, permanent_dropout=True, silent=False):
206214
input_models = {}
@@ -386,12 +394,17 @@ def warmup_scheduler(epoch):
386394
logger.info('Cross validation fold {}/{}:'.format(fold + 1, cv))
387395
cv_ext = '.cv{}'.format(fold + 1)
388396

397+
template_model = build_model(loader, args, silent=True)
398+
if args.initial_weights:
399+
logger.info("Loading weights from {}".format(args.initial_weights))
400+
template_model.load_weights(args.initial_weights)
401+
389402
if len(args.gpus) > 1:
390403
from keras.utils import multi_gpu_model
391404
gpu_count = len(args.gpus)
392-
model = multi_gpu_model(build_model(loader, args, silent=True), cpu_merge=False, gpus=gpu_count)
405+
model = multi_gpu_model(template_model, cpu_merge=False, gpus=gpu_count)
393406
else:
394-
model = build_model(loader, args, silent=True)
407+
model = template_model
395408

396409
optimizer = optimizers.deserialize({'class_name': args.optimizer, 'config': {}})
397410
base_lr = args.base_lr or K.get_value(optimizer.lr)
@@ -411,7 +424,7 @@ def warmup_scheduler(epoch):
411424
checkpointer = MultiGPUCheckpoint(prefix + cv_ext + '.model.h5', save_best_only=True)
412425
tensorboard = TensorBoard(log_dir="tb/{}{}{}".format(args.tb_prefix, ext, cv_ext))
413426
history_logger = LoggingCallback(logger.debug)
414-
427+
415428
callbacks = [candle_monitor, timeout_monitor, history_logger]
416429
if args.reduce_lr:
417430
callbacks.append(reduce_lr)
@@ -421,6 +434,8 @@ def warmup_scheduler(epoch):
421434
callbacks.append(checkpointer)
422435
if args.tb:
423436
callbacks.append(tensorboard)
437+
if args.save_weights:
438+
callbacks.append(SimpleWeightSaver(args.save_path + '/' + args.save_weights))
424439

425440
if args.use_exported_data is not None:
426441
train_gen = DataFeeder(filename=args.use_exported_data, batch_size=args.batch_size, shuffle=args.shuffle, single=args.single, agg_dose=args.agg_dose)

0 commit comments

Comments
 (0)