@@ -539,21 +539,18 @@ def append_checkpoint(self, ckpt: CheckpointPath) -> None:
539539 # No metric tracked, most recents goes last
540540 self ._ckpt_paths .append (ckpt )
541541
542- @rank_zero_read_and_broadcast
543542 def does_checkpoint_exist (
544- self , ckpt : CheckpointPath , process_group : Optional [dist .ProcessGroup ] = None
543+ self ,
544+ ckpt : CheckpointPath ,
545+ process_group : Optional [dist .ProcessGroup ] = None ,
545546 ) -> bool :
546547 """
547548 Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory.
548549 If the checkpointer doesn't have a metadata file, this function will always return False. Check is executed in rank 0, but
549550 result is broadcasted to all ranks.
550551 """
551- if not self ._metadata_fnames :
552- return False
553-
554- fs , _ = url_to_fs (self .dirpath )
555- return any (
556- _metadata_exists (fs , ckpt .path , fname ) for fname in self ._metadata_fnames
552+ return does_checkpoint_exist (
553+ ckpt .path , self ._metadata_fnames , process_group = process_group
557554 )
558555
559556 @staticmethod
@@ -596,6 +593,33 @@ def remove_checkpoint(self) -> None:
596593 )
597594
598595
596+ @rank_zero_read_and_broadcast
597+ def does_checkpoint_exist (
598+ ckpt_path : str ,
599+ metadata_fname : Union [str , List [str ]],
600+ process_group : Optional [dist .ProcessGroup ] = None ,
601+ ) -> bool :
602+ """
603+ Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory.
604+ Will return False if the metadata_fname is None. Check is executed in rank 0, but
605+ result is broadcasted to all ranks.
606+
607+ Args:
608+ ckpt: The checkpoint to check.
609+ metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present.
610+ process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used.
611+ """
612+ if not metadata_fname :
613+ return False
614+ else :
615+ metadata_fnames = (
616+ [metadata_fname ] if isinstance (metadata_fname , str ) else metadata_fname
617+ )
618+
619+ fs , _ = url_to_fs (ckpt_path )
620+ return any (_metadata_exists (fs , ckpt_path , fname ) for fname in metadata_fnames )
621+
622+
599623@rank_zero_read_and_broadcast
600624def get_latest_checkpoint_path (
601625 dirpath : str ,
0 commit comments