Skip to content

Commit eeb0e73

Browse files
committed
Paginated results in list_user_access (#3535)
1 parent 5b7f672 commit eeb0e73

File tree

2 files changed

+26
-29
lines changed

2 files changed

+26
-29
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8511,7 +8511,7 @@ def delete_collection_item(
85118511
@validate_hf_hub_args
85128512
def list_pending_access_requests(
85138513
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
8514-
) -> list[AccessRequest]:
8514+
) -> Iterable[AccessRequest]:
85158515
"""
85168516
Get pending access requests for a given gated repo.
85178517
@@ -8534,7 +8534,7 @@ def list_pending_access_requests(
85348534
To disable authentication, pass `False`.
85358535
85368536
Returns:
8537-
`list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`,
8537+
`Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`,
85388538
`status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will
85398539
be populated with user's answers.
85408540
@@ -8550,7 +8550,7 @@ def list_pending_access_requests(
85508550
>>> from huggingface_hub import list_pending_access_requests, accept_access_request
85518551
85528552
# List pending requests
8553-
>>> requests = list_pending_access_requests("meta-llama/Llama-2-7b")
8553+
>>> requests = list(list_pending_access_requests("meta-llama/Llama-2-7b"))
85548554
>>> len(requests)
85558555
411
85568556
>>> requests[0]
@@ -8570,12 +8570,12 @@ def list_pending_access_requests(
85708570
>>> accept_access_request("meta-llama/Llama-2-7b", "clem")
85718571
```
85728572
"""
8573-
return self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token)
8573+
yield from self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token)
85748574

85758575
@validate_hf_hub_args
85768576
def list_accepted_access_requests(
85778577
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
8578-
) -> list[AccessRequest]:
8578+
) -> Iterable[AccessRequest]:
85798579
"""
85808580
Get accepted access requests for a given gated repo.
85818581
@@ -8600,7 +8600,7 @@ def list_accepted_access_requests(
86008600
To disable authentication, pass `False`.
86018601
86028602
Returns:
8603-
`list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`,
8603+
`Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`,
86048604
`status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will
86058605
be populated with user's answers.
86068606
@@ -8615,7 +8615,7 @@ def list_accepted_access_requests(
86158615
```py
86168616
>>> from huggingface_hub import list_accepted_access_requests
86178617
8618-
>>> requests = list_accepted_access_requests("meta-llama/Llama-2-7b")
8618+
>>> requests = list(list_accepted_access_requests("meta-llama/Llama-2-7b"))
86198619
>>> len(requests)
86208620
411
86218621
>>> requests[0]
@@ -8632,12 +8632,12 @@ def list_accepted_access_requests(
86328632
]
86338633
```
86348634
"""
8635-
return self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token)
8635+
yield from self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token)
86368636

86378637
@validate_hf_hub_args
86388638
def list_rejected_access_requests(
86398639
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
8640-
) -> list[AccessRequest]:
8640+
) -> Iterable[AccessRequest]:
86418641
"""
86428642
Get rejected access requests for a given gated repo.
86438643
@@ -8662,7 +8662,7 @@ def list_rejected_access_requests(
86628662
To disable authentication, pass `False`.
86638663
86648664
Returns:
8665-
`list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`,
8665+
`Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`,
86668666
`status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will
86678667
be populated with user's answers.
86688668
@@ -8677,7 +8677,7 @@ def list_rejected_access_requests(
86778677
```py
86788678
>>> from huggingface_hub import list_rejected_access_requests
86798679
8680-
>>> requests = list_rejected_access_requests("meta-llama/Llama-2-7b")
8680+
>>> requests = list(list_rejected_access_requests("meta-llama/Llama-2-7b"))
86818681
>>> len(requests)
86828682
411
86838683
>>> requests[0]
@@ -8694,36 +8694,33 @@ def list_rejected_access_requests(
86948694
]
86958695
```
86968696
"""
8697-
return self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token)
8697+
yield from self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token)
86988698

86998699
def _list_access_requests(
87008700
self,
87018701
repo_id: str,
87028702
status: Literal["accepted", "rejected", "pending"],
87038703
repo_type: Optional[str] = None,
87048704
token: Union[bool, str, None] = None,
8705-
) -> list[AccessRequest]:
8705+
) -> Iterable[AccessRequest]:
87068706
if repo_type not in constants.REPO_TYPES:
87078707
raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}")
87088708
if repo_type is None:
87098709
repo_type = constants.REPO_TYPE_MODEL
87108710

8711-
response = get_session().get(
8711+
for request in paginate(
87128712
f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/{status}",
8713+
params={},
87138714
headers=self._build_hf_headers(token=token),
8714-
)
8715-
hf_raise_for_status(response)
8716-
return [
8717-
AccessRequest(
8715+
):
8716+
yield AccessRequest(
87188717
username=request["user"]["user"],
87198718
fullname=request["user"]["fullname"],
87208719
email=request["user"].get("email"),
87218720
status=request["status"],
87228721
timestamp=parse_datetime(request["timestamp"]),
87238722
fields=request.get("fields"), # only if custom fields in form
87248723
)
8725-
for request in response.json()
8726-
]
87278724

87288725
@validate_hf_hub_args
87298726
def cancel_access_request(

tests/test_hf_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4099,18 +4099,18 @@ def tearDown(self) -> None:
40994099

41004100
def test_access_requests_normal_usage(self) -> None:
41014101
# No access requests initially
4102-
requests = self._api.list_accepted_access_requests(self.repo_id)
4102+
requests = list(self._api.list_accepted_access_requests(self.repo_id))
41034103
assert len(requests) == 0
4104-
requests = self._api.list_pending_access_requests(self.repo_id)
4104+
requests = list(self._api.list_pending_access_requests(self.repo_id))
41054105
assert len(requests) == 0
4106-
requests = self._api.list_rejected_access_requests(self.repo_id)
4106+
requests = list(self._api.list_rejected_access_requests(self.repo_id))
41074107
assert len(requests) == 0
41084108

41094109
# Grant access to a user
41104110
self._api.grant_access(self.repo_id, OTHER_USER)
41114111

41124112
# User is in accepted list
4113-
requests = self._api.list_accepted_access_requests(self.repo_id)
4113+
requests = list(self._api.list_accepted_access_requests(self.repo_id))
41144114
assert len(requests) == 1
41154115
request = requests[0]
41164116
assert isinstance(request, AccessRequest)
@@ -4121,23 +4121,23 @@ def test_access_requests_normal_usage(self) -> None:
41214121

41224122
# Cancel access
41234123
self._api.cancel_access_request(self.repo_id, OTHER_USER)
4124-
requests = self._api.list_accepted_access_requests(self.repo_id)
4124+
requests = list(self._api.list_accepted_access_requests(self.repo_id))
41254125
assert len(requests) == 0 # not accepted anymore
4126-
requests = self._api.list_pending_access_requests(self.repo_id)
4126+
requests = list(self._api.list_pending_access_requests(self.repo_id))
41274127
assert len(requests) == 1
41284128
assert requests[0].username == OTHER_USER
41294129

41304130
# Reject access
41314131
self._api.reject_access_request(self.repo_id, OTHER_USER, rejection_reason="This is a rejection reason")
4132-
requests = self._api.list_pending_access_requests(self.repo_id)
4132+
requests = list(self._api.list_pending_access_requests(self.repo_id))
41334133
assert len(requests) == 0 # not pending anymore
4134-
requests = self._api.list_rejected_access_requests(self.repo_id)
4134+
requests = list(self._api.list_rejected_access_requests(self.repo_id))
41354135
assert len(requests) == 1
41364136
assert requests[0].username == OTHER_USER
41374137

41384138
# Accept again
41394139
self._api.accept_access_request(self.repo_id, OTHER_USER)
4140-
requests = self._api.list_accepted_access_requests(self.repo_id)
4140+
requests = list(self._api.list_accepted_access_requests(self.repo_id))
41414141
assert len(requests) == 1
41424142
assert requests[0].username == OTHER_USER
41434143

0 commit comments

Comments
 (0)