Skip to content

Commit 1c528a4

Browse files
committed
up
1 parent 04cd2dc commit 1c528a4

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

src/diffusers/utils/hub_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ModelCardData,
3131
create_repo,
3232
hf_hub_download,
33+
model_info,
3334
snapshot_download,
3435
upload_folder,
3536
)
@@ -402,6 +403,23 @@ def _get_checkpoint_shard_files(
402403

403404
ignore_patterns = ["*.json", "*.md"]
404405

406+
# If the repo doesn't have the required shards, error out early even before downloading anything.
407+
if not local_files_only:
408+
try:
409+
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
410+
for shard_file in original_shard_filenames:
411+
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
412+
if not shard_file_present:
413+
raise EnvironmentError(
414+
f"{shards_path} does not appear to have a file named {shard_file} which is "
415+
"required according to the checkpoint index."
416+
)
417+
except ConnectionError as e:
418+
raise EnvironmentError(
419+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
420+
" again after checking your internet connection."
421+
) from e
422+
405423
try:
406424
# Load from URL
407425
cached_folder = snapshot_download(

0 commit comments

Comments
 (0)