Skip to content

Commit cc91e1d

Browse files
HF Hub download (#137)
* HF Hub download * Update src/huggingface_hub/file_download.py Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: Julien Chaumond <[email protected]>
1 parent 11d9ab6 commit cc91e1d

File tree

3 files changed

+78
-2
lines changed

3 files changed

+78
-2
lines changed

src/huggingface_hub/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
TF2_WEIGHTS_NAME,
3030
TF_WEIGHTS_NAME,
3131
)
32-
from .file_download import cached_download, hf_hub_url
32+
from .file_download import cached_download, hf_hub_download, hf_hub_url
3333
from .hf_api import HfApi, HfFolder
3434
from .hub_mixin import ModelHubMixin
3535
from .repository import Repository

src/huggingface_hub/file_download.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,3 +468,64 @@ def _resumable_file_manager() -> "io.BufferedWriter":
468468
json.dump(meta, meta_file)
469469

470470
return cache_path
471+
472+
473+
def hf_hub_download(
474+
repo_id: str,
475+
filename: str,
476+
subfolder: Optional[str] = None,
477+
repo_type: Optional[str] = None,
478+
revision: Optional[str] = None,
479+
library_name: Optional[str] = None,
480+
library_version: Optional[str] = None,
481+
cache_dir: Union[str, Path, None] = None,
482+
user_agent: Union[Dict, str, None] = None,
483+
force_download=False,
484+
force_filename: Optional[str] = None,
485+
proxies=None,
486+
etag_timeout=10,
487+
resume_download=False,
488+
use_auth_token: Union[bool, str, None] = None,
489+
local_files_only=False,
490+
):
491+
"""
492+
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co file distributed through
493+
Cloudfront (a Content Delivery Network, or CDN) for large files (more than a few MBs).
494+
495+
The file is cached locally: look for the corresponding file in the local cache. If it's not there,
496+
download it. Then return the path to the cached file.
497+
498+
Cloudfront is replicated over the globe so downloads are way faster for the end user.
499+
500+
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
501+
because we implement a git-based versioning system on huggingface.co, which means that we store the files on S3/Cloudfront
502+
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
503+
can't ever be stale.
504+
505+
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object's ETag is:
506+
its git-sha1 if stored in git, or its sha256 if stored in git-lfs.
507+
508+
Return:
509+
Local path (string) of file or if networking is off, last version of file cached on disk.
510+
511+
Raises:
512+
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
513+
"""
514+
url = hf_hub_url(
515+
repo_id, filename, subfolder=subfolder, repo_type=repo_type, revision=revision
516+
)
517+
518+
return cached_download(
519+
url,
520+
library_name=library_name,
521+
library_version=library_version,
522+
cache_dir=cache_dir,
523+
user_agent=user_agent,
524+
force_download=force_download,
525+
force_filename=force_filename,
526+
proxies=proxies,
527+
etag_timeout=etag_timeout,
528+
resume_download=resume_download,
529+
use_auth_token=use_auth_token,
530+
local_files_only=local_files_only,
531+
)

tests/test_file_download.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
PYTORCH_WEIGHTS_NAME,
2121
REPO_TYPE_DATASET,
2222
)
23-
from huggingface_hub.file_download import cached_download, filename_to_url, hf_hub_url
23+
from huggingface_hub.file_download import (
24+
cached_download,
25+
filename_to_url,
26+
hf_hub_download,
27+
hf_hub_url,
28+
)
2429

2530
from .testing_utils import (
2631
DUMMY_MODEL_ID,
@@ -146,3 +151,13 @@ def test_dataset_lfs_object(self):
146151
metadata,
147152
(url, '"95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"'),
148153
)
154+
155+
def test_hf_hub_download(self):
156+
filepath = hf_hub_download(
157+
DUMMY_MODEL_ID,
158+
filename=CONFIG_NAME,
159+
revision=REVISION_ID_DEFAULT,
160+
force_download=True,
161+
)
162+
metadata = filename_to_url(filepath)
163+
self.assertEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"')

0 commit comments

Comments
 (0)