Skip to content

Commit c40a7bb

Browse files
committed
Fix snapshot download when local_dir is provided. (#2592)
* Fix snapshot download when local_dir is provided * Fix tests docstring * Add comment * Fixes post-review
1 parent f7dc3fd commit c40a7bb

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

src/huggingface_hub/_snapshot_download.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,7 @@
1010
from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
1111
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
1212
from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
13-
from .utils import (
14-
OfflineModeIsEnabled,
15-
filter_repo_objects,
16-
logging,
17-
validate_hf_hub_args,
18-
)
13+
from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
1914
from .utils import tqdm as hf_tqdm
2015

2116

@@ -191,6 +186,7 @@ def snapshot_download(
191186
# => let's look if we can find the appropriate folder in the cache:
192187
# - if the specified revision is a commit hash, look inside "snapshots".
193188
# - f the specified revision is a branch or tag, look inside "refs".
189+
# => if local_dir is not None, we will return the path to the local folder if it exists.
194190
if repo_info is None:
195191
# Try to get which commit hash corresponds to the specified revision
196192
commit_hash = None
@@ -210,7 +206,14 @@ def snapshot_download(
210206
# Snapshot folder exists => let's return it
211207
# (but we can't check if all the files are actually there)
212208
return snapshot_folder
213-
209+
# If local_dir is not None, return it if it exists and is not empty
210+
if local_dir is not None:
211+
local_dir = Path(local_dir)
212+
if local_dir.is_dir() and any(local_dir.iterdir()):
213+
logger.warning(
214+
f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})."
215+
)
216+
return str(local_dir.resolve())
214217
# If we couldn't find the appropriate folder on disk, raise an error.
215218
if local_files_only:
216219
raise LocalEntryNotFoundError(

tests/test_snapshot_download.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,37 @@ def test_download_model_local_only(self):
143143
)
144144
self.assertTrue(self.first_commit_hash in storage_folder) # has expected revision
145145

146+
# Test with local_dir
147+
with SoftTemporaryDirectory() as tmpdir:
148+
# first download folder to local_dir
149+
snapshot_download(self.repo_id, local_dir=tmpdir)
150+
# now load from local_dir
151+
storage_folder = snapshot_download(self.repo_id, local_dir=tmpdir, local_files_only=True)
152+
self.assertEquals(str(tmpdir), storage_folder)
153+
154+
def test_download_model_to_local_dir_with_offline_mode(self):
155+
"""Test that an already downloaded folder is returned when there is a connection error"""
156+
# first download folder to local_dir
157+
with SoftTemporaryDirectory() as tmpdir:
158+
snapshot_download(self.repo_id, local_dir=tmpdir)
159+
# Check that the folder is returned when there is a connection error
160+
for offline_mode in OfflineSimulationMode:
161+
with offline(mode=offline_mode):
162+
storage_folder = snapshot_download(self.repo_id, local_dir=tmpdir)
163+
self.assertEquals(str(tmpdir), storage_folder)
164+
165+
def test_download_model_offline_mode_not_in_local_dir(self):
166+
"""Test when connection error but local_dir is empty."""
167+
with SoftTemporaryDirectory() as tmpdir:
168+
with self.assertRaises(LocalEntryNotFoundError):
169+
snapshot_download(self.repo_id, local_dir=tmpdir, local_files_only=True)
170+
171+
for offline_mode in OfflineSimulationMode:
172+
with offline(mode=offline_mode):
173+
with SoftTemporaryDirectory() as tmpdir:
174+
with self.assertRaises(LocalEntryNotFoundError):
175+
snapshot_download(self.repo_id, local_dir=tmpdir)
176+
146177
def test_download_model_offline_mode_not_cached(self):
147178
"""Test when connection error but cache is empty."""
148179
with SoftTemporaryDirectory() as tmpdir:

0 commit comments

Comments
 (0)