Skip to content

Commit 67dbcd6

Browse files
Deprecate library/tags/task/... filtering in list_models (#3318)
* Deprecate library/tags/task/... filtering in list_models * adapt docs * expect deprecation in tests --------- Co-authored-by: Celina Hanouti <[email protected]>
1 parent a46daf5 commit 67dbcd6

File tree

3 files changed

+25
-21
lines changed

3 files changed

+25
-21
lines changed

docs/source/en/guides/search.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,7 @@ The list helpers have several attributes like:
3333
Let's see an example to get all models on the Hub that does image classification, have been trained on the imagenet dataset and that runs with PyTorch.
3434

3535
```py
36-
models = hf_api.list_models(
37-
task="image-classification",
38-
library="pytorch",
39-
trained_dataset="imagenet",
40-
)
36+
models = hf_api.list_models(filter=["image-classification", "pytorch", "imagenet"])
4137
```
4238

4339
While filtering, you can also sort the models and take only the top results. For example,

src/huggingface_hub/hf_api.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137
_get_token_from_file,
138138
_get_token_from_google_colab,
139139
)
140-
from .utils._deprecation import _deprecate_method
140+
from .utils._deprecation import _deprecate_arguments, _deprecate_method
141141
from .utils._runtime import is_xet_available
142142
from .utils._typing import CallableT
143143
from .utils.endpoint_helpers import _is_emission_within_threshold
@@ -1855,6 +1855,9 @@ def get_dataset_tags(self) -> Dict:
18551855
hf_raise_for_status(r)
18561856
return r.json()
18571857

