|
3 | 3 | import time |
4 | 4 | import unittest |
5 | 5 |
|
| 6 | +import requests |
6 | 7 | from huggingface_hub import HfApi, Repository |
| 8 | +from huggingface_hub.hf_api import HfFolder |
7 | 9 | from huggingface_hub.snapshot_download import snapshot_download |
8 | 10 | from tests.testing_constants import ENDPOINT_STAGING, PASS, USER |
9 | 11 |
|
@@ -88,3 +90,64 @@ def test_download_model(self): |
88 | 90 |
|
89 | 91 | # folder name contains the revision's commit sha. |
90 | 92 | self.assertTrue(self.first_commit_hash in storage_folder) |
| 93 | + |
| 94 | + def test_download_private_model(self): |
| 95 | + self._api.update_repo_visibility(self._token, REPO_NAME, private=True) |
| 96 | + |
| 97 | + # Test download fails without token |
| 98 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 99 | + with self.assertRaisesRegex( |
| 100 | + requests.exceptions.HTTPError, "404 Client Error" |
| 101 | + ): |
| 102 | + _ = snapshot_download( |
| 103 | + f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname |
| 104 | + ) |
| 105 | + |
| 106 | + # Test we can download with token from cache |
| 107 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 108 | + HfFolder.save_token(self._token) |
| 109 | + storage_folder = snapshot_download( |
| 110 | + f"{USER}/{REPO_NAME}", |
| 111 | + revision="main", |
| 112 | + cache_dir=tmpdirname, |
| 113 | + use_auth_token=True, |
| 114 | + ) |
| 115 | + |
| 116 | + # folder contains the two files contributed and the .gitattributes |
| 117 | + folder_contents = os.listdir(storage_folder) |
| 118 | + self.assertEqual(len(folder_contents), 3) |
| 119 | + self.assertTrue("dummy_file.txt" in folder_contents) |
| 120 | + self.assertTrue("dummy_file_2.txt" in folder_contents) |
| 121 | + self.assertTrue(".gitattributes" in folder_contents) |
| 122 | + |
| 123 | + with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: |
| 124 | + contents = f.read() |
| 125 | + self.assertEqual(contents, "v2") |
| 126 | + |
| 127 | + # folder name contains the revision's commit sha. |
| 128 | + self.assertTrue(self.second_commit_hash in storage_folder) |
| 129 | + |
| 130 | + # Test we can download with explicit token |
| 131 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 132 | + storage_folder = snapshot_download( |
| 133 | + f"{USER}/{REPO_NAME}", |
| 134 | + revision="main", |
| 135 | + cache_dir=tmpdirname, |
| 136 | + use_auth_token=self._token, |
| 137 | + ) |
| 138 | + |
| 139 | + # folder contains the two files contributed and the .gitattributes |
| 140 | + folder_contents = os.listdir(storage_folder) |
| 141 | + self.assertEqual(len(folder_contents), 3) |
| 142 | + self.assertTrue("dummy_file.txt" in folder_contents) |
| 143 | + self.assertTrue("dummy_file_2.txt" in folder_contents) |
| 144 | + self.assertTrue(".gitattributes" in folder_contents) |
| 145 | + |
| 146 | + with open(os.path.join(storage_folder, "dummy_file.txt"), "r") as f: |
| 147 | + contents = f.read() |
| 148 | + self.assertEqual(contents, "v2") |
| 149 | + |
| 150 | + # folder name contains the revision's commit sha. |
| 151 | + self.assertTrue(self.second_commit_hash in storage_folder) |
| 152 | + |
| 153 | + self._api.update_repo_visibility(self._token, REPO_NAME, private=False) |
0 commit comments