22from typing import Any , Dict , Optional , Union
33
44from huggingface_hub import constants
5+ from huggingface_hub .hf_api import InferenceProviderMapping
56from huggingface_hub .inference ._common import RequestParameters
67from huggingface_hub .utils import build_hf_headers , get_token , logging
78
89
910logger = logging .get_logger (__name__ )
1011
11-
1212# Dev purposes only.
1313# If you want to try to run inference for a new model locally before it's registered on huggingface.co
1414# for a given Inference Provider, you can add it to the following dictionary.
15- HARDCODED_MODEL_ID_MAPPING : Dict [str , Dict [str , str ]] = {
16- # "HF model ID" => "Model ID on Inference Provider's side"
15+ HARDCODED_MODEL_INFERENCE_MAPPING : Dict [str , Dict [str , InferenceProviderMapping ]] = {
16+ # "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side"
1717 #
1818 # Example:
19- # "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
19+ # "Qwen/Qwen2.5-Coder-32B-Instruct": InferenceProviderMapping(hf_model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
20+ # provider_id="Qwen2.5-Coder-32B-Instruct",
21+ # task="conversational",
22+ # status="live")
2023 "cerebras" : {},
2124 "cohere" : {},
2225 "fal-ai" : {},
@@ -61,28 +64,30 @@ def prepare_request(
6164 api_key = self ._prepare_api_key (api_key )
6265
6366 # mapped model from HF model ID
64- mapped_model = self ._prepare_mapped_model (model )
67+ provider_mapping_info = self ._prepare_mapping_info (model )
6568
6669 # default HF headers + user headers (to customize in subclasses)
6770 headers = self ._prepare_headers (headers , api_key )
6871
6972 # routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses)
70- url = self ._prepare_url (api_key , mapped_model )
73+ url = self ._prepare_url (api_key , provider_mapping_info . provider_id )
7174
7275 # prepare payload (to customize in subclasses)
73- payload = self ._prepare_payload_as_dict (inputs , parameters , mapped_model = mapped_model )
76+ payload = self ._prepare_payload_as_dict (inputs , parameters , provider_mapping_info = provider_mapping_info )
7477 if payload is not None :
7578 payload = recursive_merge (payload , extra_payload or {})
7679
7780 # body data (to customize in subclasses)
78- data = self ._prepare_payload_as_bytes (inputs , parameters , mapped_model , extra_payload )
81+ data = self ._prepare_payload_as_bytes (inputs , parameters , provider_mapping_info , extra_payload )
7982
8083 # check if both payload and data are set and return
8184 if payload is not None and data is not None :
8285 raise ValueError ("Both payload and data cannot be set in the same request." )
8386 if payload is None and data is None :
8487 raise ValueError ("Either payload or data must be set in the request." )
85- return RequestParameters (url = url , task = self .task , model = mapped_model , json = payload , data = data , headers = headers )
88+ return RequestParameters (
89+ url = url , task = self .task , model = provider_mapping_info .provider_id , json = payload , data = data , headers = headers
90+ )
8691
8792 def get_response (
8893 self ,
@@ -107,16 +112,16 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str:
107112 )
108113 return api_key
109114
110- def _prepare_mapped_model (self , model : Optional [str ]) -> str :
115+ def _prepare_mapping_info (self , model : Optional [str ]) -> InferenceProviderMapping :
111116 """Return the mapped model ID to use for the request.
112117
113118 Usually not overwritten in subclasses."""
114119 if model is None :
115120 raise ValueError (f"Please provide an HF model ID supported by { self .provider } ." )
116121
117122 # hardcoded mapping for local testing
118- if HARDCODED_MODEL_ID_MAPPING .get (self .provider , {}).get (model ):
119- return HARDCODED_MODEL_ID_MAPPING [self .provider ][model ]
123+ if HARDCODED_MODEL_INFERENCE_MAPPING .get (self .provider , {}).get (model ):
124+ return HARDCODED_MODEL_INFERENCE_MAPPING [self .provider ][model ]
120125
121126 provider_mapping = _fetch_inference_provider_mapping (model ).get (self .provider )
122127 if provider_mapping is None :
@@ -132,7 +137,7 @@ def _prepare_mapped_model(self, model: Optional[str]) -> str:
132137 logger .warning (
133138 f"Model { model } is in staging mode for provider { self .provider } . Meant for test purposes only."
134139 )
135- return provider_mapping . provider_id
140+ return provider_mapping
136141
137142 def _prepare_headers (self , headers : Dict , api_key : str ) -> Dict :
138143 """Return the headers to use for the request.
@@ -168,7 +173,9 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
168173 """
169174 return ""
170175
171- def _prepare_payload_as_dict (self , inputs : Any , parameters : Dict , mapped_model : str ) -> Optional [Dict ]:
176+ def _prepare_payload_as_dict (
177+ self , inputs : Any , parameters : Dict , provider_mapping_info : InferenceProviderMapping
178+ ) -> Optional [Dict ]:
172179 """Return the payload to use for the request, as a dict.
173180
174181 Override this method in subclasses for customized payloads.
@@ -177,7 +184,11 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:
177184 return None
178185
179186 def _prepare_payload_as_bytes (
180- self , inputs : Any , parameters : Dict , mapped_model : str , extra_payload : Optional [Dict ]
187+ self ,
188+ inputs : Any ,
189+ parameters : Dict ,
190+ provider_mapping_info : InferenceProviderMapping ,
191+ extra_payload : Optional [Dict ],
181192 ) -> Optional [bytes ]:
182193 """Return the body to use for the request, as bytes.
183194
@@ -199,8 +210,10 @@ def __init__(self, provider: str, base_url: str):
199210 def _prepare_route (self , mapped_model : str , api_key : str ) -> str :
200211 return "/v1/chat/completions"
201212
202- def _prepare_payload_as_dict (self , inputs : Any , parameters : Dict , mapped_model : str ) -> Optional [Dict ]:
203- return {"messages" : inputs , ** filter_none (parameters ), "model" : mapped_model }
213+ def _prepare_payload_as_dict (
214+ self , inputs : Any , parameters : Dict , provider_mapping_info : InferenceProviderMapping
215+ ) -> Optional [Dict ]:
216+ return {"messages" : inputs , ** filter_none (parameters ), "model" : provider_mapping_info .provider_id }
204217
205218
206219class BaseTextGenerationTask (TaskProviderHelper ):
@@ -215,8 +228,10 @@ def __init__(self, provider: str, base_url: str):
215228 def _prepare_route (self , mapped_model : str , api_key : str ) -> str :
216229 return "/v1/completions"
217230
218- def _prepare_payload_as_dict (self , inputs : Any , parameters : Dict , mapped_model : str ) -> Optional [Dict ]:
219- return {"prompt" : inputs , ** filter_none (parameters ), "model" : mapped_model }
231+ def _prepare_payload_as_dict (
232+ self , inputs : Any , parameters : Dict , provider_mapping_info : InferenceProviderMapping
233+ ) -> Optional [Dict ]:
234+ return {"prompt" : inputs , ** filter_none (parameters ), "model" : provider_mapping_info .provider_id }
220235
221236
222237@lru_cache (maxsize = None )
0 commit comments