2828from itertools import islice
2929from pathlib import Path
3030from typing import (
31+ TYPE_CHECKING ,
3132 Any ,
3233 BinaryIO ,
3334 Callable ,
135136from .utils .endpoint_helpers import _is_emission_within_threshold
136137
137138
139+ if TYPE_CHECKING :
140+ from .inference ._providers import PROVIDER_T
141+
138142R = TypeVar ("R" ) # Return type
139143CollectionItemType_T = Literal ["model" , "dataset" , "space" , "paper" , "collection" ]
140144
@@ -709,21 +713,26 @@ def __init__(self, **kwargs):
709713
710714@dataclass
711715class InferenceProviderMapping :
712- hf_model_id : str
716+ provider : "PROVIDER_T" # Provider name
717+ hf_model_id : str # ID of the model on the Hugging Face Hub
718+ provider_id : str # ID of the model on the provider's side
713719 status : Literal ["live" , "staging" ]
714- provider_id : str
715720 task : str
716721
717722 adapter : Optional [str ] = None
718723 adapter_weights_path : Optional [str ] = None
724+ type : Optional [Literal ["single-model" , "tag-filter" ]] = None
719725
720726 def __init__ (self , ** kwargs ):
727+ self .provider = kwargs .pop ("provider" )
721728 self .hf_model_id = kwargs .pop ("hf_model_id" )
722- self .status = kwargs .pop ("status" )
723729 self .provider_id = kwargs .pop ("providerId" )
730+ self .status = kwargs .pop ("status" )
724731 self .task = kwargs .pop ("task" )
732+
725733 self .adapter = kwargs .pop ("adapter" , None )
726734 self .adapter_weights_path = kwargs .pop ("adapterWeightsPath" , None )
735+ self .type = kwargs .pop ("type" , None )
727736 self .__dict__ .update (** kwargs )
728737
729738
@@ -765,12 +774,10 @@ class ModelInfo:
765774 If so, whether there is manual or automatic approval.
766775 gguf (`Dict`, *optional*):
767776 GGUF information of the model.
768- inference (`Literal["cold", "frozen", "warm"]`, *optional*):
769- Status of the model on the inference API.
770- Warm models are available for immediate use. Cold models will be loaded on first inference call.
771- Frozen models are not available in Inference API.
772- inference_provider_mapping (`Dict`, *optional*):
773- Model's inference provider mapping.
777+ inference (`Literal["warm"]`, *optional*):
778+ Status of the model on Inference Providers. Warm if the model is served by at least one provider.
779+ inference_provider_mapping (`List[InferenceProviderMapping]`, *optional*):
780+ A list of [`InferenceProviderMapping`] ordered after the user's provider order.
774781 likes (`int`):
775782 Number of likes of the model.
776783 library_name (`str`, *optional*):
@@ -815,8 +822,8 @@ class ModelInfo:
815822 downloads_all_time : Optional [int ]
816823 gated : Optional [Literal ["auto" , "manual" , False ]]
817824 gguf : Optional [Dict ]
818- inference : Optional [Literal ["warm" , "cold" , "frozen" ]]
819- inference_provider_mapping : Optional [Dict [ str , InferenceProviderMapping ]]
825+ inference : Optional [Literal ["warm" ]]
826+ inference_provider_mapping : Optional [List [ InferenceProviderMapping ]]
820827 likes : Optional [int ]
821828 library_name : Optional [str ]
822829 tags : Optional [List [str ]]
@@ -852,14 +859,25 @@ def __init__(self, **kwargs):
852859 self .gguf = kwargs .pop ("gguf" , None )
853860
854861 self .inference = kwargs .pop ("inference" , None )
855- self .inference_provider_mapping = kwargs .pop ("inferenceProviderMapping" , None )
856- if self .inference_provider_mapping :
857- self .inference_provider_mapping = {
858- provider : InferenceProviderMapping (
859- ** {** value , "hf_model_id" : self .id }
860- ) # little hack to simplify Inference Providers logic
861- for provider , value in self .inference_provider_mapping .items ()
862- }
862+
863+ # little hack to simplify Inference Providers logic and make it backward and forward compatible
864+ # right now, API returns a dict on model_info and a list on list_models. Let's harmonize to list.
865+ mapping = kwargs .pop ("inferenceProviderMapping" , None )
866+ if isinstance (mapping , list ):
867+ self .inference_provider_mapping = [
868+ InferenceProviderMapping (** {** value , "hf_model_id" : self .id }) for value in mapping
869+ ]
870+ elif isinstance (mapping , dict ):
871+ self .inference_provider_mapping = [
872+ InferenceProviderMapping (** {** value , "hf_model_id" : self .id , "provider" : provider })
873+ for provider , value in mapping .items ()
874+ ]
875+ elif mapping is None :
876+ self .inference_provider_mapping = None
877+ else :
878+ raise ValueError (
879+ f"Unexpected type for `inferenceProviderMapping`. Expecting `dict` or `list`. Got { mapping } ."
880+ )
863881
864882 self .tags = kwargs .pop ("tags" , None )
865883 self .pipeline_tag = kwargs .pop ("pipeline_tag" , None )
@@ -1836,7 +1854,8 @@ def list_models(
18361854 filter : Union [str , Iterable [str ], None ] = None ,
18371855 author : Optional [str ] = None ,
18381856 gated : Optional [bool ] = None ,
1839- inference : Optional [Literal ["cold" , "frozen" , "warm" ]] = None ,
1857+ inference : Optional [Literal ["warm" ]] = None ,
1858+ inference_provider : Optional [Union [Literal ["all" ], "PROVIDER_T" , List ["PROVIDER_T" ]]] = None ,
18401859 library : Optional [Union [str , List [str ]]] = None ,
18411860 language : Optional [Union [str , List [str ]]] = None ,
18421861 model_name : Optional [str ] = None ,
@@ -1870,10 +1889,11 @@ def list_models(
18701889 A boolean to filter models on the Hub that are gated or not. By default, all models are returned.
18711890 If `gated=True` is passed, only gated models are returned.
18721891 If `gated=False` is passed, only non-gated models are returned.
1873- inference (`Literal["cold", "frozen", "warm"]`, *optional*):
1874- A string to filter models on the Hub by their state on the Inference API.
1875- Warm models are available for immediate use. Cold models will be loaded on first inference call.
1876- Frozen models are not available in Inference API.
1892+ inference (`Literal["warm"]`, *optional*):
1893+ If "warm", filter models on the Hub currently served by at least one provider.
1894+ inference_provider (`Literal["all"]` or `str`, *optional*):
1895+ A string to filter models on the Hub that are served by a specific provider.
1896+ Pass `"all"` to get all models served by at least one provider.
18771897 library (`str` or `List`, *optional*):
18781898 A string or list of strings of foundational libraries models were
18791899 originally trained from, such as pytorch, tensorflow, or allennlp.
@@ -1933,7 +1953,7 @@ def list_models(
19331953 Returns:
19341954 `Iterable[ModelInfo]`: an iterable of [`huggingface_hub.hf_api.ModelInfo`] objects.
19351955
1936- Example usage with the `filter` argument :
1956+ Example:
19371957
19381958 ```python
19391959 >>> from huggingface_hub import HfApi
@@ -1943,24 +1963,19 @@ def list_models(
19431963 # List all models
19441964 >>> api.list_models()
19451965
1946- # List only the text classification models
1966+ # List text classification models
19471967 >>> api.list_models(filter="text-classification")
19481968
1949- # List only models from the AllenNLP library
1950- >>> api.list_models(filter="allennlp")
1951- ```
1952-
1953- Example usage with the `search` argument:
1969+ # List models from the KerasHub library
1970+ >>> api.list_models(filter="keras-hub")
19541971
1955- ```python
1956- >>> from huggingface_hub import HfApi
1957-
1958- >>> api = HfApi()
1972+ # List models served by Cohere
1973+ >>> api.list_models(inference_provider="cohere")
19591974
1960- # List all models with "bert" in their name
1975+ # List models with "bert" in their name
19611976 >>> api.list_models(search="bert")
19621977
1963- # List all models with "bert" in their name made by google
1978+ # List models with "bert" in their name and pushed by google
19641979 >>> api.list_models(search="bert", author="google")
19651980 ```
19661981 """
@@ -2003,6 +2018,8 @@ def list_models(
20032018 params ["gated" ] = gated
20042019 if inference is not None :
20052020 params ["inference" ] = inference
2021+ if inference_provider is not None :
2022+ params ["inference_provider" ] = inference_provider
20062023 if pipeline_tag :
20072024 params ["pipeline_tag" ] = pipeline_tag
20082025 search_list = []
0 commit comments