Skip to content

Commit 305f613

Browse files
Expose more hf.co API surface, and ability to download a snapshot of a repo to a local folder (#25)
* [ci] Slim down the matrix quite a bit * Expose more hf.co API surface - Expose the `filter` param on the list of models to only list models with a specific tag (method renamed to list_models, don't think it's too breaking, feedback welcome) - Expose a `model_info` method that gives you slightly more detailed info on a specific model, at a specific revision * Add `snapshot_download`: Download a whole snapshot of a repo's files at the specified revision * Docs, tests, lockfiles & metadata * In `force_filename` mode do not even store metadata file * Update README.md Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: Lysandre <[email protected]> Co-authored-by: Lysandre Debut <[email protected]>
1 parent 1b94216 commit 305f613

File tree

10 files changed

+258
-44
lines changed

10 files changed

+258
-44
lines changed

.github/workflows/python-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
runs-on: ubuntu-latest
1010
strategy:
1111
matrix:
12-
python-version: ["3.6", "3.7", "3.8", "3.9"]
12+
python-version: ["3.6", "3.9"]
1313

1414
steps:
1515
- uses: actions/checkout@v2

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ Parameters:
6363

6464
Check out the source code for all possible params (we'll create a real doc page in the future).
6565

66+
### Bonus: `snapshot_download`
67+
68+
`snapshot_download()` downloads all the files from the remote repository at the specified revision,
69+
stores it to disk (in a versioning-aware way) and returns its local file path.
70+
71+
Parameters:
72+
- a `repo_id` in the format `namespace/repository`
73+
- a `revision` on which the repository will be downloaded
74+
- a `cache_dir` which you can specify if you want to control where on disk the files are cached.
75+
6676
<br>
6777

6878
## Publish models to the huggingface.co hub

src/huggingface_hub/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@
3232
from .hf_api import HfApi, HfFolder
3333
from .hub_mixin import ModelHubMixin
3434
from .repository import Repository
35+
from .snapshot_download import snapshot_download

src/huggingface_hub/file_download.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def cached_download(
281281
cache_dir: Union[str, Path, None] = None,
282282
user_agent: Union[Dict, str, None] = None,
283283
force_download=False,
284+
force_filename: Optional[str] = None,
284285
proxies=None,
285286
etag_timeout=10,
286287
resume_download=False,
@@ -360,7 +361,9 @@ def cached_download(
360361
# etag is None
361362
pass
362363

363-
filename = url_to_filename(url, etag)
364+
filename = (
365+
force_filename if force_filename is not None else url_to_filename(url, etag)
366+
)
364367

365368
# get cache path to put the file
366369
cache_path = os.path.join(cache_dir, filename)
@@ -378,7 +381,11 @@ def cached_download(
378381
)
379382
if not file.endswith(".json") and not file.endswith(".lock")
380383
]
381-
if len(matching_files) > 0 and not force_download:
384+
if (
385+
len(matching_files) > 0
386+
and not force_download
387+
and force_filename is None
388+
):
382389
return os.path.join(cache_dir, matching_files[-1])
383390
else:
384391
# If files cannot be found and local_files_only=True,
@@ -444,10 +451,11 @@ def _resumable_file_manager() -> "io.BufferedWriter":
444451
logger.info("storing %s in cache at %s", url, cache_path)
445452
os.replace(temp_file.name, cache_path)
446453

447-
logger.info("creating metadata file for %s", cache_path)
448-
meta = {"url": url, "etag": etag}
449-
meta_path = cache_path + ".json"
450-
with open(meta_path, "w") as meta_file:
451-
json.dump(meta, meta_file)
454+
if force_filename is None:
455+
logger.info("creating metadata file for %s", cache_path)
456+
meta = {"url": url, "etag": etag}
457+
meta_path = cache_path + ".json"
458+
with open(meta_path, "w") as meta_file:
459+
json.dump(meta, meta_file)
452460

453461
return cache_path

src/huggingface_hub/hf_api.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616

1717
import os
18+
import warnings
1819
from os.path import expanduser
1920
from typing import Dict, List, Optional, Tuple
2021

@@ -36,7 +37,7 @@ def __init__(self, **kwargs):
3637
setattr(self, k, v)
3738

3839

39-
class ModelSibling:
40+
class ModelFile:
4041
"""
4142
Data structure that represents a public file inside a model, accessible from huggingface.co
4243
"""
@@ -55,6 +56,7 @@ class ModelInfo:
5556
def __init__(
5657
self,
5758
modelId: Optional[str] = None, # id of model
59+
sha: Optional[str] = None, # commit sha at the specified revision
5860
tags: List[str] = [],
5961
pipeline_tag: Optional[str] = None,
6062
siblings: Optional[
@@ -63,10 +65,11 @@ def __init__(
6365
**kwargs
6466
):
6567
self.modelId = modelId
68+
self.sha = sha
6669
self.tags = tags
6770
self.pipeline_tag = pipeline_tag
6871
self.siblings = (
69-
[ModelSibling(**x) for x in siblings] if siblings is not None else None
72+
[ModelFile(**x) for x in siblings] if siblings is not None else None
7073
)
7174
for k, v in kwargs.items():
7275
setattr(self, k, v)
@@ -108,16 +111,51 @@ def logout(self, token: str) -> None:
108111
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
109112
r.raise_for_status()
110113

111-
def model_list(self) -> List[ModelInfo]:
114+
def list_models(self, filter: Optional[str] = None) -> List[ModelInfo]:
112115
"""
113116
Get the public list of all the models on huggingface.co
114117
"""
115118
path = "{}/api/models".format(self.endpoint)
116-
r = requests.get(path)
119+
params = {"filter": filter, "full": True} if filter is not None else None
120+
r = requests.get(path, params=params)
117121
r.raise_for_status()
118122
d = r.json()
119123
return [ModelInfo(**x) for x in d]
120124

125+
def model_list(self) -> List[ModelInfo]:
126+
"""
127+
Deprecated method name, renamed to `list_models`.
128+
129+
Get the public list of all the models on huggingface.co
130+
"""
131+
warnings.warn(
132+
"This method has been renamed to `list_models` for consistency and will be removed in a future version."
133+
)
134+
return self.list_models()
135+
136+
def model_info(
137+
self, repo_id: str, revision: Optional[str] = None, token: Optional[str] = None
138+
) -> ModelInfo:
139+
"""
140+
Get info on one specific model on huggingface.co
141+
142+
Model can be private if you pass an acceptable token.
143+
"""
144+
path = (
145+
"{}/api/models/{repo_id}".format(self.endpoint, repo_id=repo_id)
146+
if revision is None
147+
else "{}/api/models/{repo_id}/revision/{revision}".format(
148+
self.endpoint, repo_id=repo_id, revision=revision
149+
)
150+
)
151+
headers = (
152+
{"authorization": "Bearer {}".format(token)} if token is not None else None
153+
)
154+
r = requests.get(path, headers=headers)
155+
r.raise_for_status()
156+
d = r.json()
157+
return ModelInfo(**d)
158+
121159
def list_repos_objs(
122160
self, token: str, organization: Optional[str] = None
123161
) -> List[RepoObj]:
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
from pathlib import Path
3+
from typing import Optional, Union
4+
5+
from .constants import HUGGINGFACE_HUB_CACHE
6+
from .file_download import cached_download, hf_hub_url
7+
from .hf_api import HfApi
8+
9+
10+
REPO_ID_SEPARATOR = "__"
11+
# ^ make sure this substring is not allowed in repo_ids on hf.co
12+
13+
14+
def snapshot_download(
15+
repo_id: str,
16+
revision: Optional[str] = None,
17+
cache_dir: Union[str, Path, None] = None,
18+
) -> str:
19+
"""
20+
Downloads a whole snapshot of a repo's files at the specified revision.
21+
This is useful when you want all files from a repo, because you don't know
22+
which ones you will need a priori.
23+
All files are nested inside a folder in order to keep their actual filename
24+
relative to that folder.
25+
26+
An alternative would be to just clone a repo but this would require that
27+
the user always has git and git-lfs installed, and properly configured.
28+
29+
Note: at some point maybe this format of storage should actually replace
30+
the flat storage structure we've used so far (initially from allennlp
31+
if I remember correctly).
32+
33+
Return:
34+
Local folder path (string) of repo snapshot
35+
"""
36+
if cache_dir is None:
37+
cache_dir = HUGGINGFACE_HUB_CACHE
38+
if isinstance(cache_dir, Path):
39+
cache_dir = str(cache_dir)
40+
41+
_api = HfApi()
42+
model_info = _api.model_info(repo_id=repo_id, revision=revision)
43+
44+
storage_folder = os.path.join(
45+
cache_dir, repo_id.replace("/", REPO_ID_SEPARATOR) + "." + model_info.sha
46+
)
47+
48+
for model_file in model_info.siblings:
49+
url = hf_hub_url(
50+
repo_id, filename=model_file.rfilename, revision=model_info.sha
51+
)
52+
relative_filepath = os.path.join(*model_file.rfilename.split("/"))
53+
54+
# Create potential nested dir
55+
nested_dirname = os.path.dirname(
56+
os.path.join(storage_folder, relative_filepath)
57+
)
58+
os.makedirs(nested_dirname, exist_ok=True)
59+
60+
path = cached_download(
61+
url, cache_dir=storage_folder, force_filename=relative_filepath
62+
)
63+
64+
if os.path.exists(path + ".lock"):
65+
os.remove(path + ".lock")
66+
67+
return storage_folder

tests/test_file_download.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,24 @@
2323
from huggingface_hub.file_download import cached_download, filename_to_url, hf_hub_url
2424

2525
from .testing_utils import (
26-
DUMMY_UNKWOWN_IDENTIFIER,
26+
DUMMY_MODEL_ID,
27+
DUMMY_MODEL_ID_PINNED_SHA1,
28+
DUMMY_MODEL_ID_PINNED_SHA256,
29+
DUMMY_MODEL_ID_REVISION_INVALID,
30+
DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT,
2731
SAMPLE_DATASET_IDENTIFIER,
2832
OfflineSimulationMode,
2933
offline,
3034
)
3135

3236

33-
MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
34-
# An actual model hosted on huggingface.co
37+
REVISION_ID_DEFAULT = "main"
38+
# Default branch name
3539

3640
DATASET_ID = SAMPLE_DATASET_IDENTIFIER
3741
# An actual dataset hosted on huggingface.co
3842

3943

40-
REVISION_ID_DEFAULT = "main"
41-
# Default branch name
42-
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
43-
# One particular commit (not the top of `main`)
44-
REVISION_ID_INVALID = "aaaaaaa"
45-
# This commit does not exist, so we should 404.
46-
47-
PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
48-
# Sha-1 of config.json on the top of `main`, for checking purposes
49-
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
50-
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
51-
5244
DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT = "e25d55a1c4933f987c46cc75d8ffadd67f257c61"
5345
# One particular commit for DATASET_ID
5446
DATASET_SAMPLE_PY_FILE = "custom_squad.py"
@@ -62,10 +54,12 @@ def test_bogus_url(self):
6254

6355
def test_no_connection(self):
6456
invalid_url = hf_hub_url(
65-
MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID
57+
DUMMY_MODEL_ID,
58+
filename=CONFIG_NAME,
59+
revision=DUMMY_MODEL_ID_REVISION_INVALID,
6660
)
6761
valid_url = hf_hub_url(
68-
MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT
62+
DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT
6963
)
7064
self.assertIsNotNone(cached_download(valid_url, force_download=True))
7165
for offline_mode in OfflineSimulationMode:
@@ -78,39 +72,47 @@ def test_no_connection(self):
7872

7973
def test_file_not_found(self):
8074
# Valid revision (None) but missing file.
81-
url = hf_hub_url(MODEL_ID, filename="missing.bin")
75+
url = hf_hub_url(DUMMY_MODEL_ID, filename="missing.bin")
8276
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
8377
_ = cached_download(url)
8478

8579
def test_revision_not_found(self):
8680
# Valid file but missing revision
87-
url = hf_hub_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
81+
url = hf_hub_url(
82+
DUMMY_MODEL_ID,
83+
filename=CONFIG_NAME,
84+
revision=DUMMY_MODEL_ID_REVISION_INVALID,
85+
)
8886
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
8987
_ = cached_download(url)
9088

9189
def test_standard_object(self):
92-
url = hf_hub_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
90+
url = hf_hub_url(
91+
DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT
92+
)
9393
filepath = cached_download(url, force_download=True)
9494
metadata = filename_to_url(filepath)
95-
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
95+
self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA1}"'))
9696

9797
def test_standard_object_rev(self):
9898
# Same object, but different revision
9999
url = hf_hub_url(
100-
MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT
100+
DUMMY_MODEL_ID,
101+
filename=CONFIG_NAME,
102+
revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT,
101103
)
102104
filepath = cached_download(url, force_download=True)
103105
metadata = filename_to_url(filepath)
104-
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
106+
self.assertNotEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"')
105107
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
106108

107109
def test_lfs_object(self):
108110
url = hf_hub_url(
109-
MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT
111+
DUMMY_MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT
110112
)
111113
filepath = cached_download(url, force_download=True)
112114
metadata = filename_to_url(filepath)
113-
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
115+
self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA256}"'))
114116

115117
def test_dataset_standard_object_rev(self):
116118
url = hf_hub_url(
@@ -129,7 +131,7 @@ def test_dataset_standard_object_rev(self):
129131
# now let's download
130132
filepath = cached_download(url, force_download=True)
131133
metadata = filename_to_url(filepath)
132-
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
134+
self.assertNotEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"')
133135

134136
def test_dataset_lfs_object(self):
135137
url = hf_hub_url(

0 commit comments

Comments
 (0)