Skip to content

Commit cc28c37

Browse files
Disable logging in with organization token (#780)
* disabled organization token * checking under token _validate_or_retrieve_token() * added test for HfFolder token is organization token * refactored token retrieval & validation * refactored _validate_or_retrieve_token * refactored for repetitions * Update src/huggingface_hub/hf_api.py Co-authored-by: Adrin Jalali <[email protected]> * added docstring * added returns and raises Co-authored-by: Adrin Jalali <[email protected]>
1 parent 3e72bc8 commit cc28c37

File tree

4 files changed

+79
-60
lines changed

4 files changed

+79
-60
lines changed

src/huggingface_hub/commands/user.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,8 @@ def _login(hf_api, username=None, password=None, token=None):
382382
print(e)
383383
print(ANSI.red(e.response.text))
384384
exit(1)
385-
elif not hf_api._is_valid_token(token):
386-
raise ValueError("Invalid token passed.")
385+
else:
386+
token, name = hf_api._validate_or_retrieve_token(token)
387387

388388
hf_api.set_access_token(token)
389389
HfFolder.save_token(token)

src/huggingface_hub/hf_api.py

Lines changed: 53 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,6 @@ def whoami(self, token: Optional[str] = None) -> Dict:
523523
"You need to pass a valid `token` or login by using `huggingface-cli "
524524
"login`"
525525
)
526-
527526
path = f"{self.endpoint}/api/whoami-v2"
528527
r = requests.get(path, headers={"authorization": f"Bearer {token}"})
529528
try:
@@ -553,19 +552,26 @@ def _is_valid_token(self, token: str):
553552
except HTTPError:
554553
return False
555554

556-
def _validate_or_retrieve_token(self, token: Optional[Union[str, bool]] = None):
555+
def _validate_or_retrieve_token(
556+
self,
557+
token: Optional[str] = None,
558+
name: Optional[str] = None,
559+
function_name: Optional[str] = None,
560+
):
557561
"""
558-
Either retrieves stored token or validates passed token.
559-
562+
Retrieves and validates stored token or validates passed token.
560563
Args:
561-
token (`str`, *optional*):
562-
The token to check for validity
563-
564+
token (``str``, `optional`):
565+
Hugging Face token. Will default to the locally saved token if not provided.
566+
name (``str``, `optional`):
567+
Name of the repository. This is deprecated in favor of repo_id and will be removed in v0.7.
568+
function_name (``str``, `optional`):
569+
If _validate_or_retrieve_token is called from a function, name of that function to be passed inside deprecation warning.
564570
Returns:
565-
`str`: The valid token
566-
571+
Validated token and the name of the repository.
567572
Raises:
568-
`ValueError`: if the token is invalid.
573+
:class:`EnvironmentError`: If the token is not passed and there's no token saved locally.
574+
:class:`ValueError`: If organization token or invalid token is passed.
569575
"""
570576
if token is None or token is True:
571577
token = HfFolder.get_token()
@@ -574,9 +580,22 @@ def _validate_or_retrieve_token(self, token: Optional[Union[str, bool]] = None):
574580
"You need to provide a `token` or be logged in to Hugging "
575581
"Face with `huggingface-cli login`."
576582
)
577-
elif not self._is_valid_token(token):
578-
raise ValueError("Invalid token passed!")
579-
return token
583+
if name is not None:
584+
if self._is_valid_token(name):
585+
# TODO(0.6) REMOVE
586+
warnings.warn(
587+
f"`{function_name}` now takes `token` as an optional positional argument. "
588+
"Be sure to adapt your code!",
589+
FutureWarning,
590+
)
591+
token, name = name, token
592+
if isinstance(token, str):
593+
if token.startswith("api_org"):
594+
raise ValueError("You must use your personal account token.")
595+
if not self._is_valid_token(token):
596+
raise ValueError("Invalid token passed!")
597+
598+
return token, name
580599

