Skip to content

Commit 44e6bd4

Browse files
julien-cLysandreJik
authored andcommitted
list_models: more advanced API capabilities
1 parent 9a4042b commit 44e6bd4

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def get_version() -> str:
1515
"filelock",
1616
"requests",
1717
"tqdm",
18+
"typing-extensions",
1819
"importlib_metadata;python_version<'3.8'",
1920
]
2021

src/huggingface_hub/hf_api.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import warnings
2020
from io import BufferedIOBase, RawIOBase
2121
from os.path import expanduser
22-
from typing import BinaryIO, Dict, List, Optional, Tuple, Union
22+
from typing import BinaryIO, Dict, Iterable, List, Literal, Optional, Tuple, Union
2323

2424
import requests
2525
from requests.exceptions import HTTPError
@@ -63,6 +63,7 @@ def __init__(
6363
self,
6464
modelId: Optional[str] = None, # id of model
6565
sha: Optional[str] = None, # commit sha at the specified revision
66+
lastModified: Optional[str] = None, # date of last commit to repo
6667
tags: List[str] = [],
6768
pipeline_tag: Optional[str] = None,
6869
siblings: Optional[
@@ -72,6 +73,7 @@ def __init__(
7273
):
7374
self.modelId = modelId
7475
self.sha = sha
76+
self.lastModified = lastModified
7577
self.tags = tags
7678
self.pipeline_tag = pipeline_tag
7779
self.siblings = (
@@ -129,7 +131,14 @@ def logout(self, token: str) -> None:
129131
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
130132
r.raise_for_status()
131133

132-
def list_models(self, filter: Optional[str] = None) -> List[ModelInfo]:
134+
def list_models(
135+
self,
136+
filter: Union[str, Iterable[str], None] = None,
137+
sort: Optional[str] = None,
138+
direction: Optional[Literal[-1]] = None,
139+
limit: Optional[int] = None,
140+
full: Optional[bool] = None,
141+
) -> List[ModelInfo]:
133142
"""
134143
Get the public list of all the models on huggingface.co
135144
@@ -147,8 +156,8 @@ def list_models(self, filter: Optional[str] = None) -> List[ModelInfo]:
147156
>>> # List only the text classification models
148157
>>> api.list_models(filter="text-classification")
149158
150-
>>> # List only the russian models
151-
>>> api.list_models(filter="ru")
159+
>>> # List only the russian models compatible with pytorch
160+
>>> api.list_models(filter=("ru", "pytorch"))
152161
153162
>>> # List only the models trained on the "common_voice" dataset
154163
>>> api.list_models(filter="dataset:common_voice")
@@ -157,7 +166,18 @@ def list_models(self, filter: Optional[str] = None) -> List[ModelInfo]:
157166
>>> api.list_models(filter="allennlp")
158167
"""
159168
path = "{}/api/models".format(self.endpoint)
160-
params = {"filter": filter, "full": True} if filter is not None else None
169+
params = {}
170+
if filter is not None:
171+
params.update({"filter": filter})
172+
params.update({"full": True})
173+
if sort is not None:
174+
params.update({"sort": sort})
175+
if direction is not None:
176+
params.update({"direction": direction})
177+
if limit is not None:
178+
params.update({"limit": limit})
179+
if full is not None:
180+
params.update({"full": full})
161181
r = requests.get(path, params=params)
162182
r.raise_for_status()
163183
d = r.json()

0 commit comments

Comments
 (0)