Skip to content

Commit 411e378

Browse files
authored
Add 'gated' search parameter (#2448)
1 parent c9c39b8 commit 411e378

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,7 @@ def list_models(
15951595
# Search-query parameter
15961596
filter: Union[str, Iterable[str], None] = None,
15971597
author: Optional[str] = None,
1598+
gated: Optional[bool] = None,
15981599
library: Optional[Union[str, List[str]]] = None,
15991600
language: Optional[Union[str, List[str]]] = None,
16001601
model_name: Optional[str] = None,
@@ -1624,6 +1625,10 @@ def list_models(
16241625
author (`str`, *optional*):
16251626
A string which identify the author (user or organization) of the
16261627
returned models
1628+
gated (`bool`, *optional*):
1629+
A boolean to filter models on the Hub that are gated or not. By default, all models are returned.
1630+
If `gated=True` is passed, only gated models are returned.
1631+
If `gated=False` is passed, only non-gated models are returned.
16271632
library (`str` or `List`, *optional*):
16281633
A string or list of strings of foundational libraries models were
16291634
originally trained from, such as pytorch, tensorflow, or allennlp.
@@ -1749,6 +1754,8 @@ def list_models(
17491754
# Handle other query params
17501755
if author:
17511756
params["author"] = author
1757+
if gated is not None:
1758+
params["gated"] = gated
17521759
if pipeline_tag:
17531760
params["pipeline_tag"] = pipeline_tag
17541761
search_list = []
@@ -1795,6 +1802,7 @@ def list_datasets(
17951802
author: Optional[str] = None,
17961803
benchmark: Optional[Union[str, List[str]]] = None,
17971804
dataset_name: Optional[str] = None,
1805+
gated: Optional[bool] = None,
17981806
language_creators: Optional[Union[str, List[str]]] = None,
17991807
language: Optional[Union[str, List[str]]] = None,
18001808
multilinguality: Optional[Union[str, List[str]]] = None,
@@ -1826,6 +1834,10 @@ def list_datasets(
18261834
dataset_name (`str`, *optional*):
18271835
A string or list of strings that can be used to identify datasets on
18281836
the Hub by its name, such as `SQAC` or `wikineural`
1837+
gated (`bool`, *optional*):
1838+
A boolean to filter datasets on the Hub that are gated or not. By default, all datasets are returned.
1839+
If `gated=True` is passed, only gated datasets are returned.
1840+
If `gated=False` is passed, only non-gated datasets are returned.
18291841
language_creators (`str` or `List`, *optional*):
18301842
A string or list of strings that can be used to identify datasets on
18311843
the Hub with how the data was curated, such as `crowdsourced` or
@@ -1954,6 +1966,8 @@ def list_datasets(
19541966
# Handle other query params
19551967
if author:
19561968
params["author"] = author
1969+
if gated is not None:
1970+
params["gated"] = gated
19571971
search_list = []
19581972
if dataset_name:
19591973
search_list.append(dataset_name)

tests/test_hf_api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,6 +1785,14 @@ def test_list_models_expand_cannot_be_used_with_other_params(self):
17851785
with self.assertRaises(ValueError):
17861786
next(self._api.list_models(expand=["author"], cardData=True))
17871787

1788+
def test_list_models_gated_only(self):
1789+
for model in self._api.list_models(expand=["gated"], gated=True, limit=5):
1790+
assert model.gated in ("auto", "manual")
1791+
1792+
def test_list_models_non_gated_only(self):
1793+
for model in self._api.list_models(expand=["gated"], gated=False, limit=5):
1794+
assert model.gated is False
1795+
17881796
def test_model_info(self):
17891797
model = self._api.model_info(repo_id=DUMMY_MODEL_ID)
17901798
self.assertIsInstance(model, ModelInfo)
@@ -2009,6 +2017,14 @@ def test_list_datasets_expand_cannot_be_used_with_full(self):
20092017
with self.assertRaises(ValueError):
20102018
next(self._api.list_datasets(expand=["author"], full=True))
20112019

2020+
def test_list_datasets_gated_only(self):
2021+
for dataset in self._api.list_datasets(expand=["gated"], gated=True, limit=5):
2022+
assert dataset.gated in ("auto", "manual")
2023+
2024+
def test_list_datasets_non_gated_only(self):
2025+
for dataset in self._api.list_datasets(expand=["gated"], gated=False, limit=5):
2026+
assert dataset.gated is False
2027+
20122028
def test_filter_datasets_with_card_data(self):
20132029
assert any(dataset.card_data is not None for dataset in self._api.list_datasets(full=True, limit=50))
20142030
assert all(dataset.card_data is None for dataset in self._api.list_datasets(full=False, limit=50))

0 commit comments

Comments
 (0)