Skip to content

Commit 6d360e3

Browse files
authored
Make ModelSearchArguments and DatasetSearchArguments more robust (#1300)
* Make ModelTags and DatasetTags more robust to server-side changes * add warnings about DatasetSearchArguments and ModelSearchArguments inefficiency * add another warning in doc
1 parent 070d5d6 commit 6d360e3

File tree

4 files changed

+74
-36
lines changed

4 files changed

+74
-36
lines changed

docs/source/searching-the-hub.mdx

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ The `huggingface_hub` provides a user-friendly interface to know what exactly ca
4545

4646
These are nested namespace objects that have **every single option** available on the Hub and that will return what should be passed to `filter`. The best of all is: it has tab completion 🎊 .
4747

48+
<Tip warning={true}>
49+
50+
[`ModelSearchArguments`] and [`DatasetSearchArguments`] are legacy helpers meant for exploratory
51+
purposes only. Their initialization require listing all models and datasets on the Hub which
52+
makes them increasingly slower as the number of repos on the Hub increases. For some production-ready code,
53+
consider passing raw strings when making a filtered search on the Hub.
54+
55+
</Tip>
56+
4857
## Searching for a Model
4958

5059
Let's pose a problem that would be complicated to solve without access to this information:

src/huggingface_hub/hf_api.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,21 @@ class ModelSearchArguments(AttributeDictionary):
583583
584584
```python
585585
>>> args = ModelSearchArguments()
586-
>>> args.author_or_organization.huggingface
586+
587+
>>> args.author.huggingface
588+
'huggingface'
589+
587590
>>> args.language.en
591+
'en'
588592
```
593+
594+
<Tip warning={true}>
595+
596+
`ModelSearchArguments` is a legacy class meant for exploratory purposes only. Its
597+
initialization requires listing all models on the Hub which makes it increasingly
598+
slower as the number of repos on the Hub increases.
599+
600+
</Tip>
589601
"""
590602

591603
def __init__(self, api: Optional["HfApi"] = None):
@@ -621,9 +633,21 @@ class DatasetSearchArguments(AttributeDictionary):
621633
622634
```python
623635
>>> args = DatasetSearchArguments()
624-
>>> args.author_or_organization.huggingface
636+
637+
>>> args.author.huggingface
638+
'huggingface'
639+
625640
>>> args.language.en
641+
'language:en'
626642
```
643+
644+
<Tip warning={true}>
645+
646+
`DatasetSearchArguments` is a legacy class meant for exploratory purposes only. Its
647+
initialization requires listing all datasets on the Hub which makes it increasingly
648+
slower as the number of repos on the Hub increases.
649+
650+
</Tip>
627651
"""
628652

629653
def __init__(self, api: Optional["HfApi"] = None):

src/huggingface_hub/utils/endpoint_helpers.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,13 @@ def __init__(self, tag_dictionary: dict, keys: Optional[list] = None):
300300
self._unpack_and_assign_dictionary(key)
301301

302302
def _unpack_and_assign_dictionary(self, key: str):
303-
"Assignes nested attributes to `self.key` containing information as an `AttributeDictionary`"
304-
setattr(self, key, AttributeDictionary())
305-
for item in self._tag_dictionary[key]:
306-
ref = getattr(self, key)
307-
item["label"] = (
308-
item["label"].replace(" ", "").replace("-", "_").replace(".", "_")
309-
)
310-
setattr(ref, item["label"], item["id"])
303+
"Assign nested attributes to `self.key` containing information as an `AttributeDictionary`"
304+
ref = AttributeDictionary()
305+
setattr(self, key, ref)
306+
for item in self._tag_dictionary.get(key, []):
307+
label = item["label"].replace(" ", "").replace("-", "_").replace(".", "_")
308+
ref[label] = item["id"]
309+
self[key] = ref
311310

312311

313312
class ModelTags(GeneralTags):

tests/test_endpoint_helpers.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -144,38 +144,44 @@ def test_filter(self):
144144
class ModelTagsTest(unittest.TestCase):
145145
@with_production_testing
146146
def test_tags(self):
147-
_api = HfApi()
148-
path = f"{_api.endpoint}/api/models-tags-by-type"
149-
r = requests.get(path)
150-
r.raise_for_status()
151-
d = r.json()
152-
o = ModelTags(d)
153-
for kind in ["library", "language", "license", "dataset", "pipeline_tag"]:
154-
self.assertTrue(len(getattr(o, kind).keys()) > 0)
147+
# ModelTags instantiation must not fail!
148+
res = requests.get(f"{HfApi().endpoint}/api/models-tags-by-type")
149+
res.raise_for_status()
150+
tags = ModelTags(res.json())
151+
152+
# Check existing keys to get notified about server-side changes
153+
for existing_key in [
154+
"dataset",
155+
"language",
156+
"library",
157+
"license",
158+
"pipeline_tag",
159+
]:
160+
self.assertGreater(len(getattr(tags, existing_key).keys()), 0)
155161

156162

157163
class DatasetTagsTest(unittest.TestCase):
158-
@unittest.skip(
159-
"DatasetTags is currently broken. See"
160-
" https://github.com/huggingface/huggingface_hub/pull/1250. Skip test until"
161-
" it's fixed."
162-
)
163164
@with_production_testing
164165
def test_tags(self):
165-
_api = HfApi()
166-
path = f"{_api.endpoint}/api/datasets-tags-by-type"
167-
r = requests.get(path)
168-
r.raise_for_status()
169-
d = r.json()
170-
o = DatasetTags(d)
171-
for kind in [
172-
"language",
173-
"multilinguality",
166+
# DatasetTags instantiation must not fail!
167+
res = requests.get(f"{HfApi().endpoint}/api/datasets-tags-by-type")
168+
res.raise_for_status()
169+
tags = DatasetTags(res.json())
170+
171+
# Some keys existed before but have been removed server-side
172+
for missing_key in (
174173
"language_creators",
175-
"task_categories",
176-
"size_categories",
174+
"multilinguality",
175+
):
176+
self.assertEqual(len(getattr(tags, missing_key).keys()), 0)
177+
178+
# Check existing keys to get notified about server-side changes
179+
for existing_key in [
177180
"benchmark",
178-
"task_ids",
181+
"language",
179182
"license",
183+
"size_categories",
184+
"task_categories",
185+
"task_ids",
180186
]:
181-
self.assertTrue(len(getattr(o, kind).keys()) > 0)
187+
self.assertGreater(len(getattr(tags, existing_key).keys()), 0)

0 commit comments

Comments
 (0)