@@ -558,24 +558,39 @@ def test_fit_loop_reset(tmp_path):
558558
559559 # we load exactly what was saved - no reset yet
560560 fit_loop .load_state_dict (mid_epoch_ckpt ["loops" ]["fit_loop" ])
561- # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
562- fit_loop .reset ()
563- epoch_loop .reset ()
564561
565562 assert fit_loop .restarting
566563 assert fit_loop .epoch_progress .total .ready == 1
567564 assert fit_loop .epoch_progress .total .completed == 0 # the checkpoint was saved mid epoch
568565 assert fit_loop .epoch_progress .current .ready == 1
569566 assert fit_loop .epoch_progress .current .completed == 0
570567
571- assert epoch_loop .restarting
572568 assert epoch_loop .batch_progress .total .ready == 2
573569 assert epoch_loop .batch_progress .total .processed == 2
574570 assert epoch_loop .batch_progress .total .completed == 1 # the checkpoint was saved on train_batch_end
575- assert epoch_loop .batch_progress .current .ready == 1 # currents get set to the completed value
576- assert epoch_loop .batch_progress .current .processed == 1
571+ assert epoch_loop .batch_progress .current .ready == 2 # currents get set to the completed value
572+ assert epoch_loop .batch_progress .current .processed == 2
577573 assert epoch_loop .batch_progress .current .completed == 1
578574
575+ fit_loop .reset ()
576+ epoch_loop .reset ()
577+
578+ # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
579+ assert fit_loop .restarting
580+ assert fit_loop .epoch_progress .total .ready == 1
581+ assert fit_loop .epoch_progress .total .completed == 0 # the checkpoint was saved mid epoch
582+ assert fit_loop .epoch_progress .current .ready == 1
583+ assert fit_loop .epoch_progress .current .completed == 0
584+
585+ # however it should increment completed batch progress, since it was saved immediately prior
586+ assert epoch_loop .restarting
587+ assert epoch_loop .batch_progress .total .ready == 2
588+ assert epoch_loop .batch_progress .total .processed == 2
589+ assert epoch_loop .batch_progress .total .completed == 2
590+ assert epoch_loop .batch_progress .current .ready == 2
591+ assert epoch_loop .batch_progress .current .processed == 2
592+ assert epoch_loop .batch_progress .current .completed == 2
593+
579594 assert optimizer_loop .restarting
580595
581596 # reset state loaded from a checkpoint from the end of an epoch
@@ -592,19 +607,21 @@ def test_fit_loop_reset(tmp_path):
592607 fit_loop .reset ()
593608 epoch_loop .reset ()
594609
610+ # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
595611 assert fit_loop .restarting
596612 assert fit_loop .epoch_progress .total .ready == 1
597- assert fit_loop .epoch_progress .total .completed == 0 # the checkpoint saves before the epoch completes
613+ assert fit_loop .epoch_progress .total .completed == 0
598614 assert fit_loop .epoch_progress .current .ready == 1
599615 assert fit_loop .epoch_progress .current .completed == 0
600616
617+ # however it should increment completed batch progress, since it was saved immediately prior
601618 assert epoch_loop .restarting
602619 assert epoch_loop .batch_progress .total .ready == 4
603620 assert epoch_loop .batch_progress .total .processed == 4
604- assert epoch_loop .batch_progress .total .completed == 3 # the checkpoint was saved on train_batch_end
605- assert epoch_loop .batch_progress .current .ready == 3 # currents get set to the completed value
606- assert epoch_loop .batch_progress .current .processed == 3
607- assert epoch_loop .batch_progress .current .completed == 3
621+ assert epoch_loop .batch_progress .total .completed == 4
622+ assert epoch_loop .batch_progress .current .ready == 0
623+ assert epoch_loop .batch_progress .current .processed == 0
624+ assert epoch_loop .batch_progress .current .completed == 0
608625
609626
610627def compare_state_dicts (dict1 , dict2 ):
0 commit comments