1818from filelock import FileLock
1919
2020from . import __version__
21+ from .constants import (
22+ HUGGINGFACE_CO_URL_TEMPLATE ,
23+ HUGGINGFACE_HUB_CACHE ,
24+ REPO_TYPE_DATASET ,
25+ REPO_TYPE_DATASET_URL_PREFIX ,
26+ REPO_TYPES ,
27+ )
2128from .hf_api import HfFolder
2229
2330
@@ -55,34 +62,11 @@ def is_tf_available():
5562 return _tf_available
5663
5764
58- # Constants for file downloads
59-
60- PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
61- TF2_WEIGHTS_NAME = "tf_model.h5"
62- TF_WEIGHTS_NAME = "model.ckpt"
63- FLAX_WEIGHTS_NAME = "flax_model.msgpack"
64- CONFIG_NAME = "config.json"
65-
66- HUGGINGFACE_CO_URL_TEMPLATE = (
67- "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
68- )
69-
70-
71- # default cache
72- hf_cache_home = os .path .expanduser (
73- os .getenv (
74- "HF_HOME" , os .path .join (os .getenv ("XDG_CACHE_HOME" , "~/.cache" ), "huggingface" )
75- )
76- )
77- default_cache_path = os .path .join (hf_cache_home , "hub" )
78-
79- HUGGINGFACE_HUB_CACHE = os .getenv ("HUGGINGFACE_HUB_CACHE" , default_cache_path )
80-
81-
8265def hf_hub_url (
83- model_id : str ,
66+ repo_id : str ,
8467 filename : str ,
8568 subfolder : Optional [str ] = None ,
69+ repo_type : Optional [str ] = None ,
8670 revision : Optional [str ] = None ,
8771) -> str :
8872 """
@@ -103,10 +87,16 @@ def hf_hub_url(
10387 if subfolder is not None :
10488 filename = f"{ subfolder } /{ filename } "
10589
90+ if repo_type not in REPO_TYPES :
91+ raise ValueError ("Invalid repo type" )
92+
93+ if repo_type == REPO_TYPE_DATASET :
94+ repo_id = REPO_TYPE_DATASET_URL_PREFIX + repo_id
95+
10696 if revision is None :
10797 revision = "main"
10898 return HUGGINGFACE_CO_URL_TEMPLATE .format (
109- model_id = model_id , revision = revision , filename = filename
99+ repo_id = repo_id , revision = revision , filename = filename
110100 )
111101
112102
@@ -286,8 +276,17 @@ def cached_download(
286276 # between the HEAD and the GET (unlikely, but hey).
287277 if 300 <= r .status_code <= 399 :
288278 url_to_download = r .headers ["Location" ]
289- except (requests .exceptions .ConnectionError , requests .exceptions .Timeout ):
290- # etag is already None
279+ except (
280+ requests .exceptions .ConnectionError ,
281+ requests .exceptions .Timeout ,
282+ ) as exc :
283+ # Actually raise for those subclasses of ConnectionError:
284+ if isinstance (exc , requests .exceptions .SSLError ) or isinstance (
285+ exc , requests .exceptions .ProxyError
286+ ):
287+ raise exc
288+ # Otherwise, our Internet connection is down.
289+ # etag is None
291290 pass
292291
293292 filename = url_to_filename (url , etag )
0 commit comments