Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5c1a209
Add first version of third-party providers support
hanouticelina Jan 17, 2025
c142995
add task level in model id mappings
hanouticelina Jan 17, 2025
051ee76
raise error when task is not supported by a provider + some improvements
hanouticelina Jan 17, 2025
ab2b44f
small (big) refactoring
hanouticelina Jan 17, 2025
a896f3e
multiple fixes
hanouticelina Jan 20, 2025
49d389a
add hf inference tasks
hanouticelina Jan 20, 2025
20e8d3a
Handle hf_inference in single file (#2766)
Wauplin Jan 20, 2025
ceca530
harmonize prepare_payload args and add automatic-speech-recognition task
hanouticelina Jan 21, 2025
a076eb4
backward compatibility with custom urls
hanouticelina Jan 21, 2025
aca6050
first draft of tests
hanouticelina Jan 22, 2025
c2bdcc2
InferenceClient as fixture + skip if no api_key
Wauplin Jan 22, 2025
4489069
give name to parametrized tests
Wauplin Jan 22, 2025
5f9d946
upload cassettes
hanouticelina Jan 22, 2025
e1a379c
make quali
Wauplin Jan 22, 2025
fec77a6
download sample files from prod
Wauplin Jan 22, 2025
a731eec
fix python3.8
Wauplin Jan 22, 2025
9b209b8
small improvement for better readability
hanouticelina Jan 22, 2025
28825cb
make style
Wauplin Jan 22, 2025
a0208c9
fixing more tests
hanouticelina Jan 22, 2025
8f2eb6c
Merge branch 'inference-providers-compatibility' of github.com:huggin…
hanouticelina Jan 22, 2025
456122f
test url building
hanouticelina Jan 22, 2025
d5dcf8f
fix and record async client tests
hanouticelina Jan 22, 2025
4ba4ab4
re-add cassettes
hanouticelina Jan 22, 2025
65b659d
fix
hanouticelina Jan 22, 2025
ae6f2af
add cassettes back
hanouticelina Jan 22, 2025
4d50893
fix test
hanouticelina Jan 22, 2025
9d557f0
hopefully this will fix the test
hanouticelina Jan 22, 2025
92f62fc
fix sentence similarity test
hanouticelina Jan 22, 2025
8ebbe80
Merge branch 'main' into inference-providers-compatibility
hanouticelina Jan 22, 2025
2223998
Merge branch 'main' into inference-providers-compatibility
Wauplin Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
356 changes: 207 additions & 149 deletions src/huggingface_hub/inference/_client.py

Large diffs are not rendered by default.

27 changes: 0 additions & 27 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,8 @@
ValidationError,
)

from ..constants import ENDPOINT
from ..utils import (
build_hf_headers,
get_session,
hf_raise_for_status,
is_aiohttp_available,
is_numpy_available,
is_pillow_available,
Expand Down Expand Up @@ -141,30 +138,6 @@ def _import_pil_image():
return Image


## RECOMMENDED MODELS

# Will be globally fetched only once (see '_fetch_recommended_models')
_RECOMMENDED_MODELS: Optional[Dict[str, Optional[str]]] = None


def _fetch_recommended_models() -> Dict[str, Optional[str]]:
global _RECOMMENDED_MODELS
if _RECOMMENDED_MODELS is None:
response = get_session().get(f"{ENDPOINT}/api/tasks", headers=build_hf_headers())
hf_raise_for_status(response)
_RECOMMENDED_MODELS = {
task: _first_or_none(details["widgetModels"]) for task, details in response.json().items()
}
return _RECOMMENDED_MODELS


def _first_or_none(items: List[Any]) -> Optional[Any]:
try:
return items[0] or None
except IndexError:
return None


## ENCODING / DECODING UTILS


Expand Down
358 changes: 208 additions & 150 deletions src/huggingface_hub/inference/_generated/_async_client.py

Large diffs are not rendered by default.

85 changes: 85 additions & 0 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# mypy: disable-error-code="dict-item"
from typing import Any, Dict, Optional, Protocol, Union

from . import fal_ai, replicate, sambanova, together
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask


class TaskProviderHelper(Protocol):
"""Protocol defining the interface for task-specific provider helpers."""

def build_url(self, model: Optional[str] = None) -> str: ...
def map_model(self, model: Optional[str] = None) -> str: ...
def prepare_headers(self, headers: Dict, *, token: Optional[str] = None) -> Dict: ...
def prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]: ...
def get_response(self, response: Union[bytes, Dict]) -> Any: ...


