Skip to content

Commit 99bae57

Browse files
[Snapshot download] Allow to load local repo id with snapshot download (#505)
* [Snapshotdownload] Add more parameter names * up * fix style * add revision to cache * update tests * make style * improve snapshot download * add tests * delete dummy folders * Remove tree * finish cette mer** Co-authored-by: Lysandre Debut <[email protected]>
1 parent 8a8c89f commit 99bae57

File tree

4 files changed

+260
-11
lines changed

4 files changed

+260
-11
lines changed

src/huggingface_hub/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
CONFIG_NAME = "config.json"
1616
REPOCARD_NAME = "README.md"
1717

18+
DEFAULT_REVISION = "main"
19+
1820
HUGGINGFACE_CO_URL_HOME = "https://huggingface.co/"
1921

2022
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}

src/huggingface_hub/file_download.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from . import __version__
2323
from .constants import (
24+
DEFAULT_REVISION,
2425
HUGGINGFACE_CO_URL_TEMPLATE,
2526
HUGGINGFACE_HUB_CACHE,
2627
REPO_TYPES,
@@ -110,7 +111,7 @@ def hf_hub_url(
110111
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
111112

112113
if revision is None:
113-
revision = "main"
114+
revision = DEFAULT_REVISION
114115
return HUGGINGFACE_CO_URL_TEMPLATE.format(
115116
repo_id=repo_id, revision=revision, filename=filename
116117
)

src/huggingface_hub/snapshot_download.py

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,33 @@
11
import os
2+
from glob import glob
23
from pathlib import Path
34
from typing import Dict, Optional, Union
45

5-
from .constants import HUGGINGFACE_HUB_CACHE
6+
from .constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE
67
from .file_download import cached_download, hf_hub_url
78
from .hf_api import HfApi, HfFolder
9+
from .utils import logging
810

911

1012
REPO_ID_SEPARATOR = "__"
1113
# ^ make sure this substring is not allowed in repo_ids on hf.co
1214

1315

16+
logger = logging.get_logger(__name__)
17+
18+
1419
def snapshot_download(
1520
repo_id: str,
1621
revision: Optional[str] = None,
1722
cache_dir: Union[str, Path, None] = None,
1823
library_name: Optional[str] = None,
1924
library_version: Optional[str] = None,
2025
user_agent: Union[Dict, str, None] = None,
26+
proxies=None,
27+
etag_timeout=10,
28+
resume_download=False,
2129
use_auth_token: Union[bool, str, None] = None,
30+
local_files_only=False,
2231
) -> str:
2332
"""
2433
Downloads a whole snapshot of a repo's files at the specified revision.
@@ -39,6 +48,8 @@ def snapshot_download(
3948
"""
4049
if cache_dir is None:
4150
cache_dir = HUGGINGFACE_HUB_CACHE
51+
if revision is None:
52+
revision = DEFAULT_REVISION
4253
if isinstance(cache_dir, Path):
4354
cache_dir = str(cache_dir)
4455

@@ -53,18 +64,96 @@ def snapshot_download(
5364
else:
5465
token = None
5566

56-
_api = HfApi()
57-
model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token)
67+
# remove all `/` occurances to correctly convert repo to directory name
68+
repo_id_flattened = repo_id.replace("/", REPO_ID_SEPARATOR)
69+
70+
# if we have no internet connection we will look for the
71+
# last modified folder in the cache
72+
if local_files_only:
73+
# possible repos have <path/to/cache_dir>/<flatten_repo_id> prefix
74+
repo_folders_prefix = os.path.join(cache_dir, repo_id_flattened)
75+
76+
# list all possible folders that can correspond to the repo_id
77+
# and are of the format <flattened-repo-id>.<revision>.<commit-sha>
78+
# now let's list all cached repos that have to be included in the revision.
79+
# There are 3 cases that we have to consider.
80+
81+
# 1) cached repos of format <repo_id>.{revision}.<any-hash>
82+
# -> in this case {revision} has to be a branch
83+
repo_folders_branch = glob(repo_folders_prefix + "." + revision + ".*")
84+
85+
# 2) cached repos of format <repo_id>.{revision}
86+
# -> in this case {revision} has to be a commit sha
87+
repo_folders_commit_only = glob(repo_folders_prefix + "." + revision)
88+
89+
# 3) cached repos of format <repo_id>.<any-branch>.{revision}
90+
# -> in this case {revision} also has to be a commit sha
91+
repo_folders_branch_commit = glob(repo_folders_prefix + ".*." + revision)
92+
93+
# combine all possible fetched cached repos
94+
repo_folders = (
95+
repo_folders_branch + repo_folders_commit_only + repo_folders_branch_commit
96+
)
97+
98+
if len(repo_folders) == 0:
99+
raise ValueError(
100+
"Cannot find the requested files in the cached path and outgoing traffic has been"
101+
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
102+
" to False."
103+
)
58104

59-
storage_folder = os.path.join(
60-
cache_dir, repo_id.replace("/", REPO_ID_SEPARATOR) + "." + model_info.sha
61-
)
105+
# check if repo id was previously cached from a commit sha revision
106+
# and passed {revision} is not a commit sha
107+
# in this case snapshotting repos locally might lead to unexpected
108+
# behavior the user should be warned about
62109

63-
for model_file in model_info.siblings:
64-
url = hf_hub_url(
65-
repo_id, filename=model_file.rfilename, revision=model_info.sha
110+
# get all folders that were cached with just a sha commit revision
111+
all_repo_folders_from_sha = set(glob(repo_folders_prefix + ".*")) - set(
112+
glob(repo_folders_prefix + ".*.*")
66113
)
67-
relative_filepath = os.path.join(*model_file.rfilename.split("/"))
114+
# 1) is there any repo id that was previously cached from a commit sha?
115+
has_a_sha_revision_been_cached = len(all_repo_folders_from_sha) > 0
116+
# 2) is the passed {revision} is a branch
117+
is_revision_a_branch = (
118+
len(repo_folders_commit_only + repo_folders_branch_commit) == 0
119+
)
120+
121+
if has_a_sha_revision_been_cached and is_revision_a_branch:
122+
# -> in this case let's warn the user
123+
logger.warn(
124+
f"The repo {repo_id} was previously downloaded from a commit hash revision "
125+
f"and has created the following cached directories {all_repo_folders_from_sha}."
126+
f" In this case, trying to load a repo from the branch {revision} in offline "
127+
"mode might lead to unexpected behavior by not taking into account the latest "
128+
"commits."
129+
)
130+
131+
# find last modified folder
132+
storage_folder = max(repo_folders, key=os.path.getmtime)
133+
134+
# get commit sha
135+
repo_id_sha = storage_folder.split(".")[-1]
136+
model_files = os.listdir(storage_folder)
137+
else:
138+
# if we have internet connection we retrieve the correct folder name from the huggingface api
139+
_api = HfApi()
140+
model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token)
141+
142+
storage_folder = os.path.join(cache_dir, repo_id_flattened + "." + revision)
143+
144+
# if passed revision is not identical to the commit sha
145+
# then revision has to be a branch name, e.g. "main"
146+
# in this case make sure that the branch name is included
147+
# cached storage folder name
148+
if revision != model_info.sha:
149+
storage_folder += f".{model_info.sha}"
150+
151+
repo_id_sha = model_info.sha
152+
model_files = [f.rfilename for f in model_info.siblings]
153+
154+
for model_file in model_files:
155+
url = hf_hub_url(repo_id, filename=model_file, revision=repo_id_sha)
156+
relative_filepath = os.path.join(*model_file.split("/"))
68157

69158
# Create potential nested dir
70159
nested_dirname = os.path.dirname(
@@ -79,7 +168,11 @@ def snapshot_download(
79168
library_name=library_name,
80169
library_version=library_version,
81170
user_agent=user_agent,
171+
proxies=proxies,
172+
etag_timeout=etag_timeout,
173+
resume_download=resume_download,
82174
use_auth_token=use_auth_token,
175+
local_files_only=local_files_only,
83176
)
84177

85178
if os.path.exists(path + ".lock"):

tests/test_snapshot_download.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import shutil
23
import tempfile
34
import time
45
import unittest
@@ -46,8 +47,15 @@ def setUp(self) -> None:
4647

4748
self.second_commit_hash = repo.git_head_hash()
4849

50+
with repo.commit("Add file to other branch", branch="other"):
51+
with open("dummy_file_2.txt", "w+") as f:
52+
f.write("v4")
53+
54+
self.third_commit_hash = repo.git_head_hash()
55+
4956
def tearDown(self) -> None:
5057
self._api.delete_repo(name=REPO_NAME, token=self._token)
58+
shutil.rmtree(REPO_NAME)
5159

5260
def test_download_model(self):
5361
# Test `main` branch
@@ -155,3 +163,148 @@ def test_download_private_model(self):
155163
self._api.update_repo_visibility(
156164
token=self._token, name=REPO_NAME, private=False
157165
)
166+
167+
def test_download_model_local_only(self):
168+
# Test no branch specified
169+
with tempfile.TemporaryDirectory() as tmpdirname:
170+
# first download folder to cache it
171+
snapshot_download(f"{USER}/{REPO_NAME}", cache_dir=tmpdirname)
172+
173+
# now load from cache
174+
storage_folder = snapshot_download(
175+
f"{USER}/{REPO_NAME}",
176+
cache_dir=tmpdirname,
177+
local_files_only=True,
178+
)
179+
180+
# folder contains the two files contributed and the .gitattributes
181+
folder_contents = os.listdir(storage_folder)
182+
self.assertEqual(len(folder_contents), 3)
183+
self.assertTrue("dummy_file.txt" in folder_contents)
184+
self.assertTrue("dummy_file_2.txt" in folder_contents)
185+
self.assertTrue(".gitattributes" in folder_contents)
186+
187+
with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
188+
contents = f.read()
189+
self.assertEqual(contents, "v2")
190+
191+
# folder name contains the revision's commit sha.
192+
self.assertTrue(self.second_commit_hash in storage_folder)
193+
194+
# Test with specific revision branch
195+
with tempfile.TemporaryDirectory() as tmpdirname:
196+
# first download folder to cache it
197+
snapshot_download(
198+
f"{USER}/{REPO_NAME}",
199+
revision="other",
200+
cache_dir=tmpdirname,
201+
)
202+
203+
# now load from cache
204+
storage_folder = snapshot_download(
205+
f"{USER}/{REPO_NAME}",
206+
revision="other",
207+
cache_dir=tmpdirname,
208+
local_files_only=True,
209+
)
210+
211+
# folder contains the two files contributed and the .gitattributes
212+
folder_contents = os.listdir(storage_folder)
213+
self.assertEqual(len(folder_contents), 3)
214+
self.assertTrue("dummy_file.txt" in folder_contents)
215+
self.assertTrue(".gitattributes" in folder_contents)
216+
217+
with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
218+
contents = f.read()
219+
self.assertEqual(contents, "v2")
220+
221+
# folder name contains the revision's commit sha.
222+
self.assertTrue(self.third_commit_hash in storage_folder)
223+
224+
# Test with specific revision hash
225+
with tempfile.TemporaryDirectory() as tmpdirname:
226+
# first download folder to cache it
227+
snapshot_download(
228+
f"{USER}/{REPO_NAME}",
229+
revision=self.first_commit_hash,
230+
cache_dir=tmpdirname,
231+
)
232+
233+
# now load from cache
234+
storage_folder = snapshot_download(
235+
f"{USER}/{REPO_NAME}",
236+
revision=self.first_commit_hash,
237+
cache_dir=tmpdirname,
238+
local_files_only=True,
239+
)
240+
241+
# folder contains the two files contributed and the .gitattributes
242+
folder_contents = os.listdir(storage_folder)
243+
self.assertEqual(len(folder_contents), 2)
244+
self.assertTrue("dummy_file.txt" in folder_contents)
245+
self.assertTrue(".gitattributes" in folder_contents)
246+
247+
with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
248+
contents = f.read()
249+
self.assertEqual(contents, "v1")
250+
251+
# folder name contains the revision's commit sha.
252+
self.assertTrue(self.first_commit_hash in storage_folder)
253+
254+
def test_download_model_local_only_multiple(self):
255+
# Test `main` branch
256+
with tempfile.TemporaryDirectory() as tmpdirname:
257+
# download both from branch and from commit
258+
snapshot_download(
259+
f"{USER}/{REPO_NAME}",
260+
cache_dir=tmpdirname,
261+
)
262+
263+
snapshot_download(
264+
f"{USER}/{REPO_NAME}",
265+
revision=self.first_commit_hash,
266+
cache_dir=tmpdirname,
267+
)
268+
269+
# now load from cache and make sure warning to be raised
270+
with self.assertWarns(Warning):
271+
snapshot_download(
272+
f"{USER}/{REPO_NAME}",
273+
cache_dir=tmpdirname,
274+
local_files_only=True,
275+
)
276+
277+
# cache multiple commits and make sure correct commit is taken
278+
with tempfile.TemporaryDirectory() as tmpdirname:
279+
# first download folder to cache it
280+
snapshot_download(
281+
f"{USER}/{REPO_NAME}",
282+
cache_dir=tmpdirname,
283+
)
284+
285+
# now load folder from another branch
286+
snapshot_download(
287+
f"{USER}/{REPO_NAME}",
288+
revision="other",
289+
cache_dir=tmpdirname,
290+
)
291+
292+
# now make sure that loading "main" branch gives correct branch
293+
storage_folder = snapshot_download(
294+
f"{USER}/{REPO_NAME}",
295+
cache_dir=tmpdirname,
296+
local_files_only=True,
297+
)
298+
299+
# folder contains the two files contributed and the .gitattributes
300+
folder_contents = os.listdir(storage_folder)
301+
self.assertEqual(len(folder_contents), 3)
302+
self.assertTrue("dummy_file.txt" in folder_contents)
303+
self.assertTrue(".gitattributes" in folder_contents)
304+
305+
with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
306+
contents = f.read()
307+
self.assertEqual(contents, "v2")
308+
309+
# folder name contains the 2nd commit sha and not the 3rd
310+
self.assertTrue(self.second_commit_hash in storage_folder)

0 commit comments

Comments
 (0)