11import os
2+ from glob import glob
23from pathlib import Path
34from typing import Dict , Optional , Union
45
5- from .constants import HUGGINGFACE_HUB_CACHE
6+ from .constants import DEFAULT_REVISION , HUGGINGFACE_HUB_CACHE
67from .file_download import cached_download , hf_hub_url
78from .hf_api import HfApi , HfFolder
9+ from .utils import logging
810
911
1012REPO_ID_SEPARATOR = "__"
1113# ^ make sure this substring is not allowed in repo_ids on hf.co
1214
1315
16+ logger = logging .get_logger (__name__ )
17+
18+
1419def snapshot_download (
1520 repo_id : str ,
1621 revision : Optional [str ] = None ,
1722 cache_dir : Union [str , Path , None ] = None ,
1823 library_name : Optional [str ] = None ,
1924 library_version : Optional [str ] = None ,
2025 user_agent : Union [Dict , str , None ] = None ,
26+ proxies = None ,
27+ etag_timeout = 10 ,
28+ resume_download = False ,
2129 use_auth_token : Union [bool , str , None ] = None ,
30+ local_files_only = False ,
2231) -> str :
2332 """
2433 Downloads a whole snapshot of a repo's files at the specified revision.
@@ -39,6 +48,8 @@ def snapshot_download(
3948 """
4049 if cache_dir is None :
4150 cache_dir = HUGGINGFACE_HUB_CACHE
51+ if revision is None :
52+ revision = DEFAULT_REVISION
4253 if isinstance (cache_dir , Path ):
4354 cache_dir = str (cache_dir )
4455
@@ -53,18 +64,96 @@ def snapshot_download(
5364 else :
5465 token = None
5566
56- _api = HfApi ()
57- model_info = _api .model_info (repo_id = repo_id , revision = revision , token = token )
67+ # remove all `/` occurances to correctly convert repo to directory name
68+ repo_id_flattened = repo_id .replace ("/" , REPO_ID_SEPARATOR )
69+
70+ # if we have no internet connection we will look for the
71+ # last modified folder in the cache
72+ if local_files_only :
73+ # possible repos have <path/to/cache_dir>/<flatten_repo_id> prefix
74+ repo_folders_prefix = os .path .join (cache_dir , repo_id_flattened )
75+
76+ # list all possible folders that can correspond to the repo_id
77+ # and are of the format <flattened-repo-id>.<revision>.<commit-sha>
78+ # now let's list all cached repos that have to be included in the revision.
79+ # There are 3 cases that we have to consider.
80+
81+ # 1) cached repos of format <repo_id>.{revision}.<any-hash>
82+ # -> in this case {revision} has to be a branch
83+ repo_folders_branch = glob (repo_folders_prefix + "." + revision + ".*" )
84+
85+ # 2) cached repos of format <repo_id>.{revision}
86+ # -> in this case {revision} has to be a commit sha
87+ repo_folders_commit_only = glob (repo_folders_prefix + "." + revision )
88+
89+ # 3) cached repos of format <repo_id>.<any-branch>.{revision}
90+ # -> in this case {revision} also has to be a commit sha
91+ repo_folders_branch_commit = glob (repo_folders_prefix + ".*." + revision )
92+
93+ # combine all possible fetched cached repos
94+ repo_folders = (
95+ repo_folders_branch + repo_folders_commit_only + repo_folders_branch_commit
96+ )
97+
98+ if len (repo_folders ) == 0 :
99+ raise ValueError (
100+ "Cannot find the requested files in the cached path and outgoing traffic has been"
101+ " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
102+ " to False."
103+ )
58104
59- storage_folder = os .path .join (
60- cache_dir , repo_id .replace ("/" , REPO_ID_SEPARATOR ) + "." + model_info .sha
61- )
105+ # check if repo id was previously cached from a commit sha revision
106+ # and passed {revision} is not a commit sha
107+ # in this case snapshotting repos locally might lead to unexpected
108+ # behavior the user should be warned about
62109
63- for model_file in model_info . siblings :
64- url = hf_hub_url (
65- repo_id , filename = model_file . rfilename , revision = model_info . sha
110+ # get all folders that were cached with just a sha commit revision
111+ all_repo_folders_from_sha = set ( glob ( repo_folders_prefix + ".*" )) - set (
112+ glob ( repo_folders_prefix + ".*.*" )
66113 )
67- relative_filepath = os .path .join (* model_file .rfilename .split ("/" ))
114+ # 1) is there any repo id that was previously cached from a commit sha?
115+ has_a_sha_revision_been_cached = len (all_repo_folders_from_sha ) > 0
116+ # 2) is the passed {revision} is a branch
117+ is_revision_a_branch = (
118+ len (repo_folders_commit_only + repo_folders_branch_commit ) == 0
119+ )
120+
121+ if has_a_sha_revision_been_cached and is_revision_a_branch :
122+ # -> in this case let's warn the user
123+ logger .warn (
124+ f"The repo { repo_id } was previously downloaded from a commit hash revision "
125+ f"and has created the following cached directories { all_repo_folders_from_sha } ."
126+ f" In this case, trying to load a repo from the branch { revision } in offline "
127+ "mode might lead to unexpected behavior by not taking into account the latest "
128+ "commits."
129+ )
130+
131+ # find last modified folder
132+ storage_folder = max (repo_folders , key = os .path .getmtime )
133+
134+ # get commit sha
135+ repo_id_sha = storage_folder .split ("." )[- 1 ]
136+ model_files = os .listdir (storage_folder )
137+ else :
138+ # if we have internet connection we retrieve the correct folder name from the huggingface api
139+ _api = HfApi ()
140+ model_info = _api .model_info (repo_id = repo_id , revision = revision , token = token )
141+
142+ storage_folder = os .path .join (cache_dir , repo_id_flattened + "." + revision )
143+
144+ # if passed revision is not identical to the commit sha
145+ # then revision has to be a branch name, e.g. "main"
146+ # in this case make sure that the branch name is included
147+ # cached storage folder name
148+ if revision != model_info .sha :
149+ storage_folder += f".{ model_info .sha } "
150+
151+ repo_id_sha = model_info .sha
152+ model_files = [f .rfilename for f in model_info .siblings ]
153+
154+ for model_file in model_files :
155+ url = hf_hub_url (repo_id , filename = model_file , revision = repo_id_sha )
156+ relative_filepath = os .path .join (* model_file .split ("/" ))
68157
69158 # Create potential nested dir
70159 nested_dirname = os .path .dirname (
@@ -79,7 +168,11 @@ def snapshot_download(
79168 library_name = library_name ,
80169 library_version = library_version ,
81170 user_agent = user_agent ,
171+ proxies = proxies ,
172+ etag_timeout = etag_timeout ,
173+ resume_download = resume_download ,
82174 use_auth_token = use_auth_token ,
175+ local_files_only = local_files_only ,
83176 )
84177
85178 if os .path .exists (path + ".lock" ):
0 commit comments