PROVIDERS: Dict[str, Dict[str, TaskProviderHelper]] = {
"replicate": {
"text-to-image": replicate.text_to_image,
},
"fal-ai": {
"text-to-image": fal_ai.text_to_image,
"automatic-speech-recognition": fal_ai.automatic_speech_recognition,
},
"sambanova": {
"conversational": sambanova.conversational,
},
"together": {
"text-to-image": together.text_to_image,
"conversational": together.conversational,
"text-generation": together.text_generation,
},
"hf-inference": {
"text-to-image": HFInferenceTask("text-to-image"),
"conversational": HFInferenceConversational(),
"text-generation": HFInferenceTask("text-generation"),
"text-classification": HFInferenceTask("text-classification"),
"question-answering": HFInferenceTask("question-answering"),
"audio-classification": HFInferenceBinaryInputTask("audio-classification"),
"automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"),
"fill-mask": HFInferenceTask("fill-mask"),
"feature-extraction": HFInferenceTask("feature-extraction"),
"image-classification": HFInferenceBinaryInputTask("image-classification"),
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
"document-question-answering": HFInferenceTask("document-question-answering"),
"image-to-text": HFInferenceTask("image-to-text"),
"object-detection": HFInferenceBinaryInputTask("object-detection"),
"audio-to-audio": HFInferenceTask("audio-to-audio"),
"zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"),
"zero-shot-classification": HFInferenceTask("zero-shot-classification"),
"image-to-image": HFInferenceBinaryInputTask("image-to-image"),
"sentence-similarity": HFInferenceTask("sentence-similarity"),
"table-question-answering": HFInferenceTask("table-question-answering"),
"tabular-classification": HFInferenceTask("tabular-classification"),
"text-to-speech": HFInferenceTask("text-to-speech"),
"token-classification": HFInferenceTask("token-classification"),
"translation": HFInferenceTask("translation"),
"summarization": HFInferenceTask("summarization"),
"visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"),
},
}


def get_provider_helper(provider: str, task: str) -> TaskProviderHelper:
"""Get provider helper instance by name and task.

Args:
provider (str): Name of the provider
task (str): Name of the task

Returns:
TaskProviderHelper: Helper instance for the specified provider and task

Raises:
ValueError: If provider or task is not supported
"""
if provider not in PROVIDERS:
raise ValueError(f"Provider '{provider}' not supported. Available providers: {list(PROVIDERS.keys())}")
if task not in PROVIDERS[provider]:
raise ValueError(
f"Task '{task}' not supported for provider '{provider}'. "
f"Available tasks: {list(PROVIDERS[provider].keys())}"
)
return PROVIDERS[provider][task]
2 changes: 2 additions & 0 deletions src/huggingface_hub/inference/_providers/fal_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# ruff: noqa: F401
from . import automatic_speech_recognition, text_to_image
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import base64
import json
from typing import Any, Dict, Optional, Union


BASE_URL = "https://fal.run"

SUPPORTED_MODELS = {
"openai/whisper-large-v3": "fal-ai/whisper",
}


def build_url(model: Optional[str] = None) -> str:
return f"{BASE_URL}/{model}"


def map_model(model: str) -> str:
mapped_model = SUPPORTED_MODELS.get(model)
if mapped_model is None:
raise ValueError(f"Model {model} is not supported for Fal.AI automatic-speech-recognition task")
return mapped_model


def prepare_headers(headers: Dict, *, token: Optional[str] = None) -> Dict:
return {
**headers,
"authorization": f"Key {token}",
}


def prepare_payload(
inputs: Any,
parameters: Dict[str, Any],
) -> Dict[str, Any]:
if isinstance(inputs, str) and (inputs.startswith("http://") or inputs.startswith("https://")):
# If input is a URL, pass it directly
audio_url = inputs
else:
# If input is a file path, read it first
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()

audio_b64 = base64.b64encode(inputs).decode()
content_type = "audio/mpeg"
audio_url = f"data:{content_type};base64,{audio_b64}"

return {
"json": {
"audio_url": audio_url,
**{k: v for k, v in parameters.items() if v is not None},
}
}


def get_response(response: Union[bytes, Dict]) -> Any:
if isinstance(response, bytes):
response_dict = json.loads(response)
else:
response_dict = response
if not isinstance(response_dict["text"], str):
raise ValueError("Unexpected output format from API. Expected string.")
return response_dict["text"]
52 changes: 52 additions & 0 deletions src/huggingface_hub/inference/_providers/fal_ai/text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import json
from typing import Any, Dict, Optional, Union

from huggingface_hub.utils import get_session


BASE_URL = "https://fal.run"

SUPPORTED_MODELS = {
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
}


def build_url(model: Optional[str] = None) -> str:
return f"{BASE_URL}/{model}"


def map_model(model: str) -> str:
mapped_model = SUPPORTED_MODELS.get(model)
if mapped_model is None:
raise ValueError(f"Model {model} is not supported for Fal.AI text-to-image task")
return mapped_model


def prepare_headers(headers: Dict, *, token: Optional[str] = None) -> Dict:
return {
**headers,
"Authorization": f"Key {token}",
}


