Skip to content

Commit be902d8

Browse files
lhoestqjulien-c
andauthored
Add HF_HUB_OFFLINE env var (#22)
* offline mode simulator * add HF_HUB_OFFLINE env var * add tests * doc tweak * Re-align from transformers and @aaugustin cc @lhoestq Co-authored-by: Julien Chaumond <[email protected]>
1 parent 5b1c1e6 commit be902d8

File tree

5 files changed

+221
-15
lines changed

5 files changed

+221
-15
lines changed

src/huggingface_hub/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import os
22

33

4+
# Possible values for env variables
5+
6+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
7+
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
8+
49
# Constants for file downloads
510

611
PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
@@ -30,3 +35,9 @@
3035
default_cache_path = os.path.join(hf_cache_home, "hub")
3136

3237
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)
38+
39+
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "AUTO").upper()
40+
if HF_HUB_OFFLINE in ENV_VARS_TRUE_VALUES:
41+
HF_HUB_OFFLINE = True
42+
else:
43+
HF_HUB_OFFLINE = False

src/huggingface_hub/file_download.py

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import sys
88
import tempfile
9+
import time
910
from contextlib import contextmanager
1011
from functools import partial
1112
from hashlib import sha256
@@ -16,6 +17,7 @@
1617

1718
import requests
1819
from filelock import FileLock
20+
from huggingface_hub import constants
1921

