Skip to content

Commit bdb9d06

Browse files
fix: use cache for snapshots even if refs does not exist (#1306)
* fix: use cache for snapshots even if refs does not exist If one uses `hf_hub_download` only referencing specific commits the `refs` folder will not be created even though data will be cached via `snapshots` and `blobs`. Subsequent calls to `try_to_load_from_cache` were returning None even though the desired data was in the cache. Example: ```python # download something hf_hub_download(repo_id=repo, revision=commit_hash, filename=filepath, token=token) # returns None try_to_load_from_cache(repo_id=repo, revision=commit_hash, filename=filepath) ``` * FIX in case .not_exists + add tests Co-authored-by: Lucain Pouget <[email protected]>
1 parent 5a12851 commit bdb9d06

File tree

2 files changed

+66
-12
lines changed

2 files changed

+66
-12
lines changed

src/huggingface_hub/file_download.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,25 +1345,32 @@ def try_to_load_from_cache(
13451345
if not os.path.isdir(repo_cache):
13461346
# No cache for this model
13471347
return None
1348-
for subfolder in ["refs", "snapshots"]:
1349-
if not os.path.isdir(os.path.join(repo_cache, subfolder)):
1350-
return None
13511348

1352-
# Resolve refs (for instance to convert main to the associated commit sha)
1353-
cached_refs = os.listdir(os.path.join(repo_cache, "refs"))
1354-
if revision in cached_refs:
1355-
with open(os.path.join(repo_cache, "refs", revision)) as f:
1356-
revision = f.read()
1349+
refs_dir = os.path.join(repo_cache, "refs")
1350+
snapshots_dir = os.path.join(repo_cache, "snapshots")
1351+
no_exists_dir = os.path.join(repo_cache, ".no_exist")
13571352

1358-
if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
1353+
# Resolve refs (for instance to convert main to the associated commit sha)
1354+
if os.path.isdir(refs_dir):
1355+
cached_refs = os.listdir(refs_dir)
1356+
if revision in cached_refs:
1357+
with open(os.path.join(refs_dir, revision)) as f:
1358+
revision = f.read()
1359+
1360+
# Check if file is cached as "no_exists"
1361+
if os.path.isfile(os.path.join(no_exists_dir, revision, filename)):
13591362
return _CACHED_NO_EXIST
13601363

1361-
cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
1364+
# Check if revision folder exists
1365+
if not os.path.exists(snapshots_dir):
1366+
return None
1367+
cached_shas = os.listdir(snapshots_dir)
13621368
if revision not in cached_shas:
13631369
# No cache for this revision and we won't try to return a random revision
13641370
return None
13651371

1366-
cached_file = os.path.join(repo_cache, "snapshots", revision, filename)
1372+
# Check if file exists in cache
1373+
cached_file = os.path.join(snapshots_dir, revision, filename)
13671374
return cached_file if os.path.isfile(cached_file) else None
13681375

13691376

tests/test_file_download.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def test_hf_hub_download_legacy(self):
348348
metadata = filename_to_url(filepath, legacy_cache_layout=True)
349349
self.assertEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"')
350350

351-
def test_try_to_load_from_cache(self):
351+
def test_try_to_load_from_cache_exist(self):
352352
# Make sure the file is cached
353353
filepath = hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME)
354354

@@ -389,6 +389,53 @@ def test_try_to_load_from_cache_no_exist(self):
389389
# If file non-existence is not cached, returns None
390390
self.assertIsNone(try_to_load_from_cache(DUMMY_MODEL_ID, filename="dummy2"))
391391

392+
def test_try_to_load_from_cache_specific_commit_id_exist(self):
393+
"""Regression test for #1306.
394+
395+
See https://github.com/huggingface/huggingface_hub/pull/1306."""
396+
with SoftTemporaryDirectory() as cache_dir:
397+
# Cache file from specific commit id (no "refs/"" folder)
398+
commit_id = HfApi().model_info(DUMMY_MODEL_ID).sha
399+
filepath = hf_hub_download(
400+
DUMMY_MODEL_ID,
401+
filename=CONFIG_NAME,
402+
revision=commit_id,
403+
cache_dir=cache_dir,
404+
)
405+
406+
# Must be able to retrieve it "offline"
407+
attempt = try_to_load_from_cache(
408+
DUMMY_MODEL_ID,
409+
filename=CONFIG_NAME,
410+
revision=commit_id,
411+
cache_dir=cache_dir,
412+
)
413+
self.assertEqual(filepath, attempt)
414+
415+
def test_try_to_load_from_cache_specific_commit_id_no_exist(self):
416+
"""Regression test for #1306.
417+
418+
See https://github.com/huggingface/huggingface_hub/pull/1306."""
419+
with SoftTemporaryDirectory() as cache_dir:
420+
# Cache file from specific commit id (no "refs/"" folder)
421+
commit_id = HfApi().model_info(DUMMY_MODEL_ID).sha
422+
with self.assertRaises(EntryNotFoundError):
423+
hf_hub_download(
424+
DUMMY_MODEL_ID,
425+
filename="missing_file",
426+
revision=commit_id,
427+
cache_dir=cache_dir,
428+
)
429+
430+
# Must be able to retrieve it "offline"
431+
attempt = try_to_load_from_cache(
432+
DUMMY_MODEL_ID,
433+
filename="missing_file",
434+
revision=commit_id,
435+
cache_dir=cache_dir,
436+
)
437+
self.assertEqual(attempt, _CACHED_NO_EXIST)
438+
392439
def test_get_hf_file_metadata_basic(self) -> None:
393440
"""Test getting metadata from a file on the Hub."""
394441
url = hf_hub_url(

0 commit comments

Comments
 (0)