@@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
448448 _check_if_shards_exist_locally (
449449 pretrained_model_name_or_path , subfolder = subfolder , original_shard_filenames = original_shard_filenames
450450 )
451- return pretrained_model_name_or_path , sharded_metadata
451+ return shards_path , sharded_metadata
452452
453453 # At this stage pretrained_model_name_or_path is a model identifier on the Hub
454454 allow_patterns = original_shard_filenames
@@ -467,35 +467,37 @@ def _get_checkpoint_shard_files(
467467 "required according to the checkpoint index."
468468 )
469469
470- try :
471- # Load from URL
472- cached_folder = snapshot_download (
473- pretrained_model_name_or_path ,
474- cache_dir = cache_dir ,
475- proxies = proxies ,
476- local_files_only = local_files_only ,
477- token = token ,
478- revision = revision ,
479- allow_patterns = allow_patterns ,
480- ignore_patterns = ignore_patterns ,
481- user_agent = user_agent ,
482- )
483- if subfolder is not None :
484- cached_folder = os .path .join (cached_folder , subfolder )
470+ try :
471+ # Load from URL
472+ cached_folder = snapshot_download (
473+ pretrained_model_name_or_path ,
474+ cache_dir = cache_dir ,
475+ proxies = proxies ,
476+ local_files_only = local_files_only ,
477+ token = token ,
478+ revision = revision ,
479+ allow_patterns = allow_patterns ,
480+ ignore_patterns = ignore_patterns ,
481+ user_agent = user_agent ,
482+ )
483+ if subfolder is not None :
484+ cached_folder = os .path .join (cached_folder , subfolder )
485485
486- # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
487- # we don't have to catch them here. We have also dealt with EntryNotFoundError.
488- except HTTPError as e :
489- raise EnvironmentError (
490- f"We couldn't connect to '{ HUGGINGFACE_CO_RESOLVE_ENDPOINT } ' to load { pretrained_model_name_or_path } . You should try"
491- " again after checking your internet connection."
492- ) from e
486+ # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
487+ # we don't have to catch them here. We have also dealt with EntryNotFoundError.
488+ except HTTPError as e :
489+ raise EnvironmentError (
490+ f"We couldn't connect to '{ HUGGINGFACE_CO_RESOLVE_ENDPOINT } ' to load { pretrained_model_name_or_path } . You should try"
491+ " again after checking your internet connection."
492+ ) from e
493493
494494 # If `local_files_only=True`, `cached_folder` may not contain all the shard files.
495- if local_files_only :
495+ elif local_files_only :
496496 _check_if_shards_exist_locally (
497497 local_dir = cache_dir , subfolder = subfolder , original_shard_filenames = original_shard_filenames
498498 )
499+ if subfolder is not None :
500+ cached_folder = os .path .join (cached_folder , subfolder )
499501
500502 return cached_folder , sharded_metadata
501503
0 commit comments