Skip to content

Commit 826f654

Browse files
[InferenceClient] Add third-party providers support (#2757)
* Add first version of third-party providers support * add task level in model id mappings * raise error when task is not supported by a provider + some improvements * small (big) refactoring * multiple fixes * add hf inference tasks * Handle hf_inference in single file (#2766) * harmonize prepare_payload args and add automatic-speech-recognition task * backward compatibility with custom urls * first draft of tests * InferenceClient as fixture + skip if no api_key * give name to parametrized tests * upload cassettes * make quali * download sample files from prod * fix python3.8 * small improvement for better readability Co-authored-by: Lucain <[email protected]> * make style * fixing more tests * test url building * fix and record async client tests * re-add cassettes * fix * add cassettes back * fix test * hopefully this will fix the test * fix sentence similarity test --------- Co-authored-by: Lucain <[email protected]> Co-authored-by: Lucain Pouget <[email protected]>
1 parent b7abb35 commit 826f654

File tree

104 files changed

+66284
-40908
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+66284
-40908
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 207 additions & 149 deletions
Large diffs are not rendered by default.

src/huggingface_hub/inference/_common.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,8 @@
4949
ValidationError,
5050
)
5151

52-
from ..constants import ENDPOINT
5352
from ..utils import (
54-
build_hf_headers,
5553
get_session,
56-
hf_raise_for_status,
5754
is_aiohttp_available,
5855
is_numpy_available,
5956
is_pillow_available,
@@ -141,30 +138,6 @@ def _import_pil_image():
141138
return Image
142139

143140

144-
## RECOMMENDED MODELS
145-
146-
# Will be globally fetched only once (see '_fetch_recommended_models')
147-
_RECOMMENDED_MODELS: Optional[Dict[str, Optional[str]]] = None
148-
149-
150-
def _fetch_recommended_models() -> Dict[str, Optional[str]]:
151-
global _RECOMMENDED_MODELS
152-
if _RECOMMENDED_MODELS is None:
153-
response = get_session().get(f"{ENDPOINT}/api/tasks", headers=build_hf_headers())
154-
hf_raise_for_status(response)
155-
_RECOMMENDED_MODELS = {
156-
task: _first_or_none(details["widgetModels"]) for task, details in response.json().items()
157-
}
158-
return _RECOMMENDED_MODELS
159-
160-
161-
def _first_or_none(items: List[Any]) -> Optional[Any]:
162-
try:
163-
return items[0] or None
164-
except IndexError:
165-
return None
166-
167-
168141
## ENCODING / DECODING UTILS
169142

170143

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 208 additions & 150 deletions
Large diffs are not rendered by default.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# mypy: disable-error-code="dict-item"
2+
from typing import Any, Dict, Optional, Protocol, Union
3+
4+
from . import fal_ai, replicate, sambanova, together
5+
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
6+
7+
8+
class TaskProviderHelper(Protocol):
9+
"""Protocol defining the interface for task-specific provider helpers."""
10+
11+
def build_url(self, model: Optional[str] = None) -> str: ...
12+
def map_model(self, model: Optional[str] = None) -> str: ...
13+
def prepare_headers(self, headers: Dict, *, token: Optional[str] = None) -> Dict: ...
14+
def prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: ...
15+
def get_response(self, response: Union[bytes, Dict]) -> Any: ...
16+
17+
18+
PROVIDERS: Dict[str, Dict[str, TaskProviderHelper]] = {
19+
"replicate": {
20+
"text-to-image": replicate.text_to_image,
21+
},
22+
"fal-ai": {
23+
"text-to-image": fal_ai.text_to_image,
24+
"automatic-speech-recognition": fal_ai.automatic_speech_recognition,
25+
},
26+
"sambanova": {
27+
"conversational": sambanova.conversational,
28+
},
29+
"together": {
30+
"text-to-image": together.text_to_image,
31+
"conversational": together.conversational,
32+
"text-generation": together.text_generation,
33+
},
34+
"hf-inference": {
35+
"text-to-image": HFInferenceTask("text-to-image"),
36+
"conversational": HFInferenceConversational(),
37+
"text-generation": HFInferenceTask("text-generation"),
38+
"text-classification": HFInferenceTask("text-classification"),
39+
"question-answering": HFInferenceTask("question-answering"),
40+
"audio-classification": HFInferenceBinaryInputTask("audio-classification"),
41+
"automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"),
42+
"fill-mask": HFInferenceTask("fill-mask"),
43+
"feature-extraction": HFInferenceTask("feature-extraction"),
44+
"image-classification": HFInferenceBinaryInputTask("image-classification"),
45+
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
46+
"document-question-answering": HFInferenceTask("document-question-answering"),
47+
"image-to-text": HFInferenceTask("image-to-text"),
48+
"object-detection": HFInferenceBinaryInputTask("object-detection"),
49+
"audio-to-audio": HFInferenceTask("audio-to-audio"),
50+
"zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"),
51+
"zero-shot-classification": HFInferenceTask("zero-shot-classification"),
52+
"image-to-image": HFInferenceBinaryInputTask("image-to-image"),
53+
"sentence-similarity": HFInferenceTask("sentence-similarity"),
54+
"table-question-answering": HFInferenceTask("table-question-answering"),
55+
"tabular-classification": HFInferenceTask("tabular-classification"),
56+
"text-to-speech": HFInferenceTask("text-to-speech"),
57+
"token-classification": HFInferenceTask("token-classification"),
58+
"translation": HFInferenceTask("translation"),
59+
"summarization": HFInferenceTask("summarization"),
60+
"visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"),
61+
},
62+
}
63+
64+
65+
def get_provider_helper(provider: str, task: str) -> TaskProviderHelper:
66+
"""Get provider helper instance by name and task.
67+
68+
Args:
69+
provider (str): Name of the provider
70+
task (str): Name of the task
71+
72+
Returns:
73+
TaskProviderHelper: Helper instance for the specified provider and task
74+
75+
Raises:
76+
ValueError: If provider or task is not supported
77+
"""
78+
if provider not in PROVIDERS:
79+
raise ValueError(f"Provider '{provider}' not supported. Available providers: {list(PROVIDERS.keys())}")
80+
if task not in PROVIDERS[provider]:
81+
raise ValueError(
82+
f"Task '{task}' not supported for provider '{provider}'. "
83+
f"Available tasks: {list(PROVIDERS[provider].keys())}"
84+
)
85+
return PROVIDERS[provider][task]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# ruff: noqa: F401
2+
from . import automatic_speech_recognition, text_to_image
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import base64
2+
import json
3+
from typing import Any, Dict, Optional, Union
4+
5+
6+
BASE_URL = "https://fal.run"
7+
8+
SUPPORTED_MODELS = {
9+
"openai/whisper-large-v3": "fal-ai/whisper",
10+
}
11+
12+
13+
def build_url(model: Optional[str] = None) -> str:
14+
return f"{BASE_URL}/{model}"
15+
16+
17+
def map_model(model: str) -> str:
18+
mapped_model = SUPPORTED_MODELS.get(model)
19+
if mapped_model is None:
20+
raise ValueError(f"Model {model} is not supported for Fal.AI automatic-speech-recognition task")
21+
return mapped_model
22+
23+
24+
def prepare_headers(headers: Dict, *, token: Optional[str] = None) -> Dict:
25+
return {
26+
**headers,
27+
"authorization": f"Key {token}",
28+
}
29+
30+
31+
def prepare_payload(
32+
inputs: Any,
33+
parameters: Dict[str, Any],
34+
) -> Dict[str, Any]:
35+
if isinstance(inputs, str) and (inputs.startswith("http://") or inputs.startswith("https://")):
36+
# If input is a URL, pass it directly
37+
audio_url = inputs
38+
else:
39+
# If input is a file path, read it first
40+
if isinstance(inputs, str):
41+
with open(inputs, "rb") as f:
42+
inputs = f.read()
43+
44+
audio_b64 = base64.b64encode(inputs).decode()
45+
content_type = "audio/mpeg"
46+
audio_url = f"data:{content_type};base64,{audio_b64}"
47+
48+
return {
49+
"json": {
50+
"audio_url": audio_url,
51+
**{k: v for k, v in parameters.items() if v is not None},
52+
}
53+
}
54+
55+
56+
def get_response(response: Union[bytes, Dict]) -> Any:
57+
if isinstance(response, bytes):
58+
response_dict = json.loads(response)
59+
else:
60+
response_dict = response
61+
if not isinstance(response_dict["text"], str):
62+
raise ValueError("Unexpected output format from API. Expected string.")
63+
return response_dict["text"]
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import json
2+
from typing import Any, Dict, Optional, Union
3+
4+
from huggingface_hub.utils import get_session
5+
6+
7+
BASE_URL = "https://fal.run"
8+
9+
SUPPORTED_MODELS = {
10+
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
11+
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
12+
}
13+
14+
15+
def build_url(model: Optional[str] = None) -> str:
16+
return f"{BASE_URL}/{model}"
17+
18+
19+
def map_model(model: str) -> str:
20+
mapped_model = SUPPORTED_MODELS.get(model)
21+
if mapped_model is None:
22+
raise ValueError(f"Model {model} is not supported for Fal.AI text-to-image task")
23+
return mapped_model
24+
25+
26+
def prepare_headers(headers: Dict, *, token: Optional[str] = None) -> Dict:
27+
return {
28+
**headers,
29+
"Authorization": f"Key {token}",
30+
}
31+
32+
33+
def prepare_payload(inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
34+
parameters = {k: v for k, v in parameters.items() if v is not None}
35+
if "image_size" not in parameters and "width" in parameters and "height" in parameters:
36+
parameters["image_size"] = {
37+
"width": parameters.pop("width"),
38+
"height": parameters.pop("height"),
39+
}
40+
41+
return {
42+
"json": {"prompt": inputs, **parameters},
43+
}
44+
45+
46+
def get_response(response: Union[bytes, Dict]) -> Any:
47+
if isinstance(response, bytes):
48+
response_dict = json.loads(response) # type: ignore
49+
else:
50+
response_dict = response
51+
url = response_dict["images"][0]["url"]
52+
return get_session().get(url).content
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import logging
2+
from pathlib import Path
3+
from typing import Any, BinaryIO, Dict, List, Optional, Union
4+
5+
from huggingface_hub.constants import ENDPOINT
6+
from huggingface_hub.inference._common import _b64_encode, _open_as_binary
7+
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
8+
9+
10+
# TYPES
11+
UrlT = str
12+
PathT = Union[str, Path]
13+
BinaryT = Union[bytes, BinaryIO]
14+
ContentT = Union[BinaryT, PathT, UrlT]
15+
16+
# Use to set a Accept: image/png header
17+
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
## RECOMMENDED MODELS
23+
24+
# Will be globally fetched only once (see '_fetch_recommended_models')
25+
_RECOMMENDED_MODELS: Optional[Dict[str, Optional[str]]] = None
26+
27+
BASE_URL = "https://api-inference.huggingface.co"
28+
29+
30+
def _first_or_none(items: List[Any]) -> Optional[Any]:
31+
try:
32+
return items[0] or None
33+
except IndexError:
34+
return None
35+
36+
37+
def _fetch_recommended_models() -> Dict[str, Optional[str]]:
38+
global _RECOMMENDED_MODELS
39+
if _RECOMMENDED_MODELS is None:
40+
response = get_session().get(f"{ENDPOINT}/api/tasks", headers=build_hf_headers())
41+
hf_raise_for_status(response)
42+
_RECOMMENDED_MODELS = {
43+
task: _first_or_none(details["widgetModels"]) for task, details in response.json().items()
44+
}
45+
return _RECOMMENDED_MODELS
46+
47+
48+
def get_recommended_model(task: str) -> str:
49+
"""
50+
Get the model Hugging Face recommends for the input task.
51+
52+
Args:
53+
task (`str`):
54+
The Hugging Face task to get which model Hugging Face recommends.
55+
All available tasks can be found [here](https://huggingface.co/tasks).
56+
57+
Returns:
58+
`str`: Name of the model recommended for the input task.
59+
60+
Raises:
61+
`ValueError`: If Hugging Face has no recommendation for the input task.
62+
"""
63+
model = _fetch_recommended_models().get(task)
64+
if model is None:
65+
raise ValueError(
66+
f"Task {task} has no recommended model. Please specify a model"
67+
" explicitly. Visit https://huggingface.co/tasks for more info."
68+
)
69+
return model
70+
71+
72+
class HFInferenceTask:
73+
"""Base class for HF Inference API tasks."""
74+
75+
def __init__(self, task: str):
76+
self.task = task
77+
78+
def build_url(self, model: Optional[str] = None) -> str:
79+
if model is None:
80+
model = get_recommended_model(self.task)
81+
return (
82+
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
83+
f"{BASE_URL}/pipeline/{self.task}/{model}"
84+
if self.task in ("feature-extraction", "sentence-similarity")
85+
# Otherwise, we use the default endpoint
86+
else f"{BASE_URL}/models/{model}"
87+
)
88+
89+
def map_model(self, model: str) -> str:
90+
return model if model is not None else get_recommended_model(self.task)
91+
92+
def prepare_headers(self, headers: Dict, *, token: Optional[str] = None) -> Dict:
93+
return headers
94+
95+
def prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
96+
if isinstance(inputs, (bytes, Path)):
97+
raise ValueError(f"Unexpected binary inputs. Got {inputs}") # type: ignore
98+
_ = parameters.pop("model")
99+
return {
100+
"json": {
101+
"inputs": inputs,
102+
"parameters": {k: v for k, v in parameters.items() if v is not None},
103+
}
104+
}
105+
106+
def get_response(self, response: Union[bytes, Dict]) -> Any:
107+
return response
108+
109+
110+
class HFInferenceBinaryInputTask(HFInferenceTask):
111+
def prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
112+
parameters = {k: v for k, v in parameters.items() if v is not None}
113+
_ = parameters.pop("model") # model is not a valid parameter for hf-inference tasks
114+
has_parameters = len(parameters) > 0
115+
116+
# Raise if not a binary object or a local path or a URL.
117+
if not isinstance(inputs, (bytes, Path)) and not isinstance(inputs, str):
118+
raise ValueError(f"Expected binary inputs or a local path or a URL. Got {inputs}")
119+
120+
# Send inputs as raw content when no parameters are provided
121+
if not has_parameters:
122+
with _open_as_binary(inputs) as data:
123+
data_as_bytes = data if isinstance(data, bytes) else data.read()
124+
return {"data": data_as_bytes}
125+
126+
# Otherwise encode as b64
127+
return {"json": {"inputs": _b64_encode(inputs), "parameters": parameters}}
128+
129+
130+
class HFInferenceConversational(HFInferenceTask):
131+
def __init__(self):
132+
super().__init__("conversational")
133+
134+
def build_url(self, model: Optional[str] = None) -> str:
135+
if model is None:
136+
model = get_recommended_model("text-generation")
137+
return f"{BASE_URL}/models/{model}/v1/chat/completions"
138+
139+
def prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
140+
parameters = {key: value for key, value in parameters.items() if value is not None}
141+
model = parameters.get("model")
142+
return {"model": model, "messages": inputs, **parameters}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# ruff: noqa: F401
2+
from . import text_to_image

0 commit comments

Comments
 (0)