Skip to content

Commit 04cd2dc

Browse files
committed
reviewer feedback.
1 parent b7af511 commit 04cd2dc

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

src/diffusers/utils/hub_utils.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
)
4545
from packaging import version
4646
from requests import HTTPError
47-
from requests.exceptions import ConnectionError
4847

4948
from .. import __version__
5049
from .constants import (
@@ -402,24 +401,6 @@ def _get_checkpoint_shard_files(
402401
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
403402

404403
ignore_patterns = ["*.json", "*.md"]
405-
try:
406-
temp_dir = snapshot_download(
407-
repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only
408-
)
409-
except ConnectionError as e:
410-
raise EnvironmentError(
411-
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
412-
" again after checking your internet connection."
413-
) from e
414-
415-
model_files_info = _get_filepaths_for_folder(temp_dir)
416-
for shard_file in original_shard_filenames:
417-
shard_file_present = any(shard_file in k for k in model_files_info)
418-
if not shard_file_present:
419-
raise EnvironmentError(
420-
f"{shards_path} does not appear to have a file named {shard_file} which is "
421-
"required according to the checkpoint index."
422-
)
423404

424405
try:
425406
# Load from URL
@@ -437,6 +418,15 @@ def _get_checkpoint_shard_files(
437418
if subfolder is not None:
438419
cached_folder = os.path.join(cached_folder, subfolder)
439420

421+
model_files_info = _get_filepaths_for_folder(cached_folder)
422+
for shard_file in original_shard_filenames:
423+
shard_file_present = any(shard_file in k for k in model_files_info)
424+
if not shard_file_present:
425+
raise EnvironmentError(
426+
f"{shards_path} does not appear to have a file named {shard_file} which is "
427+
"required according to the checkpoint index."
428+
)
429+
440430
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
441431
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
442432
except HTTPError as e:

0 commit comments

Comments
 (0)