@@ -245,12 +245,6 @@ def run(params):
245
245
seed = args .rng_seed
246
246
candle .set_seed (seed )
247
247
248
- # Construct extension to save model
249
- # ext = adrp.extension_from_parameters(params, ".keras")
250
- # params['save_path'] = './'+params['base_name']+'/'
251
- # candle.verify_path(params["save_path"])
252
-
253
- # prefix = "{}{}".format(params["save_path"], ext)
254
248
prefix = "{}" .format (params ["save_path" ])
255
249
logfile = params ["logfile" ] if params ["logfile" ] else prefix + "TEST.log"
256
250
candle .set_up_logger (logfile , adrp .logger , params ["verbose" ])
@@ -259,7 +253,6 @@ def run(params):
259
253
# Get default parameters for initialization and optimizer functions
260
254
keras_defaults = candle .keras_default_config ()
261
255
262
- ##
263
256
X_train , Y_train , X_test , Y_test , PS , count_array = adrp .load_data (params , seed )
264
257
265
258
print ("X_train shape:" , X_train .shape )
@@ -342,12 +335,20 @@ def run(params):
342
335
343
336
# set up a bunch of callbacks to do work during model training..
344
337
345
- checkpointer = ModelCheckpoint (
346
- filepath = params ["save_path" ] + "agg_adrp.autosave.model.h5" ,
347
- verbose = 1 ,
348
- save_weights_only = False ,
349
- save_best_only = True ,
350
- )
338
+ #checkpointer = ModelCheckpoint(
339
+ # filepath=params["save_path"] + "agg_adrp.autosave.model.h5",
340
+ # verbose=1,
341
+ # save_weights_only=False,
342
+ # save_best_only=True,
343
+ #)
344
+ initial_epoch = 0
345
+ ckpt = candle .CandleCkptKeras (params , verbose = True )
346
+ ckpt .set_model (model )
347
+ J = ckpt .restart (model )
348
+ if J is not None :
349
+ initial_epoch = J ["epoch" ]
350
+ print ("restarting from ckpt: initial_epoch: %i" % initial_epoch )
351
+
351
352
csv_logger = CSVLogger (params ["save_path" ] + "agg_adrp.training.log" )
352
353
353
354
# min_lr = params['learning_rate']*params['reduce_ratio']
@@ -456,8 +457,9 @@ def run(params):
456
457
verbose = 1 ,
457
458
sample_weight = train_weight ,
458
459
validation_data = (X_test , Y_test , test_weight ),
459
- callbacks = [checkpointer , timeout_monitor , csv_logger , reduce_lr , early_stop ],
460
+ callbacks = [ckpt , timeout_monitor , csv_logger , reduce_lr , early_stop ],
460
461
)
462
+ ckpt .report_final ()
461
463
462
464
print ("Reloading saved best model" )
463
465
model .load_weights (params ["save_path" ] + "agg_adrp.autosave.model.h5" )
0 commit comments