diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 4fc063796a..3e980f31a2 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -83,6 +83,11 @@ # Regex to check if the file etag IS a valid sha256 REGEX_SHA256 = re.compile(r"^[0-9a-f]{64}$") +# Redirect allowlist for use by relative redirect wrapper +# Example: HF_DOWNLOAD_REDIRECT_ALLOWLIST=opendns.com +REDIRECT_ALLOWLIST = os.environ.get("HF_DOWNLOAD_REDIRECT_ALLOWLIST", "").split(",") +REDIRECT_ALLOWLIST = [domain for domain in REDIRECT_ALLOWLIST if len(domain) > 0] + _are_symlinks_supported_in_dir: Dict[str, bool] = {} @@ -262,7 +267,7 @@ def hf_hub_url( def _request_wrapper( - method: HTTP_METHOD_T, url: str, *, follow_relative_redirects: bool = False, **params + method: HTTP_METHOD_T, url: str, *, follow_relative_redirects: bool = False, base_domain = "", **params ) -> requests.Response: """Wrapper around requests methods to follow relative redirects if `follow_relative_redirects=True` even when `allow_redirection=False`. @@ -283,6 +288,10 @@ def _request_wrapper( """ # Recursively follow relative redirects if follow_relative_redirects: + # Allow return to the initial domain + if base_domain == "": + base_domain = urlparse(url).netloc + response = _request_wrapper( method=method, url=url, @@ -294,15 +303,18 @@ def _request_wrapper( # This is useful in case of a renamed repository. if 300 <= response.status_code <= 399: parsed_target = urlparse(response.headers["Location"]) - if parsed_target.netloc == "": + if parsed_target.netloc == base_domain or parsed_target.netloc == "" or any(parsed_target.netloc.endswith(domain) for domain in REDIRECT_ALLOWLIST): # This means it is a relative 'location' headers, as allowed by RFC 7231. # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') # We want to follow this relative redirect ! # # Highly inspired by `resolve_redirects` from requests library. # See https://github.com/psf/requests/blob/main/requests/sessions.py#L159 - next_url = urlparse(url)._replace(path=parsed_target.path).geturl() - return _request_wrapper(method=method, url=next_url, follow_relative_redirects=True, **params) + if parsed_target.netloc == "": + next_url = urlparse(url)._replace(path=parsed_target.path, query=parsed_target.query).geturl() + else: + next_url = parsed_target.geturl() + return _request_wrapper(method=method, url=next_url, follow_relative_redirects=True, base_domain=base_domain, **params) return response # Perform request and return if status_code is not in the retry list.