Skip to content

Commit 5a48cf2

Browse files
authored
Implement a Model Filter class (#553)
* Update error on tests * Implement model search by Filter * Fix init import * Clean * Rename and fix imports * api -> endpoint * Fix test * New version * rm carbon emissions for now * Fixup docs * clean * author_or_organization -> author * Keep consistent * Clean up logic for model filter * Finish implementation, need to write tests * Proper doc * Update init * Add all tests * partially there * Better docstring * Fix test * Add tests with filter * typo fix * Flip query * Clean * Include doc example * framework -> library * Clean up filter if/else * hub -> Hub * List of strings doc * Fixup tests * Passing
1 parent bf00b94 commit 5a48cf2

File tree

6 files changed

+619
-164
lines changed

6 files changed

+619
-164
lines changed

src/huggingface_hub/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@
3232
)
3333
from .file_download import cached_download, hf_hub_download, hf_hub_url
3434
from .hf_api import (
35+
DatasetSearchArguments,
3536
HfApi,
3637
HfFolder,
38+
ModelSearchArguments,
3739
create_repo,
3840
dataset_info,
3941
delete_file,
@@ -65,3 +67,4 @@
6567
from .repository import Repository
6668
from .snapshot_download import snapshot_download
6769
from .utils import logging
70+
from .utils.endpoint_helpers import DatasetFilter, ModelFilter

src/huggingface_hub/hf_api.py

Lines changed: 158 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@
3232
REPO_TYPES_URL_PREFIXES,
3333
SPACES_SDK_TYPES,
3434
)
35-
from .utils.tags import AttributeDictionary, DatasetTags, ModelTags
35+
from .utils.endpoint_helpers import (
36+
AttributeDictionary,
37+
DatasetFilter,
38+
DatasetTags,
39+
ModelFilter,
40+
ModelTags,
41+
)
3642

3743

3844
if sys.version_info >= (3, 8):
@@ -271,7 +277,6 @@ class ModelSearchArguments(AttributeDictionary):
271277
A nested namespace object holding all possible values for properties of
272278
models currently hosted in the Hub with tab-completion.
273279
If a value starts with a number, it will only exist in the dictionary
274-
275280
Example:
276281
>>> args = ModelSearchArgs()
277282
>>> args.author_or_organization.huggingface
@@ -298,15 +303,14 @@ def clean(s: str):
298303
name = model.modelId
299304
model_name_dict[name] = clean(name)
300305
self["model_name"] = model_name_dict
301-
self["author_or_organization"] = author_dict
306+
self["author"] = author_dict
302307

303308

