@@ -346,7 +346,7 @@ def test_save_on_train_end(self) -> None:
346346 self .assertTrue (os .path .exists (os .path .join (temp_dir , expected_path )))
347347
348348 with self .assertLogs (level = "WARNING" ) as log :
349- checkpoint_cb .metadata_fname = ".metadata"
349+ checkpoint_cb ._checkpoint_manager . _metadata_fname = ".metadata"
350350 # create metadata file
351351 with open (os .path .join (temp_dir , expected_path , ".metadata" ), "w" ):
352352 pass
@@ -454,99 +454,6 @@ def _test_process_group_plumbing_nccl(_: MagicMock) -> None:
454454 # check that a new process group was created
455455 tc .assertNotEqual (checkpoint_cb ._process_group , dist .group .WORLD )
456456
457- @patch (
458- "torchtnt.framework.callbacks.base_checkpointer.get_checkpoint_dirpaths" ,
459- return_value = ["epoch_1_step_10" , "epoch_2_step_20" ],
460- )
461- def test_ckpt_dirpaths (self , _ : MagicMock ) -> None :
462- """
463- Tests that ckpt_dirpaths is populated correctly
464- based on if ``keep_last_n_checkpoints`` is set.
465- """
466- bc = BaseCheckpointSaver ("foo" )
467- self .assertEqual (bc ._ckpt_dirpaths , [])
468-
469- bc = BaseCheckpointSaver ("foo" , keep_last_n_checkpoints = 10 )
470- self .assertEqual (bc ._ckpt_dirpaths , ["epoch_1_step_10" , "epoch_2_step_20" ])
471-
472- def test_should_remove_checkpoint (self ) -> None :
473- """
474- Tests the helper function that checks if checkpoint should be removed or not
475- """
476- bc = BaseCheckpointSaver ("temp" )
477-
478- # keep_last_n_checkpoints is toggled off
479- self .assertFalse (bc ._should_remove_checkpoint ())
480-
481- # not enough checkpoints are saved yet to be removed
482- bc ._keep_last_n_checkpoints = 2
483- bc ._ckpt_dirpaths = ["bar" ]
484- self .assertFalse (bc ._should_remove_checkpoint ())
485-
486- # enough checkpoints are there to remove
487- bc ._keep_last_n_checkpoints = 2
488- bc ._ckpt_dirpaths = ["foo" , "bar" ]
489- self .assertTrue (bc ._should_remove_checkpoint ())
490-
491- @patch ("torchtnt.framework.callbacks.base_checkpointer._delete_checkpoint" )
492- def test_cleanup_surplus (self , mock_delete_checkpoint : MagicMock ) -> None :
493- """
494- Tests surplus of checkpoints being cleaned up
495- """
496- state = get_dummy_train_state ()
497- unit = DummyTrainUnit (input_dim = 2 )
498- warning_messages = []
499- with tempfile .TemporaryDirectory () as temp_dir :
500- bc = BaseCheckpointSaver (temp_dir , keep_last_n_checkpoints = 1 )
501- bc ._ckpt_dirpaths = ["foo" , "bar" , "baz" ]
502-
503- expected_warning_msg = " " .join (
504- [
505- f"3 checkpoints found in { temp_dir } ." ,
506- f"Deleting { 2 } oldest" ,
507- "checkpoints to enforce ``keep_last_n_checkpoints`` argument." ,
508- ]
509- )
510-
511- with patch (
512- "torchtnt.framework.callbacks.base_checkpointer.logging.Logger.warning" ,
513- warning_messages .append ,
514- ):
515- bc .on_train_start (state , unit )
516- self .assertEqual (bc ._ckpt_dirpaths , ["baz" ])
517- self .assertEqual (warning_messages [0 ], expected_warning_msg )
518-
519- bc = BaseCheckpointSaver (temp_dir )
520- bc ._ckpt_dirpaths = ["foo" , "bar" , "baz" ]
521-
522- bc .on_train_start (state , unit )
523- self .assertEqual (bc ._ckpt_dirpaths , ["foo" , "bar" , "baz" ])
524-
525- def test_keep_last_n_checkpoints (self ) -> None :
526- """
527- Tests removing checkpoint directories
528- """
529- unit = DummyTrainUnit (input_dim = 2 )
530- state = get_dummy_train_state ()
531- with tempfile .TemporaryDirectory () as temp_dir :
532- bc = BaseCheckpointSaver (
533- temp_dir ,
534- save_every_n_train_steps = 1 ,
535- keep_last_n_checkpoints = 2 ,
536- )
537-
538- # take 10 steps
539- for _ in range (10 ):
540- unit .train_progress .increment_step ()
541- bc .on_train_step_end (state , unit )
542- # TODO remove time.sleep to avoid potential flaky test
543- time .sleep (0.1 ) # sleep to ensure enough time to checkpoint
544-
545- dirs = os .listdir (temp_dir )
546- self .assertEqual (len (dirs ), 2 )
547- self .assertIn ("epoch_0_step_9" , dirs )
548- self .assertIn ("epoch_0_step_10" , dirs )
549-
550457 def test_keep_last_n_checkpoints_e2e (self ) -> None :
551458 """
552459 Tests removing checkpoint directories e2e
@@ -581,66 +488,6 @@ def test_keep_last_n_checkpoints_e2e(self) -> None:
581488 os .listdir (temp_dir ),
582489 )
583490
584- def test_does_checkpoint_exist (self ) -> None :
585- with tempfile .TemporaryDirectory () as temp_dir :
586- with open (os .path .join (temp_dir , ".metadata" ), "w" ):
587- pass
588- bc = BaseCheckpointSaver (
589- temp_dir ,
590- save_every_n_train_steps = 2 ,
591- keep_last_n_checkpoints = 1 ,
592- )
593- # checkpointer doesn't have a metadata_fname
594- does_checkpoint_exist = bc ._does_checkpoint_exist (temp_dir )
595- self .assertFalse (does_checkpoint_exist )
596-
597- # checkpointer has metadata_fname and the file exists
598- bc .metadata_fname = ".metadata"
599- does_checkpoint_exist = bc ._does_checkpoint_exist (temp_dir )
600- self .assertTrue (does_checkpoint_exist )
601-
602- # checkpointer has metadata_fname but the file doesn't exist
603- os .remove (os .path .join (temp_dir , ".metadata" ))
604- does_checkpoint_exist = bc ._does_checkpoint_exist (temp_dir )
605- self .assertFalse (does_checkpoint_exist )
606-
607- def test_should_save_checkpoint (self ) -> None :
608- """
609- Tests basic functionality of should_save_checkpoint
610- """
611- bc = BaseCheckpointSaver ("foo" )
612-
613- # test default behavior
614- self .assertTrue (bc ._should_save_checkpoint ())
615-
616- bc ._ckpt_dirpaths = ["foo/epoch_0_step_1" ]
617- self .assertTrue (bc ._should_save_checkpoint ())
618- bc ._keep_last_n_checkpoints = 1
619- self .assertTrue (bc ._should_save_checkpoint ())
620-
621- bc ._ckpt_dirpaths = ["foo/epoch_0_step_1_val_loss=0.01" ]
622- bc ._best_checkpoint_config = BestCheckpointConfig (
623- monitored_metric = "val_loss" ,
624- mode = "min" ,
625- )
626- bc ._keep_last_n_checkpoints = None
627- self .assertTrue (bc ._should_save_checkpoint (0.02 ))
628- bc ._keep_last_n_checkpoints = 1
629- self .assertFalse (bc ._should_save_checkpoint (0.02 ))
630- self .assertTrue (bc ._should_save_checkpoint (0.001 ))
631- bc ._keep_last_n_checkpoints = 2
632- self .assertTrue (bc ._should_save_checkpoint (0.02 ))
633-
634- bc ._best_checkpoint_config = BestCheckpointConfig (
635- monitored_metric = "val_loss" ,
636- mode = "max" ,
637- )
638- bc ._keep_last_n_checkpoints = 1
639- self .assertTrue (bc ._should_save_checkpoint (0.02 ))
640- self .assertFalse (bc ._should_save_checkpoint (0.001 ))
641- bc ._keep_last_n_checkpoints = 2
642- self .assertTrue (bc ._should_save_checkpoint (0.001 ))
643-
644491 def test_best_checkpoint_attr_missing (self ) -> None :
645492 bcs = BaseCheckpointSaver (
646493 "foo" ,
@@ -686,21 +533,21 @@ def test_best_checkpoint_no_top_k(self) -> None:
686533 my_train_unit .train_loss = None
687534 bcs .on_train_epoch_end (state , my_train_unit )
688535 # none metric-value will not be updated in checkpoint dirpaths
689- self .assertEqual (bcs ._ckpt_dirpaths , [])
536+ self .assertEqual (bcs ._checkpoint_manager . _ckpt_paths , [])
690537 self .assertEqual (os .listdir (temp_dir ), ["epoch_0_step_0" ])
691538
692539 my_train_unit .train_loss = 0.01
693540 bcs .on_train_epoch_end (state , my_train_unit )
694541 self .assertEqual (
695- bcs ._ckpt_dirpaths ,
542+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
696543 [os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" )],
697544 )
698545
699546 my_train_unit .train_loss = 0.02
700547 my_train_unit .train_progress .increment_epoch ()
701548 bcs .on_train_epoch_end (state , my_train_unit )
702549 self .assertEqual (
703- bcs ._ckpt_dirpaths ,
550+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
704551 (
705552 [
706553 os .path .join (temp_dir , "epoch_1_step_0_train_loss=0.02" ),
@@ -718,7 +565,7 @@ def test_best_checkpoint_no_top_k(self) -> None:
718565 my_train_unit .train_progress .increment_epoch ()
719566 bcs .on_train_epoch_end (state , my_train_unit )
720567 self .assertEqual (
721- bcs ._ckpt_dirpaths ,
568+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
722569 (
723570 [
724571 os .path .join (temp_dir , "epoch_1_step_0_train_loss=0.02" ),
@@ -752,15 +599,15 @@ def test_best_checkpoint_top_k(self) -> None:
752599
753600 bcs .on_train_epoch_end (state , my_train_unit )
754601 self .assertEqual (
755- bcs ._ckpt_dirpaths ,
602+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
756603 [os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" )],
757604 )
758605
759606 my_train_unit .train_loss = 0.02
760607 my_train_unit .train_progress .increment_epoch ()
761608 bcs .on_train_epoch_end (state , my_train_unit )
762609 self .assertEqual (
763- bcs ._ckpt_dirpaths ,
610+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
764611 [
765612 os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" ),
766613 ],
@@ -770,7 +617,7 @@ def test_best_checkpoint_top_k(self) -> None:
770617 my_train_unit .train_progress .increment_epoch ()
771618 bcs .on_train_epoch_end (state , my_train_unit )
772619 self .assertEqual (
773- bcs ._ckpt_dirpaths ,
620+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
774621 [
775622 os .path .join (temp_dir , "epoch_2_step_0_train_loss=0.001" ),
776623 ],
@@ -793,15 +640,15 @@ def test_best_checkpoint_top_k(self) -> None:
793640
794641 bcs .on_train_epoch_end (state , my_train_unit )
795642 self .assertEqual (
796- bcs ._ckpt_dirpaths ,
643+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
797644 [os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" )],
798645 )
799646
800647 my_train_unit .train_loss = 0.02
801648 my_train_unit .train_progress .increment_epoch ()
802649 bcs .on_train_epoch_end (state , my_train_unit )
803650 self .assertEqual (
804- bcs ._ckpt_dirpaths ,
651+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
805652 [
806653 os .path .join (temp_dir , "epoch_1_step_0_train_loss=0.02" ),
807654 os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" ),
@@ -812,7 +659,7 @@ def test_best_checkpoint_top_k(self) -> None:
812659 my_train_unit .train_progress .increment_epoch ()
813660 bcs .on_train_epoch_end (state , my_train_unit )
814661 self .assertEqual (
815- bcs ._ckpt_dirpaths ,
662+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
816663 [
817664 os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" ),
818665 os .path .join (temp_dir , "epoch_2_step_0_train_loss=0.001" ),
0 commit comments