def prepare_payload(inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
parameters = {k: v for k, v in parameters.items() if v is not None}
if "image_size" not in parameters and "width" in parameters and "height" in parameters:
parameters["image_size"] = {
"width": parameters.pop("width"),
"height": parameters.pop("height"),
}

return {
"json": {"prompt": inputs, **parameters},
}


def get_response(response: Union[bytes, Dict]) -> Any:
if isinstance(response, bytes):
response_dict = json.loads(response) # type: ignore
else:
response_dict = response
url = response_dict["images"][0]["url"]
return get_session().get(url).content
142 changes: 142 additions & 0 deletions src/huggingface_hub/inference/_providers/hf_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import logging
from pathlib import Path
from typing import Any, BinaryIO, Dict, List, Optional, Union

from huggingface_hub.constants import ENDPOINT
from huggingface_hub.inference._common import _b64_encode, _open_as_binary
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status


# TYPES
UrlT = str
PathT = Union[str, Path]
BinaryT = Union[bytes, BinaryIO]
ContentT = Union[BinaryT, PathT, UrlT]

# Use to set a Accept: image/png header
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}

logger = logging.getLogger(__name__)


## RECOMMENDED MODELS

# Will be globally fetched only once (see '_fetch_recommended_models')
_RECOMMENDED_MODELS: Optional[Dict[str, Optional[str]]] = None

BASE_URL = "https://api-inference.huggingface.co"


def _first_or_none(items: List[Any]) -> Optional[Any]:
try:
return items[0] or None
except IndexError:
return None


def _fetch_recommended_models() -> Dict[str, Optional[str]]:
global _RECOMMENDED_MODELS
if _RECOMMENDED_MODELS is None:
response = get_session().get(f"{ENDPOINT}/api/tasks", headers=build_hf_headers())
hf_raise_for_status(response)
_RECOMMENDED_MODELS = {
task: _first_or_none(details["widgetModels"]) for task, details in response.json().items()
}
return _RECOMMENDED_MODELS


def get_recommended_model(task: str) -> str:
"""
Get the model Hugging Face recommends for the input task.

Args:
task (`str`):
The Hugging Face task to get which model Hugging Face recommends.
All available tasks can be found [here](https://huggingface.co/tasks).

Returns:
`str`: Name of the model recommended for the input task.

Raises:
`ValueError`: If Hugging Face has no recommendation for the input task.
"""
model = _fetch_recommended_models().get(task)
if model is None:
raise ValueError(
f"Task {task} has no recommended model. Please specify a model"
" explicitly. Visit https://huggingface.co/tasks for more info."
)
return model


class HFInferenceTask:
"""Base class for HF Inference API tasks."""

def __init__(self, task: str):
self.task = task

def build_url(self, model: Optional[str] = None) -> str:
if model is None:
model = get_recommended_model(self.task)
return (
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
f"{BASE_URL}/pipeline/{self.task}/{model}"
if self.task in ("feature-extraction", "sentence-similarity")
# Otherwise, we use the default endpoint
else f"{BASE_URL}/models/{model}"
)

def map_model(self, model: str) -> str:
return model if model is not None else get_recommended_model(self.task)

def prepare_headers(self, headers: Dict, *, token: Optional[str] = None) -> Dict:
return headers

def prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(inputs, (bytes, Path)):
raise ValueError(f"Unexpected binary inputs. Got {inputs}") # type: ignore
_ = parameters.pop("model")
return {
"json": {
"inputs": inputs,
"parameters": {k: v for k, v in parameters.items() if v is not None},
}
}

def get_response(self, response: Union[bytes, Dict]) -> Any:
return response


class HFInferenceBinaryInputTask(HFInferenceTask):
def prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
parameters = {k: v for k, v in parameters.items() if v is not None}
_ = parameters.pop("model") # model is not a valid parameter for hf-inference tasks
has_parameters = len(parameters) > 0

# Raise if not a binary object or a local path or a URL.
if not isinstance(inputs, (bytes, Path)) and not isinstance(inputs, str):
raise ValueError(f"Expected binary inputs or a local path or a URL. Got {inputs}")

# Send inputs as raw content when no parameters are provided
if not has_parameters:
with _open_as_binary(inputs) as data:
data_as_bytes = data if isinstance(data, bytes) else data.read()
return {"data": data_as_bytes}

# Otherwise encode as b64
return {"json": {"inputs": _b64_encode(inputs), "parameters": parameters}}


class HFInferenceConversational(HFInferenceTask):
def __init__(self):
super().__init__("conversational")

def build_url(self, model: Optional[str] = None) -> str:
if model is None:
model = get_recommended_model("text-generation")
return f"{BASE_URL}/models/{model}/v1/chat/completions"

def prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
parameters = {key: value for key, value in parameters.items() if value is not None}
model = parameters.get("model")
return {"model": model, "messages": inputs, **parameters}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# ruff: noqa: F401
from . import text_to_image
Loading
Loading