Skip to content

Commit bebc1f7

Browse files
[Inference Providers] Support for LoRAs (#3005)
* add loras support * nit * review suggestions * update inference provider mapping object * fix tests * fixes * use the precomputed adapterWeightsPath property * remove unnecessary function * Update src/huggingface_hub/hf_api.py Co-authored-by: Lucain <[email protected]> * style * add comment Co-authored-by: Lucain <[email protected]> --------- Co-authored-by: Lucain <[email protected]>
1 parent 0709088 commit bebc1f7

File tree

14 files changed

+315
-90
lines changed

14 files changed

+315
-90
lines changed

src/huggingface_hub/_inference_endpoints.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66

77
from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError
88

9-
from .inference._client import InferenceClient
10-
from .inference._generated._async_client import AsyncInferenceClient
119
from .utils import get_session, logging, parse_datetime
1210

1311

1412
if TYPE_CHECKING:
1513
from .hf_api import HfApi
16-
14+
from .inference._client import InferenceClient
15+
from .inference._generated._async_client import AsyncInferenceClient
1716

1817
logger = logging.get_logger(__name__)
1918

@@ -138,7 +137,7 @@ def __post_init__(self) -> None:
138137
self._populate_from_raw()
139138

140139
@property
141-
def client(self) -> InferenceClient:
140+
def client(self) -> "InferenceClient":
142141
"""Returns a client to make predictions on this Inference Endpoint.
143142
144143
Returns:
@@ -152,13 +151,15 @@ def client(self) -> InferenceClient:
152151
"Cannot create a client for this Inference Endpoint as it is not yet deployed. "
153152
"Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
154153
)
154+
from .inference._client import InferenceClient
155+
155156
return InferenceClient(
156157
model=self.url,
157158
token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok.
158159
)
159160

160161
@property
161-
def async_client(self) -> AsyncInferenceClient:
162+
def async_client(self) -> "AsyncInferenceClient":
162163
"""Returns a client to make predictions on this Inference Endpoint.
163164
164165
Returns:
@@ -172,6 +173,8 @@ def async_client(self) -> AsyncInferenceClient:
172173
"Cannot create a client for this Inference Endpoint as it is not yet deployed. "
173174
"Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
174175
)
176+
from .inference._generated._async_client import AsyncInferenceClient
177+
175178
return AsyncInferenceClient(
176179
model=self.url,
177180
token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok.

src/huggingface_hub/hf_api.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,14 +708,21 @@ def __init__(self, **kwargs):
708708

709709
@dataclass
710710
class InferenceProviderMapping:
711+
hf_model_id: str
711712
status: Literal["live", "staging"]
712713
provider_id: str
713714
task: str
714715

716+
adapter: Optional[str] = None
717+
adapter_weights_path: Optional[str] = None
718+
715719
def __init__(self, **kwargs):
720+
self.hf_model_id = kwargs.pop("hf_model_id")
716721
self.status = kwargs.pop("status")
717722
self.provider_id = kwargs.pop("providerId")
718723
self.task = kwargs.pop("task")
724+
self.adapter = kwargs.pop("adapter", None)
725+
self.adapter_weights_path = kwargs.pop("adapterWeightsPath", None)
719726
self.__dict__.update(**kwargs)
720727

721728

@@ -847,7 +854,9 @@ def __init__(self, **kwargs):
847854
self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None)
848855
if self.inference_provider_mapping:
849856
self.inference_provider_mapping = {
850-
provider: InferenceProviderMapping(**value)
857+
provider: InferenceProviderMapping(
858+
**{**value, "hf_model_id": self.id}
859+
) # little hack to simplify Inference Providers logic
851860
for provider, value in self.inference_provider_mapping.items()
852861
}
853862

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,24 @@
22
from typing import Any, Dict, Optional, Union
33

44
from huggingface_hub import constants
5+
from huggingface_hub.hf_api import InferenceProviderMapping
56
from huggingface_hub.inference._common import RequestParameters
67
from huggingface_hub.utils import build_hf_headers, get_token, logging
78

89

910
logger = logging.get_logger(__name__)
1011

