Skip to content

Commit bf00b94

Browse files
Francesco Saverio Zuppichinijulien-cmuellerzr
authored
Search by authors and string (#531)
* token in env variables * done * done * Update src/huggingface_hub/hf_api.py Co-authored-by: Julien Chaumond <[email protected]> * search by and added in * fix in test * Update src/huggingface_hub/hf_api.py Co-authored-by: Zachary Mueller <[email protected]> * Update tests/test_hf_api.py Co-authored-by: Zachary Mueller <[email protected]> * Update tests/test_hf_api.py Co-authored-by: Zachary Mueller <[email protected]> * Update src/huggingface_hub/hf_api.py Co-authored-by: Zachary Mueller <[email protected]> * quality + test Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: Zachary Mueller <[email protected]>
1 parent bb6fec0 commit bf00b94

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,8 @@ def get_dataset_tags(self) -> DatasetTags:
507507
def list_models(
508508
self,
509509
filter: Union[str, Iterable[str], None] = None,
510+
author: Optional[str] = None,
511+
search: Optional[str] = None,
510512
sort: Union[Literal["lastModified"], str, None] = None,
511513
direction: Optional[Literal[-1]] = None,
512514
limit: Optional[int] = None,
@@ -538,6 +540,30 @@ def list_models(
538540
539541
>>> # List only the models from the AllenNLP library
540542
>>> api.list_models(filter="allennlp")
543+
author (:obj:`str`, `optional`):
544+
A string which identify the author (user or organization) of the returned models
545+
Example usage:
546+
547+
>>> from huggingface_hub import HfApi
548+
>>> api = HfApi()
549+
550+
>>> # List all models from google
551+
>>> api.list_models(author="google")
552+
553+
>>> # List only the text classification models from google
554+
>>> api.list_models(filter="text-classification", author="google")
555+
search (:obj:`str`, `optional`):
556+
A string that will be contained in the returned models
557+
Example usage:
558+
559+
>>> from huggingface_hub import HfApi
560+
>>> api = HfApi()
561+
562+
>>> # List all models with "bert" in their name
563+
>>> api.list_models(search="bert")
564+
565+
>>> #List all models with "bert" in their name made by google
566+
>>> api.list_models(search="bert", author="google")
541567
sort (:obj:`Literal["lastModified"]` or :obj:`str`, `optional`):
542568
The key with which to sort the resulting models. Possible values are the properties of the `ModelInfo`
543569
class.
@@ -558,6 +584,10 @@ def list_models(
558584
if filter is not None:
559585
params.update({"filter": filter})
560586
params.update({"full": True})
587+
if author is not None:
588+
params.update({"author": author})
589+
if search is not None:
590+
params.update({"search": search})
561591
if sort is not None:
562592
params.update({"sort": sort})
563593
if direction is not None:
@@ -579,6 +609,8 @@ def list_models(
579609
def list_datasets(
580610
self,
581611
filter: Union[str, Iterable[str], None] = None,
612+
author: Optional[str] = None,
613+
search: Optional[str] = None,
582614
sort: Union[Literal["lastModified"], str, None] = None,
583615
direction: Optional[Literal[-1]] = None,
584616
limit: Optional[int] = None,
@@ -603,6 +635,30 @@ def list_datasets(
603635
604636
>>> # List only the datasets in russian for language modeling
605637
>>> api.list_datasets(filter=("languages:ru", "task_ids:language-modeling"))
638+
author (:obj:`str`, `optional`):
639+
A string which identify the author of the returned models
640+
Example usage:
641+
642+
>>> from huggingface_hub import HfApi
643+
>>> api = HfApi()
644+
645+
>>> # List all datasets from google
646+
>>> api.list_datasets(author="google")
647+
648+
>>> # List only the text classification datasets from google
649+
>>> api.list_datasets(filter="text-classification", author="google")
650+
search (:obj:`str`, `optional`):
651+
A string that will be contained in the returned models
652+
Example usage:
653+
654+
>>> from huggingface_hub import HfApi
655+
>>> api = HfApi()
656+
657+
>>> # List all datasets with "text" in their name
658+
>>> api.list_datasets(search="text")
659+
660+
>>> #List all datasets with "text" in their name made by google
661+
>>> api.list_datasets(search="text", author="google")
606662
sort (:obj:`Literal["lastModified"]` or :obj:`str`, `optional`):
607663
The key with which to sort the resulting datasets. Possible values are the properties of the `DatasetInfo`
608664
class.
@@ -619,6 +675,10 @@ def list_datasets(
619675
params = {}
620676
if filter is not None:
621677
params.update({"filter": filter})
678+
if author is not None:
679+
params.update({"author": author})
680+
if search is not None:
681+
params.update({"search": search})
622682
if sort is not None:
623683
params.update({"sort": sort})
624684
if direction is not None:

tests/test_hf_api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,22 @@ def test_list_models(self):
472472
self.assertGreater(len(models), 100)
473473
self.assertIsInstance(models[0], ModelInfo)
474474

475+
@with_production_testing
476+
def test_list_models_author(self):
477+
_api = HfApi()
478+
models = _api.list_models(author="google")
479+
self.assertGreater(len(models), 10)
480+
self.assertIsInstance(models[0], ModelInfo)
481+
[self.assertTrue("google" in model.author for model in models)]
482+
483+
@with_production_testing
484+
def test_list_models_search(self):
485+
_api = HfApi()
486+
models = _api.list_models(search="bert")
487+
self.assertGreater(len(models), 10)
488+
self.assertIsInstance(models[0], ModelInfo)
489+
[self.assertTrue("bert" in model.modelId.lower()) for model in models]
490+
475491
@with_production_testing
476492
def test_list_models_complex_query(self):
477493
# Let's list the 10 most recent models
@@ -549,6 +565,20 @@ def test_list_datasets_full(self):
549565
self.assertIsInstance(dataset, DatasetInfo)
550566
self.assertTrue(any(dataset.cardData for dataset in datasets))
551567

568+
@with_production_testing
569+
def test_list_datasets_author(self):
570+
_api = HfApi()
571+
datasets = _api.list_datasets(author="huggingface")
572+
self.assertGreater(len(datasets), 1)
573+
self.assertIsInstance(datasets[0], DatasetInfo)
574+
575+
@with_production_testing
576+
def test_list_datasets_search(self):
577+
_api = HfApi()
578+
datasets = _api.list_datasets(search="wikipedia")
579+
self.assertGreater(len(datasets), 10)
580+
self.assertIsInstance(datasets[0], DatasetInfo)
581+
552582
@with_production_testing
553583
def test_dataset_info(self):
554584
_api = HfApi()

0 commit comments

Comments
 (0)