Skip to content

Commit 0c7986e

Browse files
committed
Fix URL when uploading to proxy (#2167)
* Fix URL when uploading to proxy * fix multi-part LFS upload
1 parent 0dd879b commit 0c7986e

File tree

6 files changed

+52
-17
lines changed

6 files changed

+52
-17
lines changed

src/huggingface_hub/_commit_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def _upload_lfs_files(
399399
def _wrapped_lfs_upload(batch_action) -> None:
400400
try:
401401
operation = oid2addop[batch_action["oid"]]
402-
lfs_upload(operation=operation, lfs_batch_action=batch_action, headers=headers)
402+
lfs_upload(operation=operation, lfs_batch_action=batch_action, headers=headers, endpoint=endpoint)
403403
except Exception as exc:
404404
raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc
405405

src/huggingface_hub/hf_api.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@
8282
deserialize_event,
8383
)
8484
from .constants import (
85-
_HF_DEFAULT_ENDPOINT,
86-
_HF_DEFAULT_STAGING_ENDPOINT,
8785
DEFAULT_ETAG_TIMEOUT,
8886
DEFAULT_REQUEST_TIMEOUT,
8987
DEFAULT_REVISION,
@@ -123,6 +121,7 @@
123121
build_hf_headers,
124122
experimental,
125123
filter_repo_objects,
124+
fix_hf_endpoint_in_url,
126125
get_session,
127126
hf_raise_for_status,
128127
logging,
@@ -431,14 +430,7 @@ class RepoUrl(str):
431430
"""
432431

433432
def __new__(cls, url: Any, endpoint: Optional[str] = None):
434-
# check if a proxy has been set => if yes, update the returned URL to use the proxy
435-
if ENDPOINT not in (_HF_DEFAULT_ENDPOINT, _HF_DEFAULT_STAGING_ENDPOINT):
436-
url = url.replace(_HF_DEFAULT_ENDPOINT, ENDPOINT)
437-
url = url.replace(_HF_DEFAULT_STAGING_ENDPOINT, ENDPOINT)
438-
if endpoint not in (None, _HF_DEFAULT_ENDPOINT, _HF_DEFAULT_STAGING_ENDPOINT):
439-
url = url.replace(_HF_DEFAULT_ENDPOINT, endpoint)
440-
url = url.replace(_HF_DEFAULT_STAGING_ENDPOINT, endpoint)
441-
433+
url = fix_hf_endpoint_in_url(url, endpoint=endpoint)
442434
return super(RepoUrl, cls).__new__(cls, url)
443435

444436
def __init__(self, url: Any, endpoint: Optional[str] = None) -> None:

src/huggingface_hub/lfs.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from .utils import (
3333
build_hf_headers,
34+
fix_hf_endpoint_in_url,
3435
get_session,
3536
hf_raise_for_status,
3637
http_backoff,
@@ -193,6 +194,7 @@ def lfs_upload(
193194
lfs_batch_action: Dict,
194195
token: Optional[str] = None,
195196
headers: Optional[Dict[str, str]] = None,
197+
endpoint: Optional[str] = None,
196198
) -> None:
197199
"""
198200
Handles uploading a given object to the Hub with the LFS protocol.
@@ -230,22 +232,24 @@ def lfs_upload(
230232
# 2. Upload file (either single part or multi-part)
231233
header = upload_action.get("header", {})
232234
chunk_size = header.get("chunk_size")
235+
upload_url = fix_hf_endpoint_in_url(upload_action["href"], endpoint=endpoint)
233236
if chunk_size is not None:
234237
try:
235238
chunk_size = int(chunk_size)
236239
except (ValueError, TypeError):
237240
raise ValueError(
238241
f"Malformed response from LFS batch endpoint: `chunk_size` should be an integer. Got '{chunk_size}'."
239242
)
240-
_upload_multi_part(operation=operation, header=header, chunk_size=chunk_size, upload_url=upload_action["href"])
243+
_upload_multi_part(operation=operation, header=header, chunk_size=chunk_size, upload_url=upload_url)
241244
else:
242-
_upload_single_part(operation=operation, upload_url=upload_action["href"])
245+
_upload_single_part(operation=operation, upload_url=upload_url)
243246

244247
# 3. Verify upload went well
245248
if verify_action is not None:
246249
_validate_lfs_action(verify_action)
250+
verify_url = fix_hf_endpoint_in_url(verify_action["href"], endpoint)
247251
verify_resp = get_session().post(
248-
verify_action["href"],
252+
verify_url,
249253
headers=build_hf_headers(token=token, headers=headers),
250254
json={"oid": operation.upload_info.sha256.hex(), "size": operation.upload_info.size},
251255
)

src/huggingface_hub/utils/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,14 @@
4747
from ._git_credential import list_credential_helpers, set_git_credential, unset_git_credential
4848
from ._headers import LocalTokenNotFoundError, build_hf_headers, get_token_to_send
4949
from ._hf_folder import HfFolder
50-
from ._http import OfflineModeIsEnabled, configure_http_backend, get_session, http_backoff, reset_sessions
50+
from ._http import (
51+
OfflineModeIsEnabled,
52+
configure_http_backend,
53+
fix_hf_endpoint_in_url,
54+
get_session,
55+
http_backoff,
56+
reset_sessions,
57+
)
5158
from ._pagination import paginate
5259
from ._paths import IGNORE_GIT_FOLDER_PATTERNS, filter_repo_objects
5360
from ._runtime import (

src/huggingface_hub/utils/_http.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import uuid
2222
from functools import lru_cache
2323
from http import HTTPStatus
24-
from typing import Callable, Tuple, Type, Union
24+
from typing import Callable, Optional, Tuple, Type, Union
2525

2626
import requests
2727
from requests import Response
@@ -306,3 +306,16 @@ def http_backoff(
306306

307307
# Update sleep time for next retry
308308
sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff
309+
310+
311+
def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str:
312+
"""Replace the default endpoint in a URL by a custom one.
313+
314+
This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint.
315+
"""
316+
endpoint = endpoint or constants.ENDPOINT
317+
# check if a proxy has been set => if yes, update the returned URL to use the proxy
318+
if endpoint not in (None, constants._HF_DEFAULT_ENDPOINT, constants._HF_DEFAULT_STAGING_ENDPOINT):
319+
url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint)
320+
url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint)
321+
return url

tests/test_utils_http.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
import time
44
import unittest
55
from multiprocessing import Process, Queue
6-
from typing import Generator
6+
from typing import Generator, Optional
77
from unittest.mock import Mock, call, patch
88
from uuid import UUID
99

10+
import pytest
1011
import requests
1112
from requests import ConnectTimeout, HTTPError
1213

1314
from huggingface_hub.constants import ENDPOINT
1415
from huggingface_hub.utils._http import (
1516
OfflineModeIsEnabled,
1617
configure_http_backend,
18+
fix_hf_endpoint_in_url,
1719
get_session,
1820
http_backoff,
1921
reset_sessions,
@@ -292,3 +294,20 @@ def _is_uuid(string: str) -> bool:
292294
except ValueError:
293295
return False
294296
return str(uuid_obj) == string
297+
298+
299+
@pytest.mark.parametrize(
300+
("base_url", "endpoint", "expected_url"),
301+
[
302+
# Staging url => unchanged
303+
("https://hub-ci.huggingface.co/resolve/...", None, "https://hub-ci.huggingface.co/resolve/..."),
304+
# Prod url => unchanged
305+
("https://huggingface.co/resolve/...", None, "https://huggingface.co/resolve/..."),
306+
# Custom endpoint + staging url => fixed
307+
("https://hub-ci.huggingface.co/api/models", "https://mirror.co", "https://mirror.co/api/models"),
308+
# Custom endpoint + prod url => fixed
309+
("https://huggingface.co/api/models", "https://mirror.co", "https://mirror.co/api/models"),
310+
],
311+
)
312+
def test_fix_hf_endpoint_in_url(base_url: str, endpoint: Optional[str], expected_url: str) -> None:
313+
assert fix_hf_endpoint_in_url(base_url, endpoint) == expected_url

0 commit comments

Comments
 (0)