304309
class DatasetSearchArguments(AttributeDictionary):
305310
"""
306311
A nested namespace object holding all possible values for properties of
307312
datasets currently hosted in the Hub with tab-completion.
308313
If a value starts with a number, it will only exist in the dictionary
309-
310314
Example:
311315
>>> args = DatasetSearchArguments()
312316
>>> args.author_or_organization.huggingface
@@ -333,7 +337,7 @@ def clean(s: str):
333337
name = dataset.id
334338
dataset_name_dict[name] = clean(name)
335339
self["dataset_name"] = dataset_name_dict
336-
self["author_or_organization"] = author_dict
340+
self["author"] = author_dict
337341

338342

339343
def write_to_credential_store(username: str, password: str):
@@ -506,7 +510,7 @@ def get_dataset_tags(self) -> DatasetTags:
506510

507511
def list_models(
508512
self,
509-
filter: Union[str, Iterable[str], None] = None,
513+
filter: Union[ModelFilter, str, Iterable[str], None] = None,
510514
author: Optional[str] = None,
511515
search: Optional[str] = None,
512516
sort: Union[Literal["lastModified"], str, None] = None,
@@ -519,8 +523,8 @@ def list_models(
519523
Get the public list of all the models on huggingface.co
520524
521525
Args:
522-
filter (:obj:`str` or :class:`Iterable`, `optional`):
523-
A string which can be used to identify models on the hub by their tags.
526+
filter (:class:`ModelFilter` or :obj:`str` or :class:`Iterable`, `optional`):
527+
A string or `ModelFilter` which can be used to identify models on the hub.
524528
Example usage:
525529
526530
>>> from huggingface_hub import HfApi
@@ -529,17 +533,26 @@ def list_models(
529533
>>> # List all models
530534
>>> api.list_models()
531535
536+
>>> # Get all valid search arguments
537+
>>> args = ModelSearchArguments()
538+
532539
>>> # List only the text classification models
533540
>>> api.list_models(filter="text-classification")
541+
>>> # Using the `ModelFilter`
542+
>>> filt = ModelFilter(task="text-classification")
543+
>>> # With `ModelSearchArguments`
544+
>>> filt = ModelFilter(task=args.pipeline_tags.TextClassification)
545+
>>> api.list_models(filter=filt)
534546
535-
>>> # List only the russian models compatible with pytorch
536-
>>> api.list_models(filter=("ru", "pytorch"))
547+
>>> # Using `ModelFilter` and `ModelSearchArguments` to find text classification in both PyTorch and TensorFlow
548+
>>> filt = ModelFilter(task=args.pipeline_tags.TextClassification, library=[args.library.PyTorch, args.library.TensorFlow])
549+
>>> api.list_models(filter=filt)
537550
538-
>>> # List only the models trained on the "common_voice" dataset
539-
>>> api.list_models(filter="dataset:common_voice")
540-
541-
>>> # List only the models from the AllenNLP library
551+
>>> # List only models from the AllenNLP library
542552
>>> api.list_models(filter="allennlp")
553+
>>> # Using `ModelFilter` and `ModelSearchArguments`
554+
>>> filt = ModelFilter(library=args.library.allennlp)
555+
543556
author (:obj:`str`, `optional`):
544557
A string which identify the author (user or organization) of the returned models
545558
Example usage:
@@ -552,6 +565,7 @@ def list_models(
552565
553566
>>> # List only the text classification models from google
554567
>>> api.list_models(filter="text-classification", author="google")
568+
555569
search (:obj:`str`, `optional`):
556570
A string that will be contained in the returned models
557571
Example usage:
@@ -564,6 +578,7 @@ def list_models(
564578
565579
>>> #List all models with "bert" in their name made by google
566580
>>> api.list_models(search="bert", author="google")
581+
567582
sort (:obj:`Literal["lastModified"]` or :obj:`str`, `optional`):
568583
The key with which to sort the resulting models. Possible values are the properties of the `ModelInfo`
569584
class.
@@ -582,7 +597,10 @@ def list_models(
582597
path = f"{self.endpoint}/api/models"
583598
params = {}
584599
if filter is not None:
585-
params.update({"filter": filter})
600+
if isinstance(filter, ModelFilter):
601+
params = self._unpack_model_filter(filter)
602+
else:
603+
params.update({"filter": filter})
586604
params.update({"full": True})
587605
if author is not None:
588606
params.update({"author": author})
@@ -606,9 +624,69 @@ def list_models(
606624
d = r.json()
607625
return [ModelInfo(**x) for x in d]
608626

627+
def _unpack_model_filter(self, model_filter: ModelFilter):
628+
"""
629+
Unpacks a `ModelFilter` into something readable for `list_models`
630+
"""
631+
model_str = ""
632+
tags = []
633+
634+
# Handling author
635+
if model_filter.author is not None:
636+
model_str = f"{model_filter.author}/"
637+
638+
# Handling model_name
639+
if model_filter.model_name is not None:
640+
model_str += model_filter.model_name
641+
642+
filter_tuple = []
643+
644+
# Handling tasks
645+
if model_filter.task is not None:
646+
filter_tuple.extend(
647+
[model_filter.task]
648+
if isinstance(model_filter.task, str)
649+
else model_filter.task
650+
)
651+
652+
# Handling dataset
653+
if model_filter.trained_dataset is not None:
654+
if not isinstance(model_filter.trained_dataset, (list, tuple)):
655+
model_filter.trained_dataset = [model_filter.trained_dataset]
656+
for dataset in model_filter.trained_dataset:
657+
if "dataset:" not in dataset:
658+
dataset = f"dataset:{dataset}"
659+
filter_tuple.append(dataset)
660+
661+
# Handling library
662+
if model_filter.library:
663+
filter_tuple.extend(
664+
[model_filter.library]
665+
if isinstance(model_filter.library, str)
666+
else model_filter.library
667+
)
668+
669+
# Handling tags
670+
if model_filter.tags:
671+
tags.extend(
672+
[model_filter.tags]
673+
if isinstance(model_filter.tags, str)
674+
else model_filter.tags
675+
)
676+
677+
query_dict = {}
678+
if model_str is not None:
679+
query_dict["search"] = model_str
680+
if len(tags) > 0:
681+
query_dict["tags"] = tags
682+
if model_filter.language is not None:
683+
filter_tuple.append(model_filter.language)
684+
query_dict["filter"] = tuple(filter_tuple)
685+
return query_dict
686+
609687
def list_datasets(
610688
self,
611-
filter: Union[str, Iterable[str], None] = None,
689+
filter: Union[DatasetFilter, str, Iterable[str], None] = None,
612690
author: Optional[str] = None,
613691
search: Optional[str] = None,
614692
sort: Union[Literal["lastModified"], str, None] = None,
@@ -620,21 +698,36 @@ def list_datasets(
620698
Get the public list of all the datasets on huggingface.co
621699
622700
Args:
623-
filter (:obj:`str` or :class:`Iterable`, `optional`):
624-
A string which can be used to identify datasets on the hub by their tags.
701+
filter (:class:`DatasetFilter` or :obj:`str` or :class:`Iterable`, `optional`):
702+
A string or `DatasetFilter` which can be used to identify datasets on the hub.
625703
Example usage:
626704
705+
627706
>>> from huggingface_hub import HfApi
628707
>>> api = HfApi()
629708
630709
>>> # List all datasets
631710
>>> api.list_datasets()
632711
712+
>>> # Get all valid search arguments
713+
>>> args = DatasetSearchArguments()
714+
633715
>>> # List only the text classification datasets
634716
>>> api.list_datasets(filter="task_categories:text-classification")
717+
>>> # Using the `DatasetFilter`
718+
>>> filt = DatasetFilter(task_categories="text-classification")
719+
>>> # With `DatasetSearchArguments`
720+
>>> filt = DatasetFilter(task=args.task_categories.text_classification)
721+
>>> api.list_models(filter=filt)
635722
636723
>>> # List only the datasets in russian for language modeling
637724
>>> api.list_datasets(filter=("languages:ru", "task_ids:language-modeling"))
725+
>>> # Using the `DatasetFilter`
726+
>>> filt = DatasetFilter(languages="ru", task_ids="language-modeling")
727+
>>> # With `DatasetSearchArguments`
728+
>>> filt = DatasetFilter(languages=args.languages.ru, task_ids=args.task_ids.language_modeling)
729+
>>> api.list_datasets(filter=filt)
730+
638731
author (:obj:`str`, `optional`):
639732
A string which identify the author of the returned models
640733
Example usage:
@@ -647,6 +740,7 @@ def list_datasets(
647740
648741
>>> # List only the text classification datasets from google
649742
>>> api.list_datasets(filter="text-classification", author="google")
743+
650744
search (:obj:`str`, `optional`):
651745
A string that will be contained in the returned models
652746
Example usage:
@@ -659,6 +753,7 @@ def list_datasets(
659753
660754
>>> #List all datasets with "text" in their name made by google
661755
>>> api.list_datasets(search="text", author="google")
756+
662757
sort (:obj:`Literal["lastModified"]` or :obj:`str`, `optional`):
663758
The key with which to sort the resulting datasets. Possible values are the properties of the `DatasetInfo`
664759
class.
@@ -674,7 +769,10 @@ def list_datasets(
674769
path = f"{self.endpoint}/api/datasets"
675770
params = {}
676771
if filter is not None:
677-
params.update({"filter": filter})
772+
if isinstance(filter, DatasetFilter):
773+
params = self._unpack_dataset_filter(filter)
774+
else:
775+
params.update({"filter": filter})
678776
if author is not None:
679777
params.update({"author": author})
680778
if search is not None:
@@ -693,6 +791,47 @@ def list_datasets(
693791
d = r.json()
694792
return [DatasetInfo(**x) for x in d]
695793

794+
def _unpack_dataset_filter(self, dataset_filter: DatasetFilter):
795+
"""
796+
Unpacks a `DatasetFilter` into something readable for `list_datasets`
797+
"""
798+
dataset_str = ""
799+
800+
# Handling author
801+
if dataset_filter.author is not None:
802+
dataset_str = f"{dataset_filter.author}/"
803+
804+
# Handling dataset_name
805+
if dataset_filter.dataset_name is not None:
806+
dataset_str += dataset_filter.dataset_name
807+
808+
filter_tuple = []
809+
data_attributes = [
810+
"benchmark",
811+
"language_creators",
812+
"languages",
813+
"multilinguality",
814+
"size_categories",
815+
"task_categories",
816+
"task_ids",
817+
]
818+
819+
for attr in data_attributes:
820+
curr_attr = getattr(dataset_filter, attr)
821+
if curr_attr is not None:
822+
if not isinstance(curr_attr, (list, tuple)):
823+
curr_attr = [curr_attr]
824+
for data in curr_attr:
825+
if f"{attr}:" not in data:
826+
data = f"{attr}:{data}"
827+
filter_tuple.append(data)
828+
829+
query_dict = {}
830+
if dataset_str is not None:
831+
query_dict["search"] = dataset_str
832+
query_dict["filter"] = tuple(filter_tuple)
833+
return query_dict
834+
696835
def list_metrics(self) -> List[MetricInfo]:
697836
"""
698837
Get the public list of all the metrics on huggingface.co

0 commit comments

Comments
 (0)