2022
from . import __version__
2123
from .constants import (
@@ -171,20 +173,89 @@ def http_user_agent(
171173
return ua
172174

173175

176+
class OfflineModeIsEnabled(ConnectionError):
177+
pass
178+
179+
180+
def _raise_if_offline_mode_is_enabled(msg: Optional[str] = None):
181+
"""Raise a OfflineModeIsEnabled error (subclass of ConnectionError) if HF_HUB_OFFLINE is True."""
182+
if constants.HF_HUB_OFFLINE:
183+
raise OfflineModeIsEnabled(
184+
"Offline mode is enabled."
185+
if msg is None
186+
else "Offline mode is enabled. " + str(msg)
187+
)
188+
189+
190+
def _request_with_retry(
191+
method: str,
192+
url: str,
193+
max_retries: int = 0,
194+
base_wait_time: float = 0.5,
195+
max_wait_time: float = 2,
196+
timeout: float = 10.0,
197+
**params,
198+
) -> requests.Response:
199+
"""Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff.
200+
201+
Note that if the environment variable HF_HUB_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised.
202+
203+
Args:
204+
method (str): HTTP method, such as 'GET' or 'HEAD'
205+
url (str): The URL of the ressource to fetch
206+
max_retries (int): Maximum number of retries, defaults to 0 (no retries)
207+
base_wait_time (float): Duration (in seconds) to wait before retrying the first time. Wait time between
208+
retries then grows exponentially, capped by max_wait_time.
209+
max_wait_time (float): Maximum amount of time between two retries, in seconds
210+
**params: Params to pass to `requests.request`
211+
"""
212+
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
213+
tries, success = 0, False
214+
while not success:
215+
tries += 1
216+
try:
217+
response = requests.request(
218+
method=method.upper(), url=url, timeout=timeout, **params
219+
)
220+
success = True
221+
except requests.exceptions.ConnectTimeout as err:
222+
if tries > max_retries:
223+
raise err
224+
else:
225+
logger.info(
226+
f"{method} request to {url} timed out, retrying... [{tries/max_retries}]"
227+
)
228+
sleep_time = max(
229+
max_wait_time, base_wait_time * 2 ** (tries - 1)
230+
) # Exponential backoff
231+
time.sleep(sleep_time)
232+
return response
233+
234+
174235
def http_get(
175236
url: str,
176237
temp_file: BinaryIO,
177238
proxies=None,
178239
resume_size=0,
179240
headers: Optional[Dict[str, str]] = None,
241+
timeout=10.0,
242+
max_retries=0,
180243
):
181244
"""
182245
Donwload remote file. Do not gobble up errors.
183246
"""
184247
headers = copy.deepcopy(headers)
185248
if resume_size > 0:
186249
headers["Range"] = "bytes=%d-" % (resume_size,)
187-
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
250+
r = _request_with_retry(
251+
method="GET",
252+
url=url,
253+
stream=True,
254+
proxies=proxies,
255+
headers=headers,
256+
timeout=timeout,
257+
max_retries=max_retries,
258+
)
188259
r.raise_for_status()
189260
content_length = r.headers.get("Content-Length")
190261
total = resume_size + int(content_length) if content_length is not None else None
@@ -254,8 +325,9 @@ def cached_download(
254325
etag = None
255326
if not local_files_only:
256327
try:
257-
r = requests.head(
258-
url,
328+
r = _request_with_retry(
329+
method="HEAD",
330+
url=url,
259331
headers=headers,
260332
allow_redirects=False,
261333
proxies=proxies,
@@ -276,15 +348,14 @@ def cached_download(
276348
# between the HEAD and the GET (unlikely, but hey).
277349
if 300 <= r.status_code <= 399:
278350
url_to_download = r.headers["Location"]
351+
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
352+
# Actually raise for those subclasses of ConnectionError
353+
raise
279354
except (
280355
requests.exceptions.ConnectionError,
281356
requests.exceptions.Timeout,
282-
) as exc:
283-
# Actually raise for those subclasses of ConnectionError:
284-
if isinstance(exc, requests.exceptions.SSLError) or isinstance(
285-
exc, requests.exceptions.ProxyError
286-
):
287-
raise exc
357+
OfflineModeIsEnabled,
358+
):
288359
# Otherwise, our Internet connection is down.
289360
# etag is None
290361
pass
@@ -297,7 +368,7 @@ def cached_download(
297368
# etag is None == we don't have a connection or we passed local_files_only.
298369
# try to get the last downloaded one
299370
if etag is None:
300-
if os.path.exists(cache_path):
371+
if os.path.exists(cache_path) and not force_download:
301372
return cache_path
302373
else:
303374
matching_files = [
@@ -307,7 +378,7 @@ def cached_download(
307378
)
308379
if not file.endswith(".json") and not file.endswith(".lock")
309380
]
310-
if len(matching_files) > 0:
381+
if len(matching_files) > 0 and not force_download:
311382
return os.path.join(cache_dir, matching_files[-1])
312383
else:
313384
# If files cannot be found and local_files_only=True,

tests/test_file_download.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
)
2323
from huggingface_hub.file_download import cached_download, filename_to_url, hf_hub_url
2424

25-
from .testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SAMPLE_DATASET_IDENTIFIER
25+
from .testing_utils import (
26+
DUMMY_UNKWOWN_IDENTIFIER,
27+
SAMPLE_DATASET_IDENTIFIER,
28+
OfflineSimulationMode,
29+
offline,
30+
)
2631

2732

2833
MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
@@ -51,13 +56,26 @@
5156

5257
class CachedDownloadTests(unittest.TestCase):
5358
def test_bogus_url(self):
54-
# This lets us simulate no connection
55-
# as the error raised is the same
56-
# `ConnectionError`
5759
url = "https://bogus"
5860
with self.assertRaisesRegex(ValueError, "Connection error"):
5961
_ = cached_download(url)
6062

63+
def test_no_connection(self):
64+
invalid_url = hf_hub_url(
65+
MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID
66+
)
67+
valid_url = hf_hub_url(
68+
MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT
69+
)
70+
self.assertIsNotNone(cached_download(valid_url, force_download=True))
71+
for offline_mode in OfflineSimulationMode:
72+
with offline(mode=offline_mode):
73+
with self.assertRaisesRegex(ValueError, "Connection error"):
74+
_ = cached_download(invalid_url)
75+
with self.assertRaisesRegex(ValueError, "Connection error"):
76+
_ = cached_download(valid_url, force_download=True)
77+
self.assertIsNotNone(cached_download(valid_url))
78+
6179
def test_file_not_found(self):
6280
# Valid revision (None) but missing file.
6381
url = hf_hub_url(MODEL_ID, filename="missing.bin")

tests/test_offline_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from io import BytesIO
2+
3+
import pytest
4+
5+
import requests
6+
from huggingface_hub.file_download import http_get
7+
8+
from .testing_utils import (
9+
OfflineSimulationMode,
10+
RequestWouldHangIndefinitelyError,
11+
offline,
12+
)
13+
14+
15+
def test_offline_with_timeout():
16+
with offline(OfflineSimulationMode.CONNECTION_TIMES_OUT):
17+
with pytest.raises(RequestWouldHangIndefinitelyError):
18+
requests.request("GET", "https://huggingface.co")
19+
with pytest.raises(requests.exceptions.ConnectTimeout):
20+
requests.request("GET", "https://huggingface.co", timeout=1.0)
21+
with pytest.raises(requests.exceptions.ConnectTimeout):
22+
http_get("https://huggingface.co", BytesIO())
23+
24+
25+
def test_offline_with_connection_error():
26+
with offline(OfflineSimulationMode.CONNECTION_FAILS):
27+
with pytest.raises(requests.exceptions.ConnectionError):
28+
requests.request("GET", "https://huggingface.co")
29+
with pytest.raises(requests.exceptions.ConnectionError):
30+
http_get("https://huggingface.co", BytesIO())
31+
32+
33+
def test_offline_with_datasets_offline_mode_enabled():
34+
with offline(OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1):
35+
with pytest.raises(ConnectionError):
36+
http_get("https://huggingface.co", BytesIO())

tests/testing_utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
22
import unittest
3+
from contextlib import contextmanager
34
from distutils.util import strtobool
5+
from enum import Enum
6+
from unittest.mock import patch
47

58

69
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
@@ -55,3 +58,70 @@ def require_git_lfs(test_case):
5558
return unittest.skip("test of git lfs workflow")(test_case)
5659
else:
5760
return test_case
61+
62+
63+
class RequestWouldHangIndefinitelyError(Exception):
64+
pass
65+
66+
67+
class OfflineSimulationMode(Enum):
68+
CONNECTION_FAILS = 0
69+
CONNECTION_TIMES_OUT = 1
70+
HF_HUB_OFFLINE_SET_TO_1 = 2
71+
72+
73+
@contextmanager
74+
def offline(mode=OfflineSimulationMode.CONNECTION_FAILS, timeout=1e-16):
75+
"""
76+
Simulate offline mode.
77+
78+
There are three offline simulatiom modes:
79+
80+
CONNECTION_FAILS (default mode): a ConnectionError is raised for each network call.
81+
Connection errors are created by mocking socket.socket
82+
CONNECTION_TIMES_OUT: the connection hangs until it times out.
83+
The default timeout value is low (1e-16) to speed up the tests.
84+
Timeout errors are created by mocking requests.request
85+
HF_HUB_OFFLINE_SET_TO_1: the HF_HUB_OFFLINE_SET_TO_1 environment variable is set to 1.
86+
This makes the http/ftp calls of the library instantly fail and raise an OfflineModeEmabled error.
87+
"""
88+
import socket
89+
90+
from requests import request as online_request
91+
92+
def timeout_request(method, url, **kwargs):
93+
# Change the url to an invalid url so that the connection hangs
94+
invalid_url = "https://10.255.255.1"
95+
if kwargs.get("timeout") is None:
96+
raise RequestWouldHangIndefinitelyError(
97+
f"Tried a call to {url} in offline mode with no timeout set. Please set a timeout."
98+
)
99+
kwargs["timeout"] = timeout
100+
try:
101+
return online_request(method, invalid_url, **kwargs)
102+
except Exception as e:
103+
# The following changes in the error are just here to make the offline timeout error prettier
104+
e.request.url = url
105+
max_retry_error = e.args[0]
106+
max_retry_error.args = (
107+
max_retry_error.args[0].replace("10.255.255.1", f"OfflineMock[{url}]"),
108+
)
109+
e.args = (max_retry_error,)
110+
raise
111+
112+
def offline_socket(*args, **kwargs):
113+
raise socket.error("Offline mode is enabled.")
114+
115+
if mode is OfflineSimulationMode.CONNECTION_FAILS:
116+
# inspired from https://stackoverflow.com/a/18601897
117+
with patch("socket.socket", offline_socket):
118+
yield
119+
elif mode is OfflineSimulationMode.CONNECTION_TIMES_OUT:
120+
# inspired from https://stackoverflow.com/a/904609
121+
with patch("requests.request", timeout_request):
122+
yield
123+
elif mode is OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1:
124+
with patch("huggingface_hub.constants.HF_HUB_OFFLINE", True):
125+
yield
126+
else:
127+
raise ValueError("Please use a value from the OfflineSimulationMode enum.")

0 commit comments

Comments
 (0)