Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5c1a209
Add first version of third-party providers support
hanouticelina Jan 17, 2025
c142995
add task level in model id mappings
hanouticelina Jan 17, 2025
051ee76
raise error when task is not supported by a provider + some improvements
hanouticelina Jan 17, 2025
ab2b44f
small (big) refactoring
hanouticelina Jan 17, 2025
a896f3e
multiple fixes
hanouticelina Jan 20, 2025
49d389a
add hf inference tasks
hanouticelina Jan 20, 2025
20e8d3a
Handle hf_inference in single file (#2766)
Wauplin Jan 20, 2025
ceca530
harmonize prepare_payload args and add automatic-speech-recognition task
hanouticelina Jan 21, 2025
a076eb4
backward compatibility with custom urls
hanouticelina Jan 21, 2025
aca6050
first draft of tests
hanouticelina Jan 22, 2025
c2bdcc2
InferenceClient as fixture + skip if no api_key
Wauplin Jan 22, 2025
4489069
give name to parametrized tests
Wauplin Jan 22, 2025
5f9d946
upload cassettes
hanouticelina Jan 22, 2025
e1a379c
make quali
Wauplin Jan 22, 2025
fec77a6
download sample files from prod
Wauplin Jan 22, 2025
a731eec
fix python3.8
Wauplin Jan 22, 2025
9b209b8
small improvement for better readability
hanouticelina Jan 22, 2025
28825cb
make style
Wauplin Jan 22, 2025
a0208c9
fixing more tests
hanouticelina Jan 22, 2025
8f2eb6c
Merge branch 'inference-providers-compatibility' of github.com:huggin…
hanouticelina Jan 22, 2025
456122f
test url building
hanouticelina Jan 22, 2025
d5dcf8f
fix and record async client tests
hanouticelina Jan 22, 2025
4ba4ab4
re-add cassettes
hanouticelina Jan 22, 2025
65b659d
fix
hanouticelina Jan 22, 2025
ae6f2af
add cassettes back
hanouticelina Jan 22, 2025
4d50893
fix test
hanouticelina Jan 22, 2025
9d557f0
hopefully this will fix the test
hanouticelina Jan 22, 2025
92f62fc
fix sentence similarity test
hanouticelina Jan 22, 2025
8ebbe80
Merge branch 'main' into inference-providers-compatibility
hanouticelina Jan 22, 2025
2223998
Merge branch 'main' into inference-providers-compatibility
Wauplin Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 108 additions & 25 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
ZeroShotClassificationOutputElement,
ZeroShotImageClassificationOutputElement,
)
from huggingface_hub.inference._providers import BaseProvider, get_provider
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from huggingface_hub.utils._deprecation import _deprecate_arguments

Expand All @@ -123,7 +124,7 @@ class InferenceClient:
Initialize a new Inference Client.

