Skip to content

Commit fa6641a

Browse files
committed
implement checkpointing
1 parent e790951 commit fa6641a

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

examples/ADRP/adrp_baseline_keras2.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,6 @@ def run(params):
245245
seed = args.rng_seed
246246
candle.set_seed(seed)
247247

248-
# Construct extension to save model
249-
# ext = adrp.extension_from_parameters(params, ".keras")
250-
# params['save_path'] = './'+params['base_name']+'/'
251-
# candle.verify_path(params["save_path"])
252-
253-
# prefix = "{}{}".format(params["save_path"], ext)
254248
prefix = "{}".format(params["save_path"])
255249
logfile = params["logfile"] if params["logfile"] else prefix + "TEST.log"
256250
candle.set_up_logger(logfile, adrp.logger, params["verbose"])
@@ -259,7 +253,6 @@ def run(params):
259253
# Get default parameters for initialization and optimizer functions
260254
keras_defaults = candle.keras_default_config()
261255

262-
##
263256
X_train, Y_train, X_test, Y_test, PS, count_array = adrp.load_data(params, seed)
264257

265258
print("X_train shape:", X_train.shape)
@@ -342,12 +335,20 @@ def run(params):
342335

343336
# set up a bunch of callbacks to do work during model training..
344337

345-
checkpointer = ModelCheckpoint(
346-
filepath=params["save_path"] + "agg_adrp.autosave.model.h5",
347-
verbose=1,
348-
save_weights_only=False,
349-
save_best_only=True,
350-
)
338+
#checkpointer = ModelCheckpoint(
339+
# filepath=params["save_path"] + "agg_adrp.autosave.model.h5",
340+
# verbose=1,
341+
# save_weights_only=False,
342+
# save_best_only=True,
343+
#)
344+
initial_epoch = 0
345+
ckpt = candle.CandleCkptKeras(params, verbose=True)
346+
ckpt.set_model(model)
347+
J = ckpt.restart(model)
348+
if J is not None:
349+
initial_epoch = J["epoch"]
350+
print("restarting from ckpt: initial_epoch: %i" % initial_epoch)
351+
351352
csv_logger = CSVLogger(params["save_path"] + "agg_adrp.training.log")
352353

353354
# min_lr = params['learning_rate']*params['reduce_ratio']
@@ -456,8 +457,9 @@ def run(params):
456457
verbose=1,
457458
sample_weight=train_weight,
458459
validation_data=(X_test, Y_test, test_weight),
459-
callbacks=[checkpointer, timeout_monitor, csv_logger, reduce_lr, early_stop],
460+
callbacks=[ckpt, timeout_monitor, csv_logger, reduce_lr, early_stop],
460461
)
462+
ckpt.report_final()
461463

462464
print("Reloading saved best model")
463465
model.load_weights(params["save_path"] + "agg_adrp.autosave.model.h5")

0 commit comments

Comments
 (0)