Skip to content

Commit 2bfc9dd

Browse files
authored
Support for git-based dataset repos (#14)
* Export more constants * unrelated change= do not gobble certain kinds of requests.ConnectionError * datasets: file_download cc @lhoestq * Exposing this as a constant * Implement hf_api for datasets + Split constants into their own file * integration test + CLI
1 parent 407c838 commit 2bfc9dd

File tree

9 files changed

+181
-44
lines changed

9 files changed

+181
-44
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Integration inside a library is super simple. We expose two functions, `hf_hub_u
4242
### `hf_hub_url`
4343

4444
`hf_hub_url()` takes:
45-
- a model id (like `julien-c/EsperBERTo-small` i.e. a user or organization name and a repo name, separated by `/`),
45+
- a repo id (e.g. a model id like `julien-c/EsperBERTo-small` i.e. a user or organization name and a repo name, separated by `/`),
4646
- a filename (like `pytorch_model.bin`),
4747
- and an optional git revision id (can be a branch name, a tag, or a commit hash)
4848

src/huggingface_hub/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,15 @@
1818

1919
__version__ = "0.0.1"
2020

21-
from .file_download import HUGGINGFACE_CO_URL_TEMPLATE, cached_download, hf_hub_url
21+
from .constants import (
22+
CONFIG_NAME,
23+
FLAX_WEIGHTS_NAME,
24+
HUGGINGFACE_CO_URL_HOME,
25+
HUGGINGFACE_CO_URL_TEMPLATE,
26+
PYTORCH_WEIGHTS_NAME,
27+
REPO_TYPE_DATASET,
28+
TF2_WEIGHTS_NAME,
29+
TF_WEIGHTS_NAME,
30+
)
31+
from .file_download import cached_download, hf_hub_url
2232
from .hf_api import HfApi, HfFolder

src/huggingface_hub/commands/user.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
from typing import List, Union
1919

2020
from huggingface_hub.commands import BaseHuggingfaceCLICommand
21+
from huggingface_hub.constants import (
22+
REPO_TYPE_DATASET,
23+
REPO_TYPE_DATASET_URL_PREFIX,
24+
REPO_TYPES,
25+
)
2126
from huggingface_hub.hf_api import HfApi, HfFolder
2227
from requests.exceptions import HTTPError
2328

@@ -57,7 +62,12 @@ def register_subcommand(parser: ArgumentParser):
5762
repo_create_parser.add_argument(
5863
"name",
5964
type=str,
60-
help="Name for your model's repo. Will be namespaced under your username to build the model id.",
65+
help="Name for your repo. Will be namespaced under your username to build the repo id.",
66+
)
67+
repo_create_parser.add_argument(
68+
"--type",
69+
type=str,
70+
help='Optional: repo_type: set to "dataset" if creating a dataset, default is model.',
6171
)
6272
repo_create_parser.add_argument(
6373
"--organization", type=str, help="Optional: organization namespace."
@@ -223,11 +233,16 @@ def run(self):
223233
self.args.organization if self.args.organization is not None else user
224234
)
225235

226-
print(
227-
"You are about to create {}".format(
228-
ANSI.bold(namespace + "/" + self.args.name)
229-
)
230-
)
236+
repo_id = f"{namespace}/{self.args.name}"
237+
238+
if self.args.type not in REPO_TYPES:
239+
print("Invalid repo --type")
240+
exit(1)
241+
242+
if self.args.type == REPO_TYPE_DATASET:
243+
repo_id = REPO_TYPE_DATASET_URL_PREFIX + repo_id
244+
245+
print("You are about to create {}".format(ANSI.bold(repo_id)))
231246

232247
if not self.args.yes:
233248
choice = input("Proceed? [Y/n] ").lower()
@@ -236,7 +251,10 @@ def run(self):
236251
exit()
237252
try:
238253
url = self._api.create_repo(
239-
token, name=self.args.name, organization=self.args.organization
254+
token,
255+
name=self.args.name,
256+
organization=self.args.organization,
257+
repo_type=self.args.type,
240258
)
241259
except HTTPError as e:
242260
print(e)

src/huggingface_hub/constants.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
3+
4+
# Constants for file downloads
5+
6+
PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
7+
TF2_WEIGHTS_NAME = "tf_model.h5"
8+
TF_WEIGHTS_NAME = "model.ckpt"
9+
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
10+
CONFIG_NAME = "config.json"
11+
12+
HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/"
13+
14+
HUGGINGFACE_CO_URL_TEMPLATE = (
15+
"https://huggingface.co/{repo_id}/resolve/{revision}/{filename}"
16+
)
17+
18+
REPO_TYPE_DATASET = "dataset"
19+
REPO_TYPES = [None, REPO_TYPE_DATASET]
20+
21+
REPO_TYPE_DATASET_URL_PREFIX = "datasets/"
22+
23+
24+
# default cache
25+
hf_cache_home = os.path.expanduser(
26+
os.getenv(
27+
"HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")
28+
)
29+
)
30+
default_cache_path = os.path.join(hf_cache_home, "hub")
31+
32+
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)

src/huggingface_hub/file_download.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
from filelock import FileLock
1919

2020
from . import __version__
21+
from .constants import (
22+
HUGGINGFACE_CO_URL_TEMPLATE,
23+
HUGGINGFACE_HUB_CACHE,
24+
REPO_TYPE_DATASET,
25+
REPO_TYPE_DATASET_URL_PREFIX,
26+
REPO_TYPES,
27+
)
2128
from .hf_api import HfFolder
2229

2330

@@ -55,34 +62,11 @@ def is_tf_available():
5562
return _tf_available
5663

5764

58-
# Constants for file downloads
59-
60-
PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
61-
TF2_WEIGHTS_NAME = "tf_model.h5"
62-
TF_WEIGHTS_NAME = "model.ckpt"
63-
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
64-
CONFIG_NAME = "config.json"
65-
66-
HUGGINGFACE_CO_URL_TEMPLATE = (
67-
"https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
68-
)
69-
70-
71-
# default cache
72-
hf_cache_home = os.path.expanduser(
73-
os.getenv(
74-
"HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")
75-
)
76-
)
77-
default_cache_path = os.path.join(hf_cache_home, "hub")
78-
79-
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)
80-
81-
8265
def hf_hub_url(
83-
model_id: str,
66+
repo_id: str,
8467
filename: str,
8568
subfolder: Optional[str] = None,
69+
repo_type: Optional[str] = None,
8670
revision: Optional[str] = None,
8771
) -> str:
8872
"""
@@ -103,10 +87,16 @@ def hf_hub_url(
10387
if subfolder is not None:
10488
filename = f"{subfolder}/{filename}"
10589

90+
if repo_type not in REPO_TYPES:
91+
raise ValueError("Invalid repo type")
92+
93+
if repo_type == REPO_TYPE_DATASET:
94+
repo_id = REPO_TYPE_DATASET_URL_PREFIX + repo_id
95+
10696
if revision is None:
10797
revision = "main"
10898
return HUGGINGFACE_CO_URL_TEMPLATE.format(
109-
model_id=model_id, revision=revision, filename=filename
99+
repo_id=repo_id, revision=revision, filename=filename
110100
)
111101

112102

@@ -286,8 +276,17 @@ def cached_download(
286276
# between the HEAD and the GET (unlikely, but hey).
287277
if 300 <= r.status_code <= 399:
288278
url_to_download = r.headers["Location"]
289-
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
290-
# etag is already None
279+
except (
280+
requests.exceptions.ConnectionError,
281+
requests.exceptions.Timeout,
282+
) as exc:
283+
# Actually raise for those subclasses of ConnectionError:
284+
if isinstance(exc, requests.exceptions.SSLError) or isinstance(
285+
exc, requests.exceptions.ProxyError
286+
):
287+
raise exc
288+
# Otherwise, our Internet connection is down.
289+
# etag is None
291290
pass
292291

293292
filename = url_to_filename(url, etag)

src/huggingface_hub/hf_api.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import requests
2222

23+
from .constants import REPO_TYPES
24+
2325

2426
ENDPOINT = "https://huggingface.co"
2527

@@ -139,6 +141,7 @@ def create_repo(
139141
name: str,
140142
organization: Optional[str] = None,
141143
private: Optional[bool] = None,
144+
repo_type: Optional[str] = None,
142145
exist_ok=False,
143146
lfsmultipartthresh: Optional[int] = None,
144147
) -> str:
@@ -150,12 +153,20 @@ def create_repo(
150153
Params:
151154
private: Whether the model repo should be private (requires a paid huggingface.co account)
152155
156+
repo_type: Set to "dataset" if creating a dataset, default is model
157+
153158
exist_ok: Do not raise an error if repo already exists
154159
155160
lfsmultipartthresh: Optional: internal param for testing purposes.
156161
"""
157162
path = "{}/api/repos/create".format(self.endpoint)
163+
164+
if repo_type not in REPO_TYPES:
165+
raise ValueError("Invalid repo type")
166+
158167
json = {"name": name, "organization": organization, "private": private}
168+
if repo_type is not None:
169+
json["type"] = repo_type
159170
if lfsmultipartthresh is not None:
160171
json["lfsmultipartthresh"] = lfsmultipartthresh
161172
r = requests.post(
@@ -169,7 +180,13 @@ def create_repo(
169180
d = r.json()
170181
return d["url"]
171182

172-
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
183+
def delete_repo(
184+
self,
185+
token: str,
186+
name: str,
187+
organization: Optional[str] = None,
188+
repo_type: Optional[str] = None,
189+
):
173190
"""
174191
HuggingFace git-based system, used for models.
175192
@@ -178,10 +195,18 @@ def delete_repo(self, token: str, name: str, organization: Optional[str] = None)
178195
CAUTION(this is irreversible).
179196
"""
180197
path = "{}/api/repos/delete".format(self.endpoint)
198+
199+
if repo_type not in REPO_TYPES:
200+
raise ValueError("Invalid repo type")
201+
202+
json = {"name": name, "organization": organization}
203+
if repo_type is not None:
204+
json["type"] = repo_type
205+
181206
r = requests.delete(
182207
path,
183208
headers={"authorization": "Bearer {}".format(token)},
184-
json={"name": name, "organization": organization},
209+
json=json,
185210
)
186211
r.raise_for_status()
187212

tests/test_file_download.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,23 @@
1515
import unittest
1616

1717
import requests
18-
from huggingface_hub.file_download import (
18+
from huggingface_hub.constants import (
1919
CONFIG_NAME,
2020
PYTORCH_WEIGHTS_NAME,
21-
cached_download,
22-
filename_to_url,
23-
hf_hub_url,
21+
REPO_TYPE_DATASET,
2422
)
23+
from huggingface_hub.file_download import cached_download, filename_to_url, hf_hub_url
2524

26-
from .testing_utils import DUMMY_UNKWOWN_IDENTIFIER
25+
from .testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SAMPLE_DATASET_IDENTIFIER
2726

2827

2928
MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
3029
# An actual model hosted on huggingface.co
3130

31+
DATASET_ID = SAMPLE_DATASET_IDENTIFIER
32+
# An actual dataset hosted on huggingface.co
33+
34+
3235
REVISION_ID_DEFAULT = "main"
3336
# Default branch name
3437
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
@@ -41,6 +44,10 @@
4144
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
4245
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
4346

47+
DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT = "e25d55a1c4933f987c46cc75d8ffadd67f257c61"
48+
# One particular commit for DATASET_ID
49+
DATASET_SAMPLE_PY_FILE = "custom_squad.py"
50+
4451

4552
class CachedDownloadTests(unittest.TestCase):
4653
def test_bogus_url(self):
@@ -86,3 +93,36 @@ def test_lfs_object(self):
8693
filepath = cached_download(url, force_download=True)
8794
metadata = filename_to_url(filepath)
8895
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
96+
97+
def test_dataset_standard_object_rev(self):
98+
url = hf_hub_url(
99+
DATASET_ID,
100+
filename=DATASET_SAMPLE_PY_FILE,
101+
repo_type=REPO_TYPE_DATASET,
102+
revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT,
103+
)
104+
# We can also just get the same url by prefixing "datasets" to repo_id:
105+
url2 = hf_hub_url(
106+
repo_id=f"datasets/{DATASET_ID}",
107+
filename=DATASET_SAMPLE_PY_FILE,
108+
revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT,
109+
)
110+
self.assertEqual(url, url2)
111+
# now let's download
112+
filepath = cached_download(url, force_download=True)
113+
metadata = filename_to_url(filepath)
114+
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
115+
116+
def test_dataset_lfs_object(self):
117+
url = hf_hub_url(
118+
DATASET_ID,
119+
filename="dev-v1.1.json",
120+
repo_type=REPO_TYPE_DATASET,
121+
revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT,
122+
)
123+
filepath = cached_download(url, force_download=True)
124+
metadata = filename_to_url(filepath)
125+
self.assertEqual(
126+
metadata,
127+
(url, '"95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"'),
128+
)

tests/test_hf_api.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import time
2020
import unittest
2121

22+
from huggingface_hub.constants import REPO_TYPE_DATASET
2223
from huggingface_hub.hf_api import HfApi, HfFolder, ModelInfo, RepoObj
2324
from requests.exceptions import HTTPError
2425

@@ -33,6 +34,7 @@
3334

3435
REPO_NAME = "my-model-{}".format(int(time.time() * 10e3))
3536
REPO_NAME_LARGE_FILE = "my-model-largefiles-{}".format(int(time.time() * 10e3))
37+
DATASET_REPO_NAME = "my-dataset-{}".format(int(time.time() * 10e3))
3638
WORKING_REPO_DIR = os.path.join(
3739
os.path.dirname(os.path.abspath(__file__)), "fixtures/working_repo"
3840
)
@@ -78,6 +80,14 @@ def test_create_and_delete_repo(self):
7880
self._api.create_repo(token=self._token, name=REPO_NAME)
7981
self._api.delete_repo(token=self._token, name=REPO_NAME)
8082

83+
def test_create_and_delete_dataset_repo(self):
84+
self._api.create_repo(
85+
token=self._token, name=REPO_NAME, repo_type=REPO_TYPE_DATASET
86+
)
87+
self._api.delete_repo(
88+
token=self._token, name=REPO_NAME, repo_type=REPO_TYPE_DATASET
89+
)
90+
8191

8292
class HfApiPublicTest(unittest.TestCase):
8393
def test_staging_model_list(self):

tests/testing_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
99
# Example model ids
1010

11+
SAMPLE_DATASET_IDENTIFIER = "lhoestq/custom_squad"
12+
# Example dataset ids
13+
1114

1215
def parse_flag_from_env(key, default=False):
1316
try:

0 commit comments

Comments
 (0)