[`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used
seamlessly with either the (free) Inference API or self-hosted Inference Endpoints.
seamlessly with either the (free) Inference API, self-hosted Inference Endpoints, or third-party Inference Providers.

Args:
model (`str`, `optional`):
Expand All @@ -134,6 +135,9 @@ class InferenceClient:
arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"replicate"`, `"together"`, `"fal-ai"`, or `"sambanova"`.
Defaults to Hugging Face Inference API.
token (`str` or `bool`, *optional*):
Hugging Face token. Will default to the locally saved token if not provided.
Pass `token=False` if you don't want to send your token to the server.
Expand Down Expand Up @@ -161,6 +165,7 @@ def __init__(
self,
model: Optional[str] = None,
*,
provider: Optional[str] = None,
token: Union[str, bool, None] = None,
timeout: Optional[float] = None,
headers: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -191,6 +196,10 @@ def __init__(
)
if headers is not None:
self.headers.update(headers)

# Configure provider
self.provider = provider

self.cookies = cookies
self.timeout = timeout
self.proxies = proxies
Expand All @@ -209,6 +218,7 @@ def post( # type: ignore[misc]
data: Optional[ContentT] = None,
model: Optional[str] = None,
task: Optional[str] = None,
provider_config: Optional[BaseProvider] = None,
stream: Literal[False] = ...,
) -> bytes: ...

Expand All @@ -220,6 +230,7 @@ def post( # type: ignore[misc]
data: Optional[ContentT] = None,
model: Optional[str] = None,
task: Optional[str] = None,
provider_config: Optional[BaseProvider] = None,
stream: Literal[True] = ...,
) -> Iterable[bytes]: ...

Expand All @@ -231,6 +242,7 @@ def post(
data: Optional[ContentT] = None,
model: Optional[str] = None,
task: Optional[str] = None,
provider_config: Optional[BaseProvider] = None,
stream: bool = False,
) -> Union[bytes, Iterable[bytes]]: ...

Expand All @@ -241,6 +253,7 @@ def post(
data: Optional[ContentT] = None,
model: Optional[str] = None,
task: Optional[str] = None,
provider_config: Optional[BaseProvider] = None,
stream: bool = False,
) -> Union[bytes, Iterable[bytes]]:
"""
Expand Down Expand Up @@ -273,7 +286,7 @@ def post(
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.
"""
url = self._resolve_url(model, task)
url = self._resolve_url(model=model, task=task, provider_config=provider_config)

if data is not None and json is not None:
warnings.warn("Ignoring `json` as `data` is passed as binary.")
Expand Down Expand Up @@ -379,6 +392,19 @@ def audio_classification(
response = self.post(**payload, model=model, task="audio-classification")
return AudioClassificationOutputElement.parse_obj_as_list(response)

def _configure_provider(self, provider_name: Optional[str] = None) -> Optional[BaseProvider]:
"""
Get the provider and update headers if needed.
"""
if provider_name is None or provider_name == "hf-inference":
return None # fallback to default provider (HF Inference API)

provider = get_provider(provider_name)
# Update headers with provider-specific ones
kwargs = {"token": self.token}
self.headers = provider.set_custom_headers(self.headers, **kwargs) # type: ignore
return provider

def audio_to_audio(
self,
audio: ContentT,
Expand Down Expand Up @@ -464,6 +490,7 @@ def chat_completion( # type: ignore
messages: List[Dict],
*,
model: Optional[str] = None,
provider: Optional[str] = None,
stream: Literal[False] = False,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
Expand All @@ -489,6 +516,7 @@ def chat_completion( # type: ignore
messages: List[Dict],
*,
model: Optional[str] = None,
provider: Optional[str] = None,
stream: Literal[True] = True,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
Expand All @@ -514,6 +542,7 @@ def chat_completion(
messages: List[Dict],
*,
model: Optional[str] = None,
provider: Optional[str] = None,
stream: bool = False,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
Expand All @@ -538,6 +567,7 @@ def chat_completion(
messages: List[Dict],
*,
model: Optional[str] = None,
provider: Optional[str] = None,
stream: bool = False,
# Parameters from ChatCompletionInput (handled manually)
frequency_penalty: Optional[float] = None,
Expand Down Expand Up @@ -578,6 +608,9 @@ def chat_completion(
See https://huggingface.co/tasks/text-generation for more details.
If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a
custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`].
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"replicate"`, `"together"`, `"fal-ai"`, or `"sambanova"`.
Defaults to Hugging Face Inference API.
frequency_penalty (`float`, *optional*):
Penalizes new tokens based on their existing frequency
in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0.
Expand Down Expand Up @@ -854,16 +887,27 @@ def chat_completion(
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
```
"""
model_url = self._resolve_chat_completion_url(model)
# Resolve model ID with precedence: method argument > instance model > default
model_id = model or self.model or "tgi"

# Get provider config
provider = provider or self.provider
provider_config = self._configure_provider(provider)
# Map model ID if using a third-party provider
if provider_config is not None:
mapped_model = provider_config.MODEL_IDS_MAPPING.get(model_id)
if not mapped_model:
raise ValueError(f"Model '{model_id}' not supported by provider '{provider}'")
model_id = mapped_model
model_url = self._resolve_chat_completion_url(model_id, provider_config=provider_config)

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
model_id = model or self.model or "tgi"
if model_id.startswith(("http://", "https://")):
model_id = "tgi" # dummy value
# For URLs, use "tgi" as model name in payload
payload_model = "tgi" if model_id.startswith(("http://", "https://")) else model_id

payload = dict(
model=model_id,
model=payload_model,
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
Expand Down Expand Up @@ -891,7 +935,11 @@ def chat_completion(

return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]

def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
def _resolve_chat_completion_url(
self,
model: Optional[str] = None,
provider_config: Optional[BaseProvider] = None,
) -> str:
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
# `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
Expand All @@ -900,7 +948,12 @@ def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
model_url = (
model_id_or_url
if model_id_or_url.startswith(("http://", "https://"))
else self._resolve_url(model_id_or_url, task="text-generation")
else self._resolve_url(
model_id_or_url,
task="text-generation",
provider_config=provider_config,
chat_completion=True,
)
)

# Strip trailing /
Expand Down Expand Up @@ -2355,6 +2408,7 @@ def text_to_image(
self,
prompt: str,
*,
provider: Optional[str] = None,
negative_prompt: Optional[List[str]] = None,
height: Optional[float] = None,
width: Optional[float] = None,
Expand All @@ -2378,6 +2432,9 @@ def text_to_image(
Args:
prompt (`str`):
The prompt to generate an image from.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"replicate"`, `"together"` or `"fal-ai"`.
Defaults to Hugging Face Inference API.
negative_prompt (`List[str`, *optional*):
One or several prompt to guide what NOT to include in image generation.
height (`float`, *optional*):
Expand Down Expand Up @@ -2426,20 +2483,33 @@ def text_to_image(
>>> image.save("better_astronaut.png")
```
"""

parameters = {
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"scheduler": scheduler,
"target_size": target_size,
"seed": seed,
**kwargs,
}
payload = _prepare_payload(prompt, parameters=parameters)
response = self.post(**payload, model=model, task="text-to-image")
model = model or self.model
provider = provider or self.provider
provider_config = self._configure_provider(provider)
if provider_config is not None:
mapped_model = provider_config.MODEL_IDS_MAPPING.get(model) # type: ignore
if not mapped_model:
raise ValueError(f"Model '{model}' not supported by provider '{provider}'")
model = mapped_model
payload = provider_config.prepare_custom_payload(
prompt=prompt, model=model, task="text-to-image", **kwargs
)
else:
parameters = {
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"scheduler": scheduler,
"target_size": target_size,
"seed": seed,
**kwargs,
}
payload = _prepare_payload(prompt, parameters=parameters)
response = self.post(**payload, model=model, task="text-to-image", provider_config=provider_config)
if provider_config is not None:
response = provider_config.get_response(response, task="text-to-image")
return _bytes_to_image(response)

def text_to_speech(
Expand Down Expand Up @@ -2966,7 +3036,13 @@ def zero_shot_image_classification(
)
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)

def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
def _resolve_url(
self,
model: Optional[str] = None,
task: Optional[str] = None,
chat_completion: bool = False,
provider_config: Optional[BaseProvider] = None,
) -> str:
model = model or self.model or self.base_url

# If model is already a URL, ignore `task` and return directly
Expand All @@ -2986,8 +3062,15 @@ def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None)
f" encouraged to explicitly set `model='{model}'` as the recommended"
" models list might get updated without prior notice."
)
# Get provider instance
if provider_config:
return provider_config.build_url(
model=model,
task=task,
chat_completion=chat_completion,
)

# Compute InferenceAPI url
# Default to HF InferenceAPI url
return (
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}"
Expand Down
14 changes: 13 additions & 1 deletion src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,22 @@ def _b64_encode(content: ContentT) -> str:
return base64.b64encode(data_as_bytes).decode()


def _b64_to_bytes(b64_string: str) -> bytes:
"""Convert a base64 string to bytes.

Args:
b64_string (str): The base64 encoded string

Returns:
bytes: The decoded bytes
"""
return base64.b64decode(b64_string)


def _b64_to_image(encoded_image: str) -> "Image":
"""Parse a base64-encoded string into a PIL Image."""
Image = _import_pil_image()
return Image.open(io.BytesIO(base64.b64decode(encoded_image)))
return Image.open(io.BytesIO(_b64_to_bytes(encoded_image)))


def _bytes_to_list(content: bytes) -> List:
Expand Down
Loading
Loading