1919import warnings
2020from io import BufferedIOBase , RawIOBase
2121from os .path import expanduser
22- from typing import BinaryIO , Dict , List , Optional , Tuple , Union
22+ from typing import BinaryIO , Dict , Iterable , List , Literal , Optional , Tuple , Union
2323
2424import requests
2525from requests .exceptions import HTTPError
@@ -63,6 +63,7 @@ def __init__(
6363 self ,
6464 modelId : Optional [str ] = None , # id of model
6565 sha : Optional [str ] = None , # commit sha at the specified revision
66+ lastModified : Optional [str ] = None , # date of last commit to repo
6667 tags : List [str ] = [],
6768 pipeline_tag : Optional [str ] = None ,
6869 siblings : Optional [
@@ -72,6 +73,7 @@ def __init__(
7273 ):
7374 self .modelId = modelId
7475 self .sha = sha
76+ self .lastModified = lastModified
7577 self .tags = tags
7678 self .pipeline_tag = pipeline_tag
7779 self .siblings = (
@@ -129,7 +131,14 @@ def logout(self, token: str) -> None:
129131 r = requests .post (path , headers = {"authorization" : "Bearer {}" .format (token )})
130132 r .raise_for_status ()
131133
132- def list_models (self , filter : Optional [str ] = None ) -> List [ModelInfo ]:
134+ def list_models (
135+ self ,
136+ filter : Union [str , Iterable [str ], None ] = None ,
137+ sort : Optional [str ] = None ,
138+ direction : Optional [Literal [- 1 ]] = None ,
139+ limit : Optional [int ] = None ,
140+ full : Optional [bool ] = None ,
141+ ) -> List [ModelInfo ]:
133142 """
134143 Get the public list of all the models on huggingface.co
135144
@@ -147,8 +156,8 @@ def list_models(self, filter: Optional[str] = None) -> List[ModelInfo]:
147156 >>> # List only the text classification models
148157 >>> api.list_models(filter="text-classification")
149158
150- >>> # List only the russian models
151- >>> api.list_models(filter="ru")
159+ >>> # List only the russian models compatible with pytorch
160+ >>> api.list_models(filter=( "ru", "pytorch") )
152161
153162 >>> # List only the models trained on the "common_voice" dataset
154163 >>> api.list_models(filter="dataset:common_voice")
@@ -157,7 +166,18 @@ def list_models(self, filter: Optional[str] = None) -> List[ModelInfo]:
157166 >>> api.list_models(filter="allennlp")
158167 """
159168 path = "{}/api/models" .format (self .endpoint )
160- params = {"filter" : filter , "full" : True } if filter is not None else None
169+ params = {}
170+ if filter is not None :
171+ params .update ({"filter" : filter })
172+ params .update ({"full" : True })
173+ if sort is not None :
174+ params .update ({"sort" : sort })
175+ if direction is not None :
176+ params .update ({"direction" : direction })
177+ if limit is not None :
178+ params .update ({"limit" : limit })
179+ if full is not None :
180+ params .update ({"full" : full })
161181 r = requests .get (path , params = params )
162182 r .raise_for_status ()
163183 d = r .json ()
0 commit comments