581600
def logout(self, token: Optional[str] = None) -> None:
582601
"""
@@ -758,7 +777,7 @@ def list_models(
758777
"""
759778
path = f"{self.endpoint}/api/models"
760779
if use_auth_token:
761-
token = self._validate_or_retrieve_token(use_auth_token)
780+
token, name = self._validate_or_retrieve_token(use_auth_token)
762781
headers = {"authorization": f"Bearer {token}"} if use_auth_token else None
763782
params = {}
764783
if filter is not None:
@@ -955,7 +974,7 @@ def list_datasets(
955974
"""
956975
path = f"{self.endpoint}/api/datasets"
957976
if use_auth_token:
958-
token = self._validate_or_retrieve_token(use_auth_token)
977+
token, name = self._validate_or_retrieve_token(use_auth_token)
959978
headers = {"authorization": f"Bearer {token}"} if use_auth_token else None
960979
params = {}
961980
if filter is not None:
@@ -1240,18 +1259,10 @@ def create_repo(
12401259
name, organization = _validate_repo_id_deprecation(repo_id, name, organization)
12411260

12421261
path = f"{self.endpoint}/api/repos/create"
1243-
if token is None:
1244-
token = self._validate_or_retrieve_token()
1245-
elif not self._is_valid_token(token):
1246-
if self._is_valid_token(name):
1247-
warnings.warn(
1248-
"`create_repo` now takes `token` as an optional positional argument. "
1249-
"Be sure to adapt your code!",
1250-
FutureWarning,
1251-
)
1252-
token, name = name, token
1253-
else:
1254-
raise ValueError("Invalid token passed!")
1262+
1263+
token, name = self._validate_or_retrieve_token(
1264+
token, name, function_name="create_repo"
1265+
)
12551266

12561267
checked_name = repo_type_and_id_from_hf_id(name)
12571268

@@ -1369,18 +1380,10 @@ def delete_repo(
13691380
name, organization = _validate_repo_id_deprecation(repo_id, name, organization)
13701381

13711382
path = f"{self.endpoint}/api/repos/delete"
1372-
if token is None:
1373-
token = self._validate_or_retrieve_token()
1374-
elif not self._is_valid_token(token):
1375-
if self._is_valid_token(name):
1376-
warnings.warn(
1377-
"`delete_repo` now takes `token` as an optional positional argument. "
1378-
"Be sure to adapt your code!",
1379-
FutureWarning,
1380-
)
1381-
token, name = name, token
1382-
else:
1383-
raise ValueError("Invalid token passed!")
1383+
1384+
token, name = self._validate_or_retrieve_token(
1385+
token, name, function_name="delete_repo"
1386+
)
13841387

13851388
checked_name = repo_type_and_id_from_hf_id(name)
13861389

@@ -1480,18 +1483,9 @@ def update_repo_visibility(
14801483

14811484
name, organization = _validate_repo_id_deprecation(repo_id, name, organization)
14821485

1483-
if token is None:
1484-
token = self._validate_or_retrieve_token()
1485-
elif not self._is_valid_token(token):
1486-
if self._is_valid_token(name):
1487-
warnings.warn(
1488-
"`update_repo_visibility` now takes `token` as an optional positional argument. "
1489-
"Be sure to adapt your code!",
1490-
FutureWarning,
1491-
)
1492-
token, name, private = name, private, token
1493-
else:
1494-
raise ValueError("Invalid token passed!")
1486+
token, name = self._validate_or_retrieve_token(
1487+
token, name, function_name="update_repo_visibility"
1488+
)
14951489

14961490
if organization is None:
14971491
namespace = self.whoami(token)["name"]
@@ -1548,7 +1542,8 @@ def move_repo(
15481542
15491543
- [1] https://huggingface.co/settings/tokens
15501544
"""
1551-
token = self._validate_or_retrieve_token(token)
1545+
1546+
token, name = self._validate_or_retrieve_token(token)
15521547

15531548
if len(from_id.split("/")) != 2:
15541549
raise ValueError(
@@ -1664,9 +1659,11 @@ def upload_file(
16641659
if repo_type not in REPO_TYPES:
16651660
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
16661661

1667-
if token is None:
1668-
token = self._validate_or_retrieve_token()
1669-
elif not self._is_valid_token(token):
1662+
try:
1663+
token, name = self._validate_or_retrieve_token(
1664+
token, function_name="upload_file"
1665+
)
1666+
except ValueError: # if token is invalid or organization token
16701667
if self._is_valid_token(path_or_fileobj):
16711668
warnings.warn(
16721669
"`upload_file` now takes `token` as an optional positional argument. "
@@ -1769,8 +1766,7 @@ def delete_file(
17691766
if repo_type not in REPO_TYPES:
17701767
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
17711768

1772-
if token is None:
1773-
token = self._validate_or_retrieve_token()
1769+
token, name = self._validate_or_retrieve_token(token)
17741770

17751771
if repo_type in REPO_TYPES_URL_PREFIXES:
17761772
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id

src/huggingface_hub/hub_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def push_to_hub(
276276
token = HfFolder.get_token()
277277
if token is None:
278278
raise ValueError(
279-
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
279+
"You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
280280
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
281281
"token as the `use_auth_token` argument."
282282
)

tests/test_hf_api.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,12 @@ def test_login_cli(self):
151151
read_from_credential_store(USERNAME_PLACEHOLDER), (None, None)
152152
)
153153

154+
def test_login_cli_org_fail(self):
155+
with pytest.raises(
156+
ValueError, match="You must use your personal account token."
157+
):
158+
_login(self._api, token="api_org_dummy_token")
159+
154160
def test_login_deprecation_error(self):
155161
with pytest.warns(
156162
FutureWarning,
@@ -561,6 +567,23 @@ def test_upload_file_bytesio(self):
561567
finally:
562568
self._api.delete_repo(repo_id=REPO_NAME, token=self._token)
563569

570+
@retry_endpoint
571+
def test_create_repo_org_token_fail(self):
572+
REPO_NAME = repo_name("org")
573+
with pytest.raises(
574+
ValueError, match="You must use your personal account token."
575+
):
576+
self._api.create_repo(repo_id=REPO_NAME, token="api_org_dummy_token")
577+
578+
@retry_endpoint
579+
def test_create_repo_org_token_none_fail(self):
580+
REPO_NAME = repo_name("org")
581+
HfFolder.save_token("api_org_dummy_token")
582+
with pytest.raises(
583+
ValueError, match="You must use your personal account token."
584+
):
585+
self._api.create_repo(repo_id=REPO_NAME)
586+
564587
@retry_endpoint
565588
def test_upload_file_conflict(self):
566589
REPO_NAME = repo_name("conflict")

0 commit comments

Comments
 (0)