|
17 | 17 | import re |
18 | 18 | import warnings |
19 | 19 | from dataclasses import dataclass, field |
| 20 | +from itertools import islice |
20 | 21 | from pathlib import Path |
21 | 22 | from typing import Any, BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union |
22 | 23 | from urllib.parse import quote |
|
71 | 72 | _deprecate_method, |
72 | 73 | _deprecate_positional_args, |
73 | 74 | ) |
| 75 | +from .utils._pagination import paginate |
74 | 76 | from .utils._typing import Literal, TypedDict |
75 | 77 | from .utils.endpoint_helpers import ( |
76 | 78 | AttributeDictionary, |
@@ -808,15 +810,11 @@ def list_models( |
808 | 810 | params.update({"config": True}) |
809 | 811 | if cardData: |
810 | 812 | params.update({"cardData": True}) |
811 | | - r = requests.get(path, params=params, headers=headers) |
812 | | - hf_raise_for_status(r) |
813 | | - items = [ModelInfo(**x) for x in r.json()] |
814 | 813 |
|
815 | | - # If pagination has been enabled server-side, older versions of `huggingface_hub` |
816 | | - # are deprecated as output is truncated. |
817 | | - _warn_if_truncated( |
818 | | - items, total_count=r.headers.get("X-Total-Count"), limit=limit |
819 | | - ) |
| 814 | + data = paginate(path, params=params, headers=headers) |
| 815 | + if limit is not None: |
| 816 | + data = islice(data, limit) # Do not iterate over all pages |
| 817 | + items = [ModelInfo(**x) for x in data] |
820 | 818 |
|
821 | 819 | if emissions_thresholds is not None: |
822 | 820 | if cardData is None: |
@@ -1015,17 +1013,11 @@ def list_datasets( |
1015 | 1013 | params.update({"limit": limit}) |
1016 | 1014 | if full or cardData: |
1017 | 1015 | params.update({"full": True}) |
1018 | | - r = requests.get(path, params=params, headers=headers) |
1019 | | - hf_raise_for_status(r) |
1020 | | - items = [DatasetInfo(**x) for x in r.json()] |
1021 | 1016 |
|
1022 | | - # If pagination has been enabled server-side, older versions of `huggingface_hub` |
1023 | | - # are deprecated as output is truncated. |
1024 | | - _warn_if_truncated( |
1025 | | - items, total_count=r.headers.get("X-Total-Count"), limit=limit |
1026 | | - ) |
1027 | | - |
1028 | | - return items |
| 1017 | + data = paginate(path, params=params, headers=headers) |
| 1018 | + if limit is not None: |
| 1019 | + data = islice(data, limit) # Do not iterate over all pages |
| 1020 | + return [DatasetInfo(**x) for x in data] |
1029 | 1021 |
|
1030 | 1022 | def _unpack_dataset_filter(self, dataset_filter: DatasetFilter): |
1031 | 1023 | """ |
@@ -1162,17 +1154,11 @@ def list_spaces( |
1162 | 1154 | params.update({"datasets": datasets}) |
1163 | 1155 | if models is not None: |
1164 | 1156 | params.update({"models": models}) |
1165 | | - r = requests.get(path, params=params, headers=headers) |
1166 | | - hf_raise_for_status(r) |
1167 | | - items = [SpaceInfo(**x) for x in r.json()] |
1168 | | - |
1169 | | - # If pagination has been enabled server-side, older versions of `huggingface_hub` |
1170 | | - # are deprecated as output is truncated. |
1171 | | - _warn_if_truncated( |
1172 | | - items, total_count=r.headers.get("X-Total-Count"), limit=limit |
1173 | | - ) |
1174 | 1157 |
|
1175 | | - return items |
| 1158 | + data = paginate(path, params=params, headers=headers) |
| 1159 | + if limit is not None: |
| 1160 | + data = islice(data, limit) # Do not iterate over all pages |
| 1161 | + return [SpaceInfo(**x) for x in data] |
1176 | 1162 |
|
1177 | 1163 | @validate_hf_hub_args |
1178 | 1164 | def model_info( |
@@ -3474,38 +3460,6 @@ def _parse_revision_from_pr_url(pr_url: str) -> str: |
3474 | 3460 | return f"refs/pr/{re_match[1]}" |
3475 | 3461 |
|
3476 | 3462 |
|
3477 | | -def _warn_if_truncated( |
3478 | | - items: List[Any], limit: Optional[int], total_count: Optional[str] |
3479 | | -) -> None: |
3480 | | - # TODO: remove this once pagination is properly implemented in `huggingface_hub`. |
3481 | | - if total_count is None: |
3482 | | - # Total count header not implemented |
3483 | | - return |
3484 | | - |
3485 | | - try: |
3486 | | - total_count_int = int(total_count) |
3487 | | - except ValueError: |
3488 | | - # Total count header not implemented properly server-side |
3489 | | - return |
3490 | | - |
3491 | | - if len(items) == total_count_int: |
3492 | | - # All items have been returned => not truncated |
3493 | | - return |
3494 | | - |
3495 | | - if limit is not None and len(items) == limit: |
3496 | | - # `limit` is set => truncation is expected |
3497 | | - return |
3498 | | - |
3499 | | - # Otherwise, pagination has been enabled server-side and the output has been |
3500 | | - # truncated by server => warn user. |
3501 | | - warnings.warn( |
3502 | | - "The list of repos returned by the server has been truncated. Listing repos" |
3503 | | - " from the Hub using `list_models`, `list_datasets` and `list_spaces` now" |
3504 | | - " requires pagination. To get the full list of repos, please consider upgrading" |
3505 | | - " `huggingface_hub` to its latest version." |
3506 | | - ) |
3507 | | - |
3508 | | - |
3509 | 3463 | api = HfApi() |
3510 | 3464 |
|
3511 | 3465 | set_access_token = api.set_access_token |
|
0 commit comments