11-
1212
# Dev purposes only.
1313
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
1414
# for a given Inference Provider, you can add it to the following dictionary.
15-
HARDCODED_MODEL_ID_MAPPING: Dict[str, Dict[str, str]] = {
16-
# "HF model ID" => "Model ID on Inference Provider's side"
15+
HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, InferenceProviderMapping]] = {
16+
# "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side"
1717
#
1818
# Example:
19-
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
19+
# "Qwen/Qwen2.5-Coder-32B-Instruct": InferenceProviderMapping(hf_model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
20+
# provider_id="Qwen2.5-Coder-32B-Instruct",
21+
# task="conversational",
22+
# status="live")
2023
"cerebras": {},
2124
"cohere": {},
2225
"fal-ai": {},
@@ -61,28 +64,30 @@ def prepare_request(
6164
api_key = self._prepare_api_key(api_key)
6265

6366
# mapped model from HF model ID
64-
mapped_model = self._prepare_mapped_model(model)
67+
provider_mapping_info = self._prepare_mapping_info(model)
6568

6669
# default HF headers + user headers (to customize in subclasses)
6770
headers = self._prepare_headers(headers, api_key)
6871

6972
# routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses)
70-
url = self._prepare_url(api_key, mapped_model)
73+
url = self._prepare_url(api_key, provider_mapping_info.provider_id)
7174

7275
# prepare payload (to customize in subclasses)
73-
payload = self._prepare_payload_as_dict(inputs, parameters, mapped_model=mapped_model)
76+
payload = self._prepare_payload_as_dict(inputs, parameters, provider_mapping_info=provider_mapping_info)
7477
if payload is not None:
7578
payload = recursive_merge(payload, extra_payload or {})
7679

7780
# body data (to customize in subclasses)
78-
data = self._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)
81+
data = self._prepare_payload_as_bytes(inputs, parameters, provider_mapping_info, extra_payload)
7982

8083
# check if both payload and data are set and return
8184
if payload is not None and data is not None:
8285
raise ValueError("Both payload and data cannot be set in the same request.")
8386
if payload is None and data is None:
8487
raise ValueError("Either payload or data must be set in the request.")
85-
return RequestParameters(url=url, task=self.task, model=mapped_model, json=payload, data=data, headers=headers)
88+
return RequestParameters(
89+
url=url, task=self.task, model=provider_mapping_info.provider_id, json=payload, data=data, headers=headers
90+
)
8691

8792
def get_response(
8893
self,
@@ -107,16 +112,16 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str:
107112
)
108113
return api_key
109114

110-
def _prepare_mapped_model(self, model: Optional[str]) -> str:
115+
def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping:
111116
"""Return the mapped model ID to use for the request.
112117
113118
Usually not overwritten in subclasses."""
114119
if model is None:
115120
raise ValueError(f"Please provide an HF model ID supported by {self.provider}.")
116121

117122
# hardcoded mapping for local testing
118-
if HARDCODED_MODEL_ID_MAPPING.get(self.provider, {}).get(model):
119-
return HARDCODED_MODEL_ID_MAPPING[self.provider][model]
123+
if HARDCODED_MODEL_INFERENCE_MAPPING.get(self.provider, {}).get(model):
124+
return HARDCODED_MODEL_INFERENCE_MAPPING[self.provider][model]
120125

121126
provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider)
122127
if provider_mapping is None:
@@ -132,7 +137,7 @@ def _prepare_mapped_model(self, model: Optional[str]) -> str:
132137
logger.warning(
133138
f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only."
134139
)
135-
return provider_mapping.provider_id
140+
return provider_mapping
136141

137142
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
138143
"""Return the headers to use for the request.
@@ -168,7 +173,9 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
168173
"""
169174
return ""
170175

171-
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
176+
def _prepare_payload_as_dict(
177+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
178+
) -> Optional[Dict]:
172179
"""Return the payload to use for the request, as a dict.
173180
174181
Override this method in subclasses for customized payloads.
@@ -177,7 +184,11 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:
177184
return None
178185

179186
def _prepare_payload_as_bytes(
180-
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
187+
self,
188+
inputs: Any,
189+
parameters: Dict,
190+
provider_mapping_info: InferenceProviderMapping,
191+
extra_payload: Optional[Dict],
181192
) -> Optional[bytes]:
182193
"""Return the body to use for the request, as bytes.
183194
@@ -199,8 +210,10 @@ def __init__(self, provider: str, base_url: str):
199210
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
200211
return "/v1/chat/completions"
201212

202-
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
203-
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
213+
def _prepare_payload_as_dict(
214+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
215+
) -> Optional[Dict]:
216+
return {"messages": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id}
204217

205218

206219
class BaseTextGenerationTask(TaskProviderHelper):
@@ -215,8 +228,10 @@ def __init__(self, provider: str, base_url: str):
215228
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
216229
return "/v1/completions"
217230

218-
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
219-
return {"prompt": inputs, **filter_none(parameters), "model": mapped_model}
231+
def _prepare_payload_as_dict(
232+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
233+
) -> Optional[Dict]:
234+
return {"prompt": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id}
220235

221236

222237
@lru_cache(maxsize=None)

src/huggingface_hub/inference/_providers/black_forest_labs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22
from typing import Any, Dict, Optional, Union
33

4+
from huggingface_hub.hf_api import InferenceProviderMapping
45
from huggingface_hub.inference._common import RequestParameters, _as_dict
56
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
67
from huggingface_hub.utils import logging
@@ -27,7 +28,9 @@ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
2728
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
2829
return f"/v1/{mapped_model}"
2930

30-
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
31+
def _prepare_payload_as_dict(
32+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
33+
) -> Optional[Dict]:
3134
parameters = filter_none(parameters)
3235
if "num_inference_steps" in parameters:
3336
parameters["steps"] = parameters.pop("num_inference_steps")

src/huggingface_hub/inference/_providers/fal_ai.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import Any, Dict, Optional, Union
55
from urllib.parse import urlparse
66

7+
from huggingface_hub import constants
8+
from huggingface_hub.hf_api import InferenceProviderMapping
79
from huggingface_hub.inference._common import RequestParameters, _as_dict
810
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
911
from huggingface_hub.utils import get_session, hf_raise_for_status
@@ -34,7 +36,9 @@ class FalAIAutomaticSpeechRecognitionTask(FalAITask):
3436
def __init__(self):
3537
super().__init__("automatic-speech-recognition")
3638

37-
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
39+
def _prepare_payload_as_dict(
40+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
41+
) -> Optional[Dict]:
3842
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
3943
# If input is a URL, pass it directly
4044
audio_url = inputs
@@ -61,14 +65,31 @@ class FalAITextToImageTask(FalAITask):
6165
def __init__(self):
6266
super().__init__("text-to-image")
6367

64-
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
65-
parameters = filter_none(parameters)
66-
if "width" in parameters and "height" in parameters:
67-
parameters["image_size"] = {
68-
"width": parameters.pop("width"),
69-
"height": parameters.pop("height"),
68+
def _prepare_payload_as_dict(
69+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
70+
) -> Optional[Dict]:
71+
payload: Dict[str, Any] = {
72+
"prompt": inputs,
73+
**filter_none(parameters),
74+
}
75+
if "width" in payload and "height" in payload:
76+
payload["image_size"] = {
77+
"width": payload.pop("width"),
78+
"height": payload.pop("height"),
7079
}
71-
return {"prompt": inputs, **parameters}
80+
if provider_mapping_info.adapter_weights_path is not None:
81+
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
82+
repo_id=provider_mapping_info.hf_model_id,
83+
revision="main",
84+
filename=provider_mapping_info.adapter_weights_path,
85+
)
86+
payload["loras"] = [{"path": lora_path, "scale": 1}]
87+
if provider_mapping_info.provider_id == "fal-ai/lora":
88+
# little hack: fal requires the base model for stable-diffusion-based loras but not for flux-based
89+
# See payloads in https://fal.ai/models/fal-ai/lora/api vs https://fal.ai/models/fal-ai/flux-lora/api
90+
payload["model_name"] = "stabilityai/stable-diffusion-xl-base-1.0"
91+
92+
return payload
7293

7394
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
7495
url = _as_dict(response)["images"][0]["url"]
@@ -79,7 +100,9 @@ class FalAITextToSpeechTask(FalAITask):
79100
def __init__(self):
80101
super().__init__("text-to-speech")
81102

82-
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
103+
def _prepare_payload_as_dict(
104+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
105+
) -> Optional[Dict]:
83106
return {"text": inputs, **filter_none(parameters)}
84107

85108
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
@@ -104,7 +127,9 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
104127
return f"/{mapped_model}?_subdomain=queue"
105128
return f"/{mapped_model}"
106129

107-
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
130+
def _prepare_payload_as_dict(
131+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
132+
) -> Optional[Dict]:
108133
return {"prompt": inputs, **filter_none(parameters)}
109134

110135
def get_response(

0 commit comments

Comments
 (0)