Skip to content

Commit 10e7445

Browse files
authored
Add auth to snapshot download (#340)
* Add API token to snapshot download * Fix style * Replace token with use_auth_token * Add truthiness check for token * Assign string value for token * Add case for None token * Refactor unit tests * Add test cases * Style * Add unit test for private model info * Revert "Add unit test for private model info" This reverts commit eca1e8b. * Make repo public after test
1 parent d914261 commit 10e7445

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

src/huggingface_hub/snapshot_download.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from .constants import HUGGINGFACE_HUB_CACHE
66
from .file_download import cached_download, hf_hub_url
7-
from .hf_api import HfApi
7+
from .hf_api import HfApi, HfFolder
88

99

1010
REPO_ID_SEPARATOR = "__"
@@ -18,6 +18,7 @@ def snapshot_download(
1818
library_name: Optional[str] = None,
1919
library_version: Optional[str] = None,
2020
user_agent: Union[Dict, str, None] = None,
21+
use_auth_token: Union[bool, str, None] = None,
2122
) -> str:
2223
"""
2324
Downloads a whole snapshot of a repo's files at the specified revision.
@@ -41,8 +42,19 @@ def snapshot_download(
4142
if isinstance(cache_dir, Path):
4243
cache_dir = str(cache_dir)
4344

45+
if isinstance(use_auth_token, str):
46+
token = use_auth_token
47+
elif use_auth_token:
48+
token = HfFolder.get_token()
49+
if token is None:
50+
raise EnvironmentError(
51+
"You specified use_auth_token=True, but a Hugging Face token was not found."
52+
)
53+
else:
54+
token = None
55+
4456
_api = HfApi()
45-
model_info = _api.model_info(repo_id=repo_id, revision=revision)
57+
model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token)
4658

4759
storage_folder = os.path.join(
4860
cache_dir, repo_id.replace("/", REPO_ID_SEPARATOR) + "." + model_info.sha
@@ -67,6 +79,7 @@ def snapshot_download(
6779
library_name=library_name,
6880
library_version=library_version,
6981
user_agent=user_agent,
82+
use_auth_token=use_auth_token,
7083
)
7184

7285
if os.path.exists(path + ".lock"):

tests/test_snapshot_download.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import time
44
import unittest
55

6+
import requests
67
from huggingface_hub import HfApi, Repository
8+
from huggingface_hub.hf_api import HfFolder
79
from huggingface_hub.snapshot_download import snapshot_download
810
from tests.testing_constants import ENDPOINT_STAGING, PASS, USER
911

@@ -88,3 +90,64 @@ def test_download_model(self):
8890

8991
# folder name contains the revision's commit sha.
9092
self.assertTrue(self.first_commit_hash in storage_folder)
93+
94+
def test_download_private_model(self):
95+
self._api.update_repo_visibility(self._token, REPO_NAME, private=True)
96+
97+
# Test download fails without token
98+
with tempfile.TemporaryDirectory() as tmpdirname:
99+
with self.assertRaisesRegex(
100+
requests.exceptions.HTTPError, "404 Client Error"
101+
):
102+
_ = snapshot_download(
103+
f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname
104+
)
105+
106+
# Test we can download with token from cache
107+
with tempfile.TemporaryDirectory() as tmpdirname:
108+
HfFolder.save_token(self._token)
109+
storage_folder = snapshot_download(
110+
f"{USER}/{REPO_NAME}",
111+
revision="main",
112+
cache_dir=tmpdirname,
113+
use_auth_token=True,
114+
)
115+
116+
# folder contains the two files contributed and the .gitattributes
117+
folder_contents = os.listdir(storage_folder)
118+
self.assertEqual(len(folder_contents), 3)
119+
self.assertTrue("dummy_file.txt" in folder_contents)
120+
self.assertTrue("dummy_file_2.txt" in folder_contents)
121+
self.assertTrue(".gitattributes" in folder_contents)
122+
123+
with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
124+
contents = f.read()
125+
self.assertEqual(contents, "v2")
126+
127+
# folder name contains the revision's commit sha.
128+
self.assertTrue(self.second_commit_hash in storage_folder)
129+
130+
# Test we can download with explicit token
131+
with tempfile.TemporaryDirectory() as tmpdirname:
132+
storage_folder = snapshot_download(
133+
f"{USER}/{REPO_NAME}",
134+
revision="main",
135+
cache_dir=tmpdirname,
136+
use_auth_token=self._token,
137+
)
138+
139+
# folder contains the two files contributed and the .gitattributes
140+
folder_contents = os.listdir(storage_folder)
141+
self.assertEqual(len(folder_contents), 3)
142+
self.assertTrue("dummy_file.txt" in folder_contents)
143+
self.assertTrue("dummy_file_2.txt" in folder_contents)
144+
self.assertTrue(".gitattributes" in folder_contents)
145+
146+
with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f:
147+
contents = f.read()
148+
self.assertEqual(contents, "v2")
149+
150+
# folder name contains the revision's commit sha.
151+
self.assertTrue(self.second_commit_hash in storage_folder)
152+
153+
self._api.update_repo_visibility(self._token, REPO_NAME, private=False)

0 commit comments

Comments
 (0)