Skip to content

Commit 5f4dfe0

Browse files
committed
Updated save / load weights for multi-gpu model
1 parent 9b1ed42 commit 5f4dfe0

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

Pilot1/Uno/uno_baseline_keras2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,17 @@ def build_feature_model(input_shape, name='', dense_layers=[1000, 1000],
201201
model = Model(x_input, h, name=name)
202202
return model
203203

204-
class SimpleWeightSaver(Callback):
204+
class SimpleWeightSaver(Callback):
205+
205206
def __init__(self, fname):
206207
self.fname = fname
207208

209+
def set_model(self, model):
210+
if isinstance(model.layers[-2], Model):
211+
self.model = model.layers[-2]
212+
else:
213+
self.model = model
214+
208215
def on_train_end(self, logs={}):
209216
self.model.save_weights(self.fname)
210217

@@ -402,6 +409,7 @@ def warmup_scheduler(epoch):
402409
if len(args.gpus) > 1:
403410
from keras.utils import multi_gpu_model
404411
gpu_count = len(args.gpus)
412+
logger.info("Multi GPU with {} gpus".format(gpu_count))
405413
model = multi_gpu_model(template_model, cpu_merge=False, gpus=gpu_count)
406414
else:
407415
model = template_model
@@ -411,6 +419,7 @@ def warmup_scheduler(epoch):
411419
if args.learning_rate:
412420
K.set_value(optimizer.lr, args.learning_rate)
413421

422+
414423
model.compile(loss=args.loss, optimizer=optimizer, metrics=[mae, r2])
415424

416425
# calculate trainable and non-trainable params

0 commit comments

Comments
 (0)