3232 REPO_TYPES_URL_PREFIXES ,
3333 SPACES_SDK_TYPES ,
3434)
35- from .utils .tags import AttributeDictionary , DatasetTags , ModelTags
35+ from .utils .endpoint_helpers import (
36+ AttributeDictionary ,
37+ DatasetFilter ,
38+ DatasetTags ,
39+ ModelFilter ,
40+ ModelTags ,
41+ )
3642
3743
3844if sys .version_info >= (3 , 8 ):
@@ -271,7 +277,6 @@ class ModelSearchArguments(AttributeDictionary):
271277 A nested namespace object holding all possible values for properties of
272278 models currently hosted in the Hub with tab-completion.
273279 If a value starts with a number, it will only exist in the dictionary
274-
275280 Example:
276281 >>> args = ModelSearchArgs()
277282 >>> args.author_or_organization.huggingface
@@ -298,15 +303,14 @@ def clean(s: str):
298303 name = model .modelId
299304 model_name_dict [name ] = clean (name )
300305 self ["model_name" ] = model_name_dict
301- self ["author_or_organization " ] = author_dict
306+ self ["author " ] = author_dict
302307
303308
304309class DatasetSearchArguments (AttributeDictionary ):
305310 """
306311 A nested namespace object holding all possible values for properties of
307312 datasets currently hosted in the Hub with tab-completion.
308313 If a value starts with a number, it will only exist in the dictionary
309-
310314 Example:
311315 >>> args = DatasetSearchArguments()
312316 >>> args.author_or_organization.huggingface
@@ -333,7 +337,7 @@ def clean(s: str):
333337 name = dataset .id
334338 dataset_name_dict [name ] = clean (name )
335339 self ["dataset_name" ] = dataset_name_dict
336- self ["author_or_organization " ] = author_dict
340+ self ["author " ] = author_dict
337341
338342
339343def write_to_credential_store (username : str , password : str ):
@@ -506,7 +510,7 @@ def get_dataset_tags(self) -> DatasetTags:
506510
507511 def list_models (
508512 self ,
509- filter : Union [str , Iterable [str ], None ] = None ,
513+ filter : Union [ModelFilter , str , Iterable [str ], None ] = None ,
510514 author : Optional [str ] = None ,
511515 search : Optional [str ] = None ,
512516 sort : Union [Literal ["lastModified" ], str , None ] = None ,
@@ -519,8 +523,8 @@ def list_models(
519523 Get the public list of all the models on huggingface.co
520524
521525 Args:
522- filter (:obj:`str` or :class:`Iterable`, `optional`):
523- A string which can be used to identify models on the hub by their tags .
526+ filter (:class:`ModelFilter` or : obj:`str` or :class:`Iterable`, `optional`):
527+ A string or `ModelFilter` which can be used to identify models on the hub.
524528 Example usage:
525529
526530 >>> from huggingface_hub import HfApi
@@ -529,17 +533,26 @@ def list_models(
529533 >>> # List all models
530534 >>> api.list_models()
531535
536+ >>> # Get all valid search arguments
537+ >>> args = ModelSearchArguments()
538+
532539 >>> # List only the text classification models
533540 >>> api.list_models(filter="text-classification")
541+ >>> # Using the `ModelFilter`
542+ >>> filt = ModelFilter(task="text-classification")
543+ >>> # With `ModelSearchArguments`
544+ >>> filt = ModelFilter(task=args.pipeline_tags.TextClassification)
545+ >>> api.list_models(filter=filt)
534546
535- >>> # List only the russian models compatible with pytorch
536- >>> api.list_models(filter=("ru", "pytorch"))
547+ >>> # Using `ModelFilter` and `ModelSearchArguments` to find text classification in both PyTorch and TensorFlow
548+ >>> filt = ModelFilter(task=args.pipeline_tags.TextClassification, library=[args.library.PyTorch, args.library.TensorFlow])
549+ >>> api.list_models(filter=filt)
537550
538- >>> # List only the models trained on the "common_voice" dataset
539- >>> api.list_models(filter="dataset:common_voice")
540-
541- >>> # List only the models from the AllenNLP library
551+ >>> # List only models from the AllenNLP library
542552 >>> api.list_models(filter="allennlp")
553+ >>> # Using `ModelFilter` and `ModelSearchArguments`
554+ >>> filt = ModelFilter(library=args.library.allennlp)
555+
543556 author (:obj:`str`, `optional`):
544557 A string which identify the author (user or organization) of the returned models
545558 Example usage:
@@ -552,6 +565,7 @@ def list_models(
552565
553566 >>> # List only the text classification models from google
554567 >>> api.list_models(filter="text-classification", author="google")
568+
555569 search (:obj:`str`, `optional`):
556570 A string that will be contained in the returned models
557571 Example usage:
@@ -564,6 +578,7 @@ def list_models(
564578
565579 >>> #List all models with "bert" in their name made by google
566580 >>> api.list_models(search="bert", author="google")
581+
567582 sort (:obj:`Literal["lastModified"]` or :obj:`str`, `optional`):
568583 The key with which to sort the resulting models. Possible values are the properties of the `ModelInfo`
569584 class.
@@ -582,7 +597,10 @@ def list_models(
582597 path = f"{ self .endpoint } /api/models"
583598 params = {}
584599 if filter is not None :
585- params .update ({"filter" : filter })
600+ if isinstance (filter , ModelFilter ):
601+ params = self ._unpack_model_filter (filter )
602+ else :
603+ params .update ({"filter" : filter })
586604 params .update ({"full" : True })
587605 if author is not None :
588606 params .update ({"author" : author })
@@ -606,9 +624,69 @@ def list_models(
606624 d = r .json ()
607625 return [ModelInfo (** x ) for x in d ]
608626
627+ def _unpack_model_filter (self , model_filter : ModelFilter ):
628+ """
629+ Unpacks a `ModelFilter` into something readable for `list_models`
630+ """
631+ model_str = ""
632+ tags = []
633+
634+ # Handling author
635+ if model_filter .author is not None :
636+ model_str = f"{ model_filter .author } /"
637+
638+ # Handling model_name
639+ if model_filter .model_name is not None :
640+ model_str += model_filter .model_name
641+
642+ filter_tuple = []
643+
644+ # Handling tasks
645+ if model_filter .task is not None :
646+ filter_tuple .extend (
647+ [model_filter .task ]
648+ if isinstance (model_filter .task , str )
649+ else model_filter .task
650+ )
651+
652+ # Handling dataset
653+ if model_filter .trained_dataset is not None :
654+ if not isinstance (model_filter .trained_dataset , (list , tuple )):
655+ model_filter .trained_dataset = [model_filter .trained_dataset ]
656+ for dataset in model_filter .trained_dataset :
657+ if "dataset:" not in dataset :
658+ dataset = f"dataset:{ dataset } "
659+ filter_tuple .append (dataset )
660+
661+ # Handling library
662+ if model_filter .library :
663+ filter_tuple .extend (
664+ [model_filter .library ]
665+ if isinstance (model_filter .library , str )
666+ else model_filter .library
667+ )
668+
669+ # Handling tags
670+ if model_filter .tags :
671+ tags .extend (
672+ [model_filter .tags ]
673+ if isinstance (model_filter .tags , str )
674+ else model_filter .tags
675+ )
676+
677+ query_dict = {}
678+ if model_str is not None :
679+ query_dict ["search" ] = model_str
680+ if len (tags ) > 0 :
681+ query_dict ["tags" ] = tags
682+ if model_filter .language is not None :
683+ filter_tuple .append (model_filter .language )
684+ query_dict ["filter" ] = tuple (filter_tuple )
685+ return query_dict
686+
609687 def list_datasets (
610688 self ,
611- filter : Union [str , Iterable [str ], None ] = None ,
689+ filter : Union [DatasetFilter , str , Iterable [str ], None ] = None ,
612690 author : Optional [str ] = None ,
613691 search : Optional [str ] = None ,
614692 sort : Union [Literal ["lastModified" ], str , None ] = None ,
@@ -620,21 +698,36 @@ def list_datasets(
620698 Get the public list of all the datasets on huggingface.co
621699
622700 Args:
623- filter (:obj:`str` or :class:`Iterable`, `optional`):
624- A string which can be used to identify datasets on the hub by their tags .
701+ filter (:class:`DatasetFilter` or : obj:`str` or :class:`Iterable`, `optional`):
702+ A string or `DatasetFilter` which can be used to identify datasets on the hub.
625703 Example usage:
626704
705+
627706 >>> from huggingface_hub import HfApi
628707 >>> api = HfApi()
629708
630709 >>> # List all datasets
631710 >>> api.list_datasets()
632711
712+ >>> # Get all valid search arguments
713+ >>> args = DatasetSearchArguments()
714+
633715 >>> # List only the text classification datasets
634716 >>> api.list_datasets(filter="task_categories:text-classification")
717+ >>> # Using the `DatasetFilter`
718+ >>> filt = DatasetFilter(task_categories="text-classification")
719+ >>> # With `DatasetSearchArguments`
720+ >>> filt = DatasetFilter(task=args.task_categories.text_classification)
721+ >>> api.list_models(filter=filt)
635722
636723 >>> # List only the datasets in russian for language modeling
637724 >>> api.list_datasets(filter=("languages:ru", "task_ids:language-modeling"))
725+ >>> # Using the `DatasetFilter`
726+ >>> filt = DatasetFilter(languages="ru", task_ids="language-modeling")
727+ >>> # With `DatasetSearchArguments`
728+ >>> filt = DatasetFilter(languages=args.languages.ru, task_ids=args.task_ids.language_modeling)
729+ >>> api.list_datasets(filter=filt)
730+
638731 author (:obj:`str`, `optional`):
639732 A string which identify the author of the returned models
640733 Example usage:
@@ -647,6 +740,7 @@ def list_datasets(
647740
648741 >>> # List only the text classification datasets from google
649742 >>> api.list_datasets(filter="text-classification", author="google")
743+
650744 search (:obj:`str`, `optional`):
651745 A string that will be contained in the returned models
652746 Example usage:
@@ -659,6 +753,7 @@ def list_datasets(
659753
660754 >>> #List all datasets with "text" in their name made by google
661755 >>> api.list_datasets(search="text", author="google")
756+
662757 sort (:obj:`Literal["lastModified"]` or :obj:`str`, `optional`):
663758 The key with which to sort the resulting datasets. Possible values are the properties of the `DatasetInfo`
664759 class.
@@ -674,7 +769,10 @@ def list_datasets(
674769 path = f"{ self .endpoint } /api/datasets"
675770 params = {}
676771 if filter is not None :
677- params .update ({"filter" : filter })
772+ if isinstance (filter , DatasetFilter ):
773+ params = self ._unpack_dataset_filter (filter )
774+ else :
775+ params .update ({"filter" : filter })
678776 if author is not None :
679777 params .update ({"author" : author })
680778 if search is not None :
@@ -693,6 +791,47 @@ def list_datasets(
693791 d = r .json ()
694792 return [DatasetInfo (** x ) for x in d ]
695793
794+ def _unpack_dataset_filter (self , dataset_filter : DatasetFilter ):
795+ """
796+ Unpacks a `DatasetFilter` into something readable for `list_datasets`
797+ """
798+ dataset_str = ""
799+
800+ # Handling author
801+ if dataset_filter .author is not None :
802+ dataset_str = f"{ dataset_filter .author } /"
803+
804+ # Handling dataset_name
805+ if dataset_filter .dataset_name is not None :
806+ dataset_str += dataset_filter .dataset_name
807+
808+ filter_tuple = []
809+ data_attributes = [
810+ "benchmark" ,
811+ "language_creators" ,
812+ "languages" ,
813+ "multilinguality" ,
814+ "size_categories" ,
815+ "task_categories" ,
816+ "task_ids" ,
817+ ]
818+
819+ for attr in data_attributes :
820+ curr_attr = getattr (dataset_filter , attr )
821+ if curr_attr is not None :
822+ if not isinstance (curr_attr , (list , tuple )):
823+ curr_attr = [curr_attr ]
824+ for data in curr_attr :
825+ if f"{ attr } :" not in data :
826+ data = f"{ attr } :{ data } "
827+ filter_tuple .append (data )
828+
829+ query_dict = {}
830+ if dataset_str is not None :
831+ query_dict ["search" ] = dataset_str
832+ query_dict ["filter" ] = tuple (filter_tuple )
833+ return query_dict
834+
696835 def list_metrics (self ) -> List [MetricInfo ]:
697836 """
698837 Get the public list of all the metrics on huggingface.co
0 commit comments