Skip to content

Commit 711f688

Browse files
authored
Add support for pagination in list_models list_datasets and list_spaces (#1176)
* Add support for (future) pagination * better (and fix) handling of links header * Limit pagination * add comment
1 parent 337351d commit 711f688

File tree

4 files changed

+144
-84
lines changed

4 files changed

+144
-84
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 14 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import re
1818
import warnings
1919
from dataclasses import dataclass, field
20+
from itertools import islice
2021
from pathlib import Path
2122
from typing import Any, BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union
2223
from urllib.parse import quote
@@ -71,6 +72,7 @@
7172
_deprecate_method,
7273
_deprecate_positional_args,
7374
)
75+
from .utils._pagination import paginate
7476
from .utils._typing import Literal, TypedDict
7577
from .utils.endpoint_helpers import (
7678
AttributeDictionary,
@@ -808,15 +810,11 @@ def list_models(
808810
params.update({"config": True})
809811
if cardData:
810812
params.update({"cardData": True})
811-
r = requests.get(path, params=params, headers=headers)
812-
hf_raise_for_status(r)
813-
items = [ModelInfo(**x) for x in r.json()]
814813

815-
# If pagination has been enabled server-side, older versions of `huggingface_hub`
816-
# are deprecated as output is truncated.
817-
_warn_if_truncated(
818-
items, total_count=r.headers.get("X-Total-Count"), limit=limit
819-
)
814+
data = paginate(path, params=params, headers=headers)
815+
if limit is not None:
816+
data = islice(data, limit) # Do not iterate over all pages
817+
items = [ModelInfo(**x) for x in data]
820818

821819
if emissions_thresholds is not None:
822820
if cardData is None:
@@ -1015,17 +1013,11 @@ def list_datasets(
10151013
params.update({"limit": limit})
10161014
if full or cardData:
10171015
params.update({"full": True})
1018-
r = requests.get(path, params=params, headers=headers)
1019-
hf_raise_for_status(r)
1020-
items = [DatasetInfo(**x) for x in r.json()]
10211016

1022-
# If pagination has been enabled server-side, older versions of `huggingface_hub`
1023-
# are deprecated as output is truncated.
1024-
_warn_if_truncated(
1025-
items, total_count=r.headers.get("X-Total-Count"), limit=limit
1026-
)
1027-
1028-
return items
1017+
data = paginate(path, params=params, headers=headers)
1018+
if limit is not None:
1019+
data = islice(data, limit) # Do not iterate over all pages
1020+
return [DatasetInfo(**x) for x in data]
10291021

10301022
def _unpack_dataset_filter(self, dataset_filter: DatasetFilter):
10311023
"""
@@ -1162,17 +1154,11 @@ def list_spaces(
11621154
params.update({"datasets": datasets})
11631155
if models is not None:
11641156
params.update({"models": models})
1165-
r = requests.get(path, params=params, headers=headers)
1166-
hf_raise_for_status(r)
1167-
items = [SpaceInfo(**x) for x in r.json()]
1168-
1169-
# If pagination has been enabled server-side, older versions of `huggingface_hub`
1170-
# are deprecated as output is truncated.
1171-
_warn_if_truncated(
1172-
items, total_count=r.headers.get("X-Total-Count"), limit=limit
1173-
)
11741157

1175-
return items
1158+
data = paginate(path, params=params, headers=headers)
1159+
if limit is not None:
1160+
data = islice(data, limit) # Do not iterate over all pages
1161+
return [SpaceInfo(**x) for x in data]
11761162

11771163
@validate_hf_hub_args
11781164
def model_info(
@@ -3474,38 +3460,6 @@ def _parse_revision_from_pr_url(pr_url: str) -> str:
34743460
return f"refs/pr/{re_match[1]}"
34753461

34763462

3477-
def _warn_if_truncated(
3478-
items: List[Any], limit: Optional[int], total_count: Optional[str]
3479-
) -> None:
3480-
# TODO: remove this once pagination is properly implemented in `huggingface_hub`.
3481-
if total_count is None:
3482-
# Total count header not implemented
3483-
return
3484-
3485-
try:
3486-
total_count_int = int(total_count)
3487-
except ValueError:
3488-
# Total count header not implemented properly server-side
3489-
return
3490-
3491-
if len(items) == total_count_int:
3492-
# All items have been returned => not truncated
3493-
return
3494-
3495-
if limit is not None and len(items) == limit:
3496-
# `limit` is set => truncation is expected
3497-
return
3498-
3499-
# Otherwise, pagination has been enabled server-side and the output has been
3500-
# truncated by server => warn user.
3501-
warnings.warn(
3502-
"The list of repos returned by the server has been truncated. Listing repos"
3503-
" from the Hub using `list_models`, `list_datasets` and `list_spaces` now"
3504-
" requires pagination. To get the full list of repos, please consider upgrading"
3505-
" `huggingface_hub` to its latest version."
3506-
)
3507-
3508-
35093463
api = HfApi()
35103464

35113465
set_access_token = api.set_access_token
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# coding=utf-8
2+
# Copyright 2022-present, the HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Contains utilities to handle pagination on Huggingface Hub."""
16+
from typing import Dict, Iterable, Optional
17+
18+
import requests
19+
20+
from . import hf_raise_for_status, logging
21+
22+
23+
logger = logging.get_logger(__name__)
24+
25+
26+
def paginate(path: str, params: Dict, headers: Dict) -> Iterable:
27+
"""Fetch a list of models/datasets/spaces and paginate through results.
28+
29+
For now, pagination is not mandatory on the Hub. However at some point the number of
30+
repos per page will be limited for performance reasons. This helper makes `huggingface_hub`
31+
compliant with future server-side updates.
32+
33+
This is using the same "Link" header format as GitHub.
34+
See:
35+
- https://requests.readthedocs.io/en/latest/api/#requests.Response.links
36+
- https://docs.github.com/en/rest/guides/traversing-with-pagination#link-header
37+
"""
38+
r = requests.get(path, params=params, headers=headers)
39+
hf_raise_for_status(r)
40+
yield from r.json()
41+
42+
# If pagination is implemented server-side, follow pages
43+
# Next link already contains query params
44+
next_page = _get_next_page(r)
45+
while next_page is not None:
46+
logger.debug(f"Pagination detected. Requesting next page: {next_page}")
47+
r = requests.get(next_page, headers=headers)
48+
hf_raise_for_status(r)
49+
yield from r.json()
50+
next_page = _get_next_page(r)
51+
52+
53+
def _get_next_page(response: requests.Response) -> Optional[str]:
54+
return response.links.get("next", {}).get("url")

tests/test_hf_api.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
ModelSearchArguments,
5656
RepoFile,
5757
SpaceInfo,
58-
_warn_if_truncated,
5958
erase_from_credential_store,
6059
read_from_credential_store,
6160
repo_type_and_id_from_hf_id,
@@ -2297,26 +2296,3 @@ def _assert_token_is(
22972296
self, mock_build_hf_headers: Mock, expected_value: str
22982297
) -> None:
22992298
self.assertEqual(mock_build_hf_headers.call_args[1]["token"], expected_value)
2300-
2301-
2302-
class WarnIfTruncatedTest(unittest.TestCase):
2303-
def test_warn_if_truncated(self) -> None:
2304-
# Can't tell if output is truncated
2305-
_warn_if_truncated([1, 2, 3], limit=None, total_count=None)
2306-
2307-
# Can't tell if output is truncated
2308-
_warn_if_truncated([1, 2, 3], limit=None, total_count="foo")
2309-
2310-
# All items returned
2311-
_warn_if_truncated([1, 2, 3], limit=None, total_count="3")
2312-
2313-
# Output is truncated (no limit, received 3)
2314-
with self.assertWarns(UserWarning):
2315-
_warn_if_truncated([1, 2, 3], limit=None, total_count="5")
2316-
2317-
# Output is truncated (limit is 4, received 3)
2318-
with self.assertWarns(UserWarning):
2319-
_warn_if_truncated([1, 2, 3], limit=4, total_count="5")
2320-
2321-
# Output is truncated by the user
2322-
_warn_if_truncated([1, 2, 3], limit=3, total_count="5")

tests/test_utils_pagination.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import unittest
2+
from unittest.mock import Mock, call, patch
3+
4+
from huggingface_hub.utils._pagination import paginate
5+
6+
from .testing_utils import handle_injection_in_test
7+
8+
9+
class TestPagination(unittest.TestCase):
10+
@patch("huggingface_hub.utils._pagination.requests.get")
11+
@patch("huggingface_hub.utils._pagination.hf_raise_for_status")
12+
@handle_injection_in_test
13+
def test_mocked_paginate(
14+
self, mock_get: Mock, mock_hf_raise_for_status: Mock
15+
) -> None:
16+
mock_params = Mock()
17+
mock_headers = Mock()
18+
19+
# Simulate page 1
20+
mock_response_page_1 = Mock()
21+
mock_response_page_1.json.return_value = [1, 2, 3]
22+
mock_response_page_1.links = {"next": {"url": "url_p2"}}
23+
24+
# Simulate page 2
25+
mock_response_page_2 = Mock()
26+
mock_response_page_2.json.return_value = [4, 5, 6]
27+
mock_response_page_2.links = {"next": {"url": "url_p3"}}
28+
29+
# Simulate page 3
30+
mock_response_page_3 = Mock()
31+
mock_response_page_3.json.return_value = [7, 8]
32+
mock_response_page_3.links = {}
33+
34+
# Mock response
35+
mock_get.side_effect = [
36+
mock_response_page_1,
37+
mock_response_page_2,
38+
mock_response_page_3,
39+
]
40+
41+
results = paginate("url", params=mock_params, headers=mock_headers)
42+
43+
# Requests are made only when generator is yielded
44+
self.assertEqual(mock_get.call_count, 0)
45+
46+
# Results after concatenating pages
47+
self.assertListEqual(list(results), [1, 2, 3, 4, 5, 6, 7, 8])
48+
49+
# All pages requested: 3 requests, 3 raise for status
50+
self.assertEqual(mock_get.call_count, 3)
51+
self.assertEqual(mock_hf_raise_for_status.call_count, 3)
52+
53+
# Params not passed to next pages
54+
self.assertListEqual(
55+
mock_get.call_args_list,
56+
[
57+
call("url", params=mock_params, headers=mock_headers),
58+
call("url_p2", headers=mock_headers),
59+
call("url_p3", headers=mock_headers),
60+
],
61+
)
62+
63+
def test_paginate_github_api(self) -> None:
64+
# Real test: paginate over huggingface repos on Github
65+
# Use enumerate and stop after first page to avoid loading all repos
66+
for num, _ in enumerate(
67+
paginate(
68+
"https://api.github.com/orgs/huggingface/repos?limit=4",
69+
params={},
70+
headers={},
71+
)
72+
):
73+
if num == 6:
74+
break
75+
else:
76+
self.fail("Did not get more than 6 repos")

0 commit comments

Comments
 (0)