66import os
77import sys
88import tempfile
9+ import time
910from contextlib import contextmanager
1011from functools import partial
1112from hashlib import sha256
1617
1718import requests
1819from filelock import FileLock
20+ from huggingface_hub import constants
1921
2022from . import __version__
2123from .constants import (
@@ -171,20 +173,89 @@ def http_user_agent(
171173 return ua
172174
173175
176+ class OfflineModeIsEnabled (ConnectionError ):
177+ pass
178+
179+
180+ def _raise_if_offline_mode_is_enabled (msg : Optional [str ] = None ):
181+ """Raise a OfflineModeIsEnabled error (subclass of ConnectionError) if HF_HUB_OFFLINE is True."""
182+ if constants .HF_HUB_OFFLINE :
183+ raise OfflineModeIsEnabled (
184+ "Offline mode is enabled."
185+ if msg is None
186+ else "Offline mode is enabled. " + str (msg )
187+ )
188+
189+
190+ def _request_with_retry (
191+ method : str ,
192+ url : str ,
193+ max_retries : int = 0 ,
194+ base_wait_time : float = 0.5 ,
195+ max_wait_time : float = 2 ,
196+ timeout : float = 10.0 ,
197+ ** params ,
198+ ) -> requests .Response :
199+ """Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff.
200+
201+ Note that if the environment variable HF_HUB_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised.
202+
203+ Args:
204+ method (str): HTTP method, such as 'GET' or 'HEAD'
205+ url (str): The URL of the ressource to fetch
206+ max_retries (int): Maximum number of retries, defaults to 0 (no retries)
207+ base_wait_time (float): Duration (in seconds) to wait before retrying the first time. Wait time between
208+ retries then grows exponentially, capped by max_wait_time.
209+ max_wait_time (float): Maximum amount of time between two retries, in seconds
210+ **params: Params to pass to `requests.request`
211+ """
212+ _raise_if_offline_mode_is_enabled (f"Tried to reach { url } " )
213+ tries , success = 0 , False
214+ while not success :
215+ tries += 1
216+ try :
217+ response = requests .request (
218+ method = method .upper (), url = url , timeout = timeout , ** params
219+ )
220+ success = True
221+ except requests .exceptions .ConnectTimeout as err :
222+ if tries > max_retries :
223+ raise err
224+ else :
225+ logger .info (
226+ f"{ method } request to { url } timed out, retrying... [{ tries / max_retries } ]"
227+ )
228+ sleep_time = max (
229+ max_wait_time , base_wait_time * 2 ** (tries - 1 )
230+ ) # Exponential backoff
231+ time .sleep (sleep_time )
232+ return response
233+
234+
174235def http_get (
175236 url : str ,
176237 temp_file : BinaryIO ,
177238 proxies = None ,
178239 resume_size = 0 ,
179240 headers : Optional [Dict [str , str ]] = None ,
241+ timeout = 10.0 ,
242+ max_retries = 0 ,
180243):
181244 """
182245 Donwload remote file. Do not gobble up errors.
183246 """
184247 headers = copy .deepcopy (headers )
185248 if resume_size > 0 :
186249 headers ["Range" ] = "bytes=%d-" % (resume_size ,)
187- r = requests .get (url , stream = True , proxies = proxies , headers = headers )
250+ r = _request_with_retry (
251+ method = "GET" ,
252+ url = url ,
253+ stream = True ,
254+ proxies = proxies ,
255+ headers = headers ,
256+ timeout = timeout ,
257+ max_retries = max_retries ,
258+ )
188259 r .raise_for_status ()
189260 content_length = r .headers .get ("Content-Length" )
190261 total = resume_size + int (content_length ) if content_length is not None else None
@@ -254,8 +325,9 @@ def cached_download(
254325 etag = None
255326 if not local_files_only :
256327 try :
257- r = requests .head (
258- url ,
328+ r = _request_with_retry (
329+ method = "HEAD" ,
330+ url = url ,
259331 headers = headers ,
260332 allow_redirects = False ,
261333 proxies = proxies ,
@@ -276,15 +348,14 @@ def cached_download(
276348 # between the HEAD and the GET (unlikely, but hey).
277349 if 300 <= r .status_code <= 399 :
278350 url_to_download = r .headers ["Location" ]
351+ except (requests .exceptions .SSLError , requests .exceptions .ProxyError ):
352+ # Actually raise for those subclasses of ConnectionError
353+ raise
279354 except (
280355 requests .exceptions .ConnectionError ,
281356 requests .exceptions .Timeout ,
282- ) as exc :
283- # Actually raise for those subclasses of ConnectionError:
284- if isinstance (exc , requests .exceptions .SSLError ) or isinstance (
285- exc , requests .exceptions .ProxyError
286- ):
287- raise exc
357+ OfflineModeIsEnabled ,
358+ ):
288359 # Otherwise, our Internet connection is down.
289360 # etag is None
290361 pass
@@ -297,7 +368,7 @@ def cached_download(
297368 # etag is None == we don't have a connection or we passed local_files_only.
298369 # try to get the last downloaded one
299370 if etag is None :
300- if os .path .exists (cache_path ):
371+ if os .path .exists (cache_path ) and not force_download :
301372 return cache_path
302373 else :
303374 matching_files = [
@@ -307,7 +378,7 @@ def cached_download(
307378 )
308379 if not file .endswith (".json" ) and not file .endswith (".lock" )
309380 ]
310- if len (matching_files ) > 0 :
381+ if len (matching_files ) > 0 and not force_download :
311382 return os .path .join (cache_dir , matching_files [- 1 ])
312383 else :
313384 # If files cannot be found and local_files_only=True,
0 commit comments