Skip to content

Commit ae4abb3

Browse files
committed
Bring restart() into P1B1
1 parent 47cb2e3 commit ae4abb3

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

Pilot1/P1B1/p1b1_baseline_pytorch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ def fit(model, X_train, X_val, params):
196196
)
197197
ckpt = candle.CandleCkptPyTorch(params)
198198
ckpt.set_model({"model": model, "optimizer": optimizer})
199+
J = ckpt.restart(model)
200+
if J is not None:
201+
initial_epoch = J["epoch"]
202+
best_metric_last = J["best_metric_last"]
203+
params["ckpt_best_metric_last"] = best_metric_last
204+
print("initial_epoch: %i" % initial_epoch)
199205

200206
total_step = len(train_iter)
201207
loss_list = []

0 commit comments

Comments
 (0)