1858+
@_deprecate_arguments(
1859+
version="1.0", deprecated_args=["language", "library", "task", "tags"], custom_message="Use `filter` instead."
1860+
)
18581861
@validate_hf_hub_args
18591862
def list_models(
18601863
self,
@@ -1865,12 +1868,8 @@ def list_models(
18651868
gated: Optional[bool] = None,
18661869
inference: Optional[Literal["warm"]] = None,
18671870
inference_provider: Optional[Union[Literal["all"], "PROVIDER_T", List["PROVIDER_T"]]] = None,
1868-
library: Optional[Union[str, List[str]]] = None,
1869-
language: Optional[Union[str, List[str]]] = None,
18701871
model_name: Optional[str] = None,
1871-
task: Optional[Union[str, List[str]]] = None,
18721872
trained_dataset: Optional[Union[str, List[str]]] = None,
1873-
tags: Optional[Union[str, List[str]]] = None,
18741873
search: Optional[str] = None,
18751874
pipeline_tag: Optional[str] = None,
18761875
emissions_thresholds: Optional[Tuple[float, float]] = None,
@@ -1884,13 +1883,19 @@ def list_models(
18841883
cardData: bool = False,
18851884
fetch_config: bool = False,
18861885
token: Union[bool, str, None] = None,
1886+
# Deprecated arguments - use `filter` instead
1887+
language: Optional[Union[str, List[str]]] = None,
1888+
library: Optional[Union[str, List[str]]] = None,
1889+
tags: Optional[Union[str, List[str]]] = None,
1890+
task: Optional[Union[str, List[str]]] = None,
18871891
) -> Iterable[ModelInfo]:
18881892
"""
18891893
List models hosted on the Huggingface Hub, given some filters.
18901894
18911895
Args:
18921896
filter (`str` or `Iterable[str]`, *optional*):
18931897
A string or list of string to filter models on the Hub.
1898+
Models can be filtered by library, language, task, tags, and more.
18941899
author (`str`, *optional*):
18951900
A string which identify the author (user or organization) of the
18961901
returned models.
@@ -1904,23 +1909,19 @@ def list_models(
19041909
A string to filter models on the Hub that are served by a specific provider.
19051910
Pass `"all"` to get all models served by at least one provider.
19061911
library (`str` or `List`, *optional*):
1907-
A string or list of strings of foundational libraries models were
1908-
originally trained from, such as pytorch, tensorflow, or allennlp.
1912+
Deprecated. Pass a library name in `filter` to filter models by library.
19091913
language (`str` or `List`, *optional*):
1910-
A string or list of strings of languages, both by name and country
1911-
code, such as "en" or "English"
1914+
Deprecated. Pass a language in `filter` to filter models by language.
19121915
model_name (`str`, *optional*):
19131916
A string that contain complete or partial names for models on the
19141917
Hub, such as "bert" or "bert-base-cased"
19151918
task (`str` or `List`, *optional*):
1916-
A string or list of strings of tasks models were designed for, such
1917-
as: "fill-mask" or "automatic-speech-recognition"
1919+
Deprecated. Pass a task in `filter` to filter models by task.
19181920
trained_dataset (`str` or `List`, *optional*):
19191921
A string tag or a list of string tags of the trained dataset for a
19201922
model on the Hub.
19211923
tags (`str` or `List`, *optional*):
1922-
A string tag or a list of tags to filter models on the Hub by, such
1923-
as `text-generation` or `spacy`.
1924+
Deprecated. Pass tags in `filter` to filter models by tags.
19241925
search (`str`, *optional*):
19251926
A string that will be contained in the returned model ids.
19261927
pipeline_tag (`str`, *optional*):
@@ -1991,7 +1992,7 @@ def list_models(
19911992
if expand and (full or cardData or fetch_config):
19921993
raise ValueError("`expand` cannot be used if `full`, `cardData` or `fetch_config` are passed.")
19931994

1994-
if emissions_thresholds is not None and cardData is None:
1995+
if emissions_thresholds is not None and not cardData:
19951996
raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.")
19961997

19971998
path = f"{self.endpoint}/api/models"
@@ -2074,6 +2075,7 @@ def list_models(
20742075
if emissions_thresholds is None or _is_emission_within_threshold(model_info, *emissions_thresholds):
20752076
yield model_info
20762077

2078+
@_deprecate_arguments(version="1.0", deprecated_args=["tags"], custom_message="Use `filter` instead.")
20772079
@validate_hf_hub_args
20782080
def list_datasets(
20792081
self,
@@ -2088,7 +2090,6 @@ def list_datasets(
20882090
language: Optional[Union[str, List[str]]] = None,
20892091
multilinguality: Optional[Union[str, List[str]]] = None,
20902092
size_categories: Optional[Union[str, List[str]]] = None,
2091-
tags: Optional[Union[str, List[str]]] = None,
20922093
task_categories: Optional[Union[str, List[str]]] = None,
20932094
task_ids: Optional[Union[str, List[str]]] = None,
20942095
search: Optional[str] = None,
@@ -2100,6 +2101,8 @@ def list_datasets(
21002101
expand: Optional[List[ExpandDatasetProperty_T]] = None,
21012102
full: Optional[bool] = None,
21022103
token: Union[bool, str, None] = None,
2104+
# Deprecated arguments - use `filter` instead
2105+
tags: Optional[Union[str, List[str]]] = None,
21032106
) -> Iterable[DatasetInfo]:
21042107
"""
21052108
List datasets hosted on the Huggingface Hub, given some filters.
@@ -2134,7 +2137,7 @@ def list_datasets(
21342137
the Hub by the size of the dataset such as `100K<n<1M` or
21352138
`1M<n<10M`.
21362139
tags (`str` or `List`, *optional*):
2137-
A string tag or a list of tags to filter datasets on the Hub.
2140+
Deprecated. Pass tags in `filter` to filter datasets by tags.
21382141
task_categories (`str` or `List`, *optional*):
21392142
A string or list of strings that can be used to identify datasets on
21402143
the Hub by the designed task, such as `audio_classification` or

tests/test_hf_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,6 +2168,7 @@ def test_filter_datasets_with_card_data(self):
21682168
assert any(dataset.card_data is not None for dataset in self._api.list_datasets(full=True, limit=50))
21692169
assert all(dataset.card_data is None for dataset in self._api.list_datasets(full=False, limit=50))
21702170

2171+
@expect_deprecation("list_datasets")
21712172
def test_filter_datasets_by_tag(self):
21722173
for dataset in self._api.list_datasets(tags="fiftyone", limit=5):
21732174
assert "fiftyone" in dataset.tags
@@ -2278,13 +2279,15 @@ def test_failing_filter_models_by_author_and_model_name(self):
22782279
models = list(self._api.list_models(author="muellerzr", model_name="testme"))
22792280
assert len(models) == 0
22802281

2282+
@expect_deprecation("list_models")
22812283
def test_filter_models_with_library(self):
22822284
models = list(self._api.list_models(author="microsoft", model_name="wavlm-base-sd", library="tensorflow"))
22832285
assert len(models) == 0
22842286

22852287
models = list(self._api.list_models(author="microsoft", model_name="wavlm-base-sd", library="pytorch"))
22862288
assert len(models) > 0
22872289

2290+
@expect_deprecation("list_models")
22882291
def test_filter_models_with_task(self):
22892292
models = list(self._api.list_models(task="fill-mask", model_name="albert-base-v2"))
22902293
assert models[0].pipeline_tag == "fill-mask"
@@ -2295,11 +2298,13 @@ def test_filter_models_with_task(self):
22952298
models = list(self._api.list_models(task="dummytask"))
22962299
assert len(models) == 0
22972300

2301+
@expect_deprecation("list_models")
22982302
def test_filter_models_by_language(self):
22992303
for language in ["en", "fr", "zh"]:
23002304
for model in self._api.list_models(language=language, limit=5):
23012305
assert language in model.tags
23022306

2307+
@expect_deprecation("list_models")
23032308
def test_filter_models_with_tag(self):
23042309
models = list(self._api.list_models(author="HuggingFaceBR4", tags=["tensorboard"]))
23052310
assert models[0].id.startswith("HuggingFaceBR4/")

0 commit comments

Comments
 (0)