Skip to content

Commit 9edcb30

Browse files
authored
πŸ‘¨β€πŸ’» Configure HF Hub URL with environment variable (#815)
* Make ENDPOINT configurable via an environment variable * ⬆ Needs requests>=2.27 for JSONDecodeError See https://docs.python-requests.org/en/latest/community/updates/#id2 * πŸ”§ Add private param to repo_type_and_id_from_hf_id Only for testing purposes * πŸ’„ Code quality * 🩹 Parentheses are important * πŸ‘Œ Suggested implementation * πŸ”₯ We don't need this anymore, do we ? * Rename hf_api to client
1 parent b287fba commit 9edcb30

File tree

6 files changed

+35
-30
lines changed

6 files changed

+35
-30
lines changed

β€Žsetup.pyβ€Ž

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def get_version() -> str:
1313

1414
install_requires = [
1515
"filelock",
16-
"requests",
16+
"requests>=2.27",
1717
"tqdm",
1818
"pyyaml",
1919
"typing-extensions>=3.7.4.3", # to be able to import TypeAlias
@@ -27,11 +27,7 @@ def get_version() -> str:
2727
"torch",
2828
]
2929

30-
extras["tensorflow"] = [
31-
"tensorflow",
32-
"pydot",
33-
"graphviz"
34-
]
30+
extras["tensorflow"] = ["tensorflow", "pydot", "graphviz"]
3531

3632
extras["testing"] = [
3733
"pytest",

β€Žsrc/huggingface_hub/constants.pyβ€Ž

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@
2424
os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
2525
)
2626

27-
ENDPOINT = (
27+
ENDPOINT = os.getenv("HF_HUB_URL") or (
2828
"https://moon-staging.huggingface.co" if _staging_mode else "https://huggingface.co"
2929
)
3030

31-
3231
HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"
3332

3433
REPO_TYPE_DATASET = "dataset"

β€Žsrc/huggingface_hub/hf_api.pyβ€Ž

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import os
16+
import re
1617
import subprocess
1718
import sys
1819
import warnings
@@ -80,7 +81,7 @@ def _validate_repo_id_deprecation(repo_id, name, organization):
8081
return name, organization
8182

8283

83-
def repo_type_and_id_from_hf_id(hf_id: str):
84+
def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None):
8485
"""
8586
Returns the repo type and ID from a huggingface.co URL linking to a
8687
repository
@@ -94,16 +95,19 @@ def repo_type_and_id_from_hf_id(hf_id: str):
9495
- <repo_type>/<namespace>/<repo_id>
9596
- <namespace>/<repo_id>
9697
- <repo_id>
98+
hub_url (`str`, *optional*):
99+
The URL of the HuggingFace Hub, defaults to https://huggingface.co
97100
"""
98-
is_hf_url = "huggingface.co" in hf_id and "@" not in hf_id
101+
hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else ENDPOINT)
102+
is_hf_url = hub_url in hf_id and "@" not in hf_id
99103
url_segments = hf_id.split("/")
100104
is_hf_id = len(url_segments) <= 3
101105

102106
if is_hf_url:
103107
namespace, repo_id = url_segments[-2:]
104-
if namespace == "huggingface.co":
108+
if namespace == hub_url:
105109
namespace = None
106-
if len(url_segments) > 2 and "huggingface.co" not in url_segments[-3]:
110+
if len(url_segments) > 2 and hub_url not in url_segments[-3]:
107111
repo_type = url_segments[-3]
108112
else:
109113
repo_type = None

β€Žsrc/huggingface_hub/repository.pyβ€Ž

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
from contextlib import contextmanager
99
from pathlib import Path
1010
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
11+
from urllib.parse import urlparse
1112

1213
from tqdm.auto import tqdm
1314

1415
from huggingface_hub.constants import REPO_TYPES_URL_PREFIXES, REPOCARD_NAME
1516
from huggingface_hub.repocard import metadata_load, metadata_save
1617

17-
from .hf_api import ENDPOINT, HfApi, HfFolder, repo_type_and_id_from_hf_id
18+
from .hf_api import HfApi, HfFolder, repo_type_and_id_from_hf_id
1819
from .lfs import LFS_MULTIPART_UPLOAD_COMMAND
1920
from .utils import logging
2021

@@ -441,6 +442,7 @@ def __init__(
441442
revision: Optional[str] = None,
442443
private: bool = False,
443444
skip_lfs_files: bool = False,
445+
client: Optional[HfApi] = None,
444446
):
445447
"""
446448
Instantiate a local clone of a git repo.
@@ -482,6 +484,9 @@ def __init__(
482484
whether the repository is private or not.
483485
skip_lfs_files (`bool`, *optional*, defaults to `False`):
484486
whether to skip git-LFS files or not.
487+
client (`HfApi`, *optional*):
488+
Instance of HfApi to use when calling the HF Hub API.
489+
A new instance will be created if this is left to `None`.
485490
"""
486491

487492
os.makedirs(local_dir, exist_ok=True)
@@ -490,6 +495,7 @@ def __init__(
490495
self.command_queue = []
491496
self.private = private
492497
self.skip_lfs_files = skip_lfs_files
498+
self.client = client if client is not None else HfApi()
493499

494500
self.check_git_versions()
495501

@@ -513,7 +519,7 @@ def __init__(
513519
if self.huggingface_token is not None and (
514520
git_email is None or git_user is None
515521
):
516-
user = HfApi().whoami(self.huggingface_token)
522+
user = self.client.whoami(self.huggingface_token)
517523

518524
if git_email is None:
519525
git_email = user["email"]
@@ -631,34 +637,36 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
631637
"Couldn't load Hugging Face Authorization Token. Credentials are required to work with private repositories."
632638
" Please login in using `huggingface-cli login` or provide your token manually with the `use_auth_token` key."
633639
)
634-
api = HfApi()
635-
636-
if "huggingface.co" in repo_url or (
640+
hub_url = self.client.endpoint
641+
if hub_url in repo_url or (
637642
"http" not in repo_url and len(repo_url.split("/")) <= 2
638643
):
639-
repo_type, namespace, repo_id = repo_type_and_id_from_hf_id(repo_url)
644+
repo_type, namespace, repo_id = repo_type_and_id_from_hf_id(
645+
repo_url, hub_url=hub_url
646+
)
640647

641648
if repo_type is not None:
642649
self.repo_type = repo_type
643650

644-
repo_url = ENDPOINT + "/"
651+
repo_url = hub_url + "/"
645652

646653
if self.repo_type in REPO_TYPES_URL_PREFIXES:
647654
repo_url += REPO_TYPES_URL_PREFIXES[self.repo_type]
648655

649656
if token is not None:
650-
whoami_info = api.whoami(token)
657+
whoami_info = self.client.whoami(token)
651658
user = whoami_info["name"]
652659
valid_organisations = [org["name"] for org in whoami_info["orgs"]]
653660

654661
if namespace is not None:
655662
repo_id = f"{namespace}/{repo_id}"
656663
repo_url += repo_id
657664

658-
repo_url = repo_url.replace("https://", f"https://user:{token}@")
665+
scheme = urlparse(repo_url).scheme
666+
repo_url = repo_url.replace(f"{scheme}://", f"{scheme}://user:{token}@")
659667

660668
if namespace == user or namespace in valid_organisations:
661-
api.create_repo(
669+
self.client.create_repo(
662670
repo_id=repo_id,
663671
token=token,
664672
repo_type=self.repo_type,
@@ -671,7 +679,7 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
671679
repo_url += repo_id
672680

673681
# For error messages, it's cleaner to show the repo url without the token.
674-
clean_repo_url = re.sub(r"https://.*@", "https://", repo_url)
682+
clean_repo_url = re.sub(r"(https?)://.*@", r"\1://", repo_url)
675683
try:
676684
subprocess.run(
677685
"git lfs install".split(),

β€Žtests/test_hf_api.pyβ€Ž

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1277,4 +1277,7 @@ def test_repo_type_and_id_from_hf_id(self):
12771277
}
12781278

12791279
for key, value in possible_values.items():
1280-
self.assertEqual(repo_type_and_id_from_hf_id(key), tuple(value))
1280+
self.assertEqual(
1281+
repo_type_and_id_from_hf_id(key, hub_url="https://huggingface.co"),
1282+
tuple(value),
1283+
)

β€Žtests/testing_utils.pyβ€Ž

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,7 @@ def with_production_testing(func):
169169
ENDPOINT_PRODUCTION,
170170
)
171171

172-
repository = patch(
173-
"huggingface_hub.repository.ENDPOINT",
174-
ENDPOINT_PRODUCTION,
175-
)
176-
177-
return repository(hf_api(file_download(func)))
172+
return hf_api(file_download(func))
178173

179174

180175
def retry_endpoint(function, number_of_tries: int = 3, wait_time: int = 5):

0 commit comments

Comments
Β (0)