Skip to content

Commit 7d0b5a1

Browse files
committed
Fix loop reset test
1 parent cbb9fb5 commit 7d0b5a1

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

tests/tests_pytorch/loops/test_loops.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -558,24 +558,39 @@ def test_fit_loop_reset(tmp_path):
558558

559559
# we load exactly what was saved - no reset yet
560560
fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])
561-
# resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
562-
fit_loop.reset()
563-
epoch_loop.reset()
564561

565562
assert fit_loop.restarting
566563
assert fit_loop.epoch_progress.total.ready == 1
567564
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
568565
assert fit_loop.epoch_progress.current.ready == 1
569566
assert fit_loop.epoch_progress.current.completed == 0
570567

571-
assert epoch_loop.restarting
572568
assert epoch_loop.batch_progress.total.ready == 2
573569
assert epoch_loop.batch_progress.total.processed == 2
574570
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
575-
assert epoch_loop.batch_progress.current.ready == 1 # currents get set to the completed value
576-
assert epoch_loop.batch_progress.current.processed == 1
571+
assert epoch_loop.batch_progress.current.ready == 2 # currents get set to the completed value
572+
assert epoch_loop.batch_progress.current.processed == 2
577573
assert epoch_loop.batch_progress.current.completed == 1
578574

575+
fit_loop.reset()
576+
epoch_loop.reset()
577+
578+
# resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
579+
assert fit_loop.restarting
580+
assert fit_loop.epoch_progress.total.ready == 1
581+
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
582+
assert fit_loop.epoch_progress.current.ready == 1
583+
assert fit_loop.epoch_progress.current.completed == 0
584+
585+
# however it should increment completed batch progress, since it was saved immediately prior
586+
assert epoch_loop.restarting
587+
assert epoch_loop.batch_progress.total.ready == 2
588+
assert epoch_loop.batch_progress.total.processed == 2
589+
assert epoch_loop.batch_progress.total.completed == 2
590+
assert epoch_loop.batch_progress.current.ready == 2
591+
assert epoch_loop.batch_progress.current.processed == 2
592+
assert epoch_loop.batch_progress.current.completed == 2
593+
579594
assert optimizer_loop.restarting
580595

581596
# reset state loaded from a checkpoint from the end of an epoch
@@ -592,19 +607,21 @@ def test_fit_loop_reset(tmp_path):
592607
fit_loop.reset()
593608
epoch_loop.reset()
594609

610+
# resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
595611
assert fit_loop.restarting
596612
assert fit_loop.epoch_progress.total.ready == 1
597-
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
613+
assert fit_loop.epoch_progress.total.completed == 0
598614
assert fit_loop.epoch_progress.current.ready == 1
599615
assert fit_loop.epoch_progress.current.completed == 0
600616

617+
# however it should increment completed batch progress, since it was saved immediately prior
601618
assert epoch_loop.restarting
602619
assert epoch_loop.batch_progress.total.ready == 4
603620
assert epoch_loop.batch_progress.total.processed == 4
604-
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
605-
assert epoch_loop.batch_progress.current.ready == 3 # currents get set to the completed value
606-
assert epoch_loop.batch_progress.current.processed == 3
607-
assert epoch_loop.batch_progress.current.completed == 3
621+
assert epoch_loop.batch_progress.total.completed == 4
622+
assert epoch_loop.batch_progress.current.ready == 0
623+
assert epoch_loop.batch_progress.current.processed == 0
624+
assert epoch_loop.batch_progress.current.completed == 0
608625

609626

610627
def compare_state_dicts(dict1, dict2):

0 commit comments

Comments
 (0)