@@ -201,10 +201,17 @@ def build_feature_model(input_shape, name='', dense_layers=[1000, 1000],
201
201
model = Model (x_input , h , name = name )
202
202
return model
203
203
204
- class SimpleWeightSaver (Callback ):
204
+ class SimpleWeightSaver (Callback ):
205
+
205
206
def __init__ (self , fname ):
206
207
self .fname = fname
207
208
209
+ def set_model (self , model ):
210
+ if isinstance (model .layers [- 2 ], Model ):
211
+ self .model = model .layers [- 2 ]
212
+ else :
213
+ self .model = model
214
+
208
215
def on_train_end (self , logs = {}):
209
216
self .model .save_weights (self .fname )
210
217
@@ -402,6 +409,7 @@ def warmup_scheduler(epoch):
402
409
if len (args .gpus ) > 1 :
403
410
from keras .utils import multi_gpu_model
404
411
gpu_count = len (args .gpus )
412
+ logger .info ("Multi GPU with {} gpus" .format (gpu_count ))
405
413
model = multi_gpu_model (template_model , cpu_merge = False , gpus = gpu_count )
406
414
else :
407
415
model = template_model
@@ -411,6 +419,7 @@ def warmup_scheduler(epoch):
411
419
if args .learning_rate :
412
420
K .set_value (optimizer .lr , args .learning_rate )
413
421
422
+
414
423
model .compile (loss = args .loss , optimizer = optimizer , metrics = [mae , r2 ])
415
424
416
425
# calculate trainable and non-trainable params
0 commit comments