Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5c1a209
Add first version of third-party providers support
hanouticelina Jan 17, 2025
c142995
add task level in model id mappings
hanouticelina Jan 17, 2025
051ee76
raise error when task is not supported by a provider + some improvements
hanouticelina Jan 17, 2025
ab2b44f
small (big) refactoring
hanouticelina Jan 17, 2025
a896f3e
multiple fixes
hanouticelina Jan 20, 2025
49d389a
add hf inference tasks
hanouticelina Jan 20, 2025
20e8d3a
Handle hf_inference in single file (#2766)
Wauplin Jan 20, 2025
ceca530
harmonize prepare_payload args and add automatic-speech-recognition task
hanouticelina Jan 21, 2025
a076eb4
backward compatibility with custom urls
hanouticelina Jan 21, 2025
aca6050
first draft of tests
hanouticelina Jan 22, 2025
c2bdcc2
InferenceClient as fixture + skip if no api_key
Wauplin Jan 22, 2025
4489069
give name to parametrized tests
Wauplin Jan 22, 2025
5f9d946
upload cassettes
hanouticelina Jan 22, 2025
e1a379c
make quali
Wauplin Jan 22, 2025
fec77a6
download sample files from prod
Wauplin Jan 22, 2025
a731eec
fix python3.8
Wauplin Jan 22, 2025
9b209b8
small improvement for better readability
hanouticelina Jan 22, 2025
28825cb
make style
Wauplin Jan 22, 2025
a0208c9
fixing more tests
hanouticelina Jan 22, 2025
8f2eb6c
Merge branch 'inference-providers-compatibility' of github.com:huggin…
hanouticelina Jan 22, 2025
456122f
test url building
hanouticelina Jan 22, 2025
d5dcf8f
fix and record async client tests
hanouticelina Jan 22, 2025
4ba4ab4
re-add cassettes
hanouticelina Jan 22, 2025
65b659d
fix
hanouticelina Jan 22, 2025
ae6f2af
add cassettes back
hanouticelina Jan 22, 2025
4d50893
fix test
hanouticelina Jan 22, 2025
9d557f0
hopefully this will fix the test
hanouticelina Jan 22, 2025
92f62fc
fix sentence similarity test
hanouticelina Jan 22, 2025
8ebbe80
Merge branch 'main' into inference-providers-compatibility
hanouticelina Jan 22, 2025
2223998
Merge branch 'main' into inference-providers-compatibility
Wauplin Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 63 additions & 86 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
ZeroShotClassificationOutputElement,
ZeroShotImageClassificationOutputElement,
)
from huggingface_hub.inference._providers import get_provider_helper
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from huggingface_hub.utils._deprecation import _deprecate_arguments

Expand All @@ -123,7 +124,7 @@ class InferenceClient:
Initialize a new Inference Client.

[`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used
seamlessly with either the (free) Inference API or self-hosted Inference Endpoints.
seamlessly with either the (free) Inference API, self-hosted Inference Endpoints, or third-party Inference Providers.

Args:
model (`str`, `optional`):
Expand All @@ -134,6 +135,9 @@ class InferenceClient:
arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"replicate"`, `"together"`, `"fal-ai"`, `"sambanova"` or `"hf-inference"`.
Defaults to `"hf-inference"` (Hugging Face Serverless Inference API).
token (`str` or `bool`, *optional*):
Hugging Face token. Will default to the locally saved token if not provided.
Pass `token=False` if you don't want to send your token to the server.
Expand Down Expand Up @@ -161,6 +165,7 @@ def __init__(
self,
model: Optional[str] = None,
*,
provider: Optional[str] = None,
token: Union[str, bool, None] = None,
timeout: Optional[float] = None,
headers: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -191,6 +196,12 @@ def __init__(
)
if headers is not None:
self.headers.update(headers)

# Configure provider
if provider is None:
provider = "hf-inference"
self.provider = provider

self.cookies = cookies
self.timeout = timeout
self.proxies = proxies
Expand Down Expand Up @@ -273,7 +284,12 @@ def post(
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.
"""
url = self._resolve_url(model, task)
# TODO: either we do that or we pass the url as a parameter
if model is not None and (model.startswith("http://") or model.startswith("https://")):
url = model
else:
provider_helper = get_provider_helper(self.provider, task=task)
url = provider_helper.build_url(model=model)

if data is not None and json is not None:
warnings.warn("Ignoring `json` as `data` is passed as binary.")
Expand Down Expand Up @@ -854,68 +870,49 @@ def chat_completion(
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
```
"""
model_url = self._resolve_chat_completion_url(model)
# Resolve model ID with precedence: method argument > instance model > default
model_id = model or self.model or "tgi"
# Get the provider helper
provider_helper = get_provider_helper(self.provider, task="conversational")
# ?: should we update the headers here?
self.headers = provider_helper.prepare_headers(headers=self.headers, **{"token": self.token})
# Get the mapped provider model ID
model_id = provider_helper.map_model(model=model_id)
# Build the URL for the provider
model_url = provider_helper.build_url(model=model_id)

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
model_id = model or self.model or "tgi"
if model_id.startswith(("http://", "https://")):
model_id = "tgi" # dummy value

payload = dict(
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
temperature=temperature,
tool_choice=tool_choice,
tool_prompt=tool_prompt,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
stream=stream,
stream_options=stream_options,
)
payload = {key: value for key, value in payload.items() if value is not None}
# For URLs, use "tgi" as model name in payload
payload_model = "tgi" if model_id.startswith(("http://", "https://")) else model_id
parameters = {
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_tokens": max_tokens,
"n": n,
"presence_penalty": presence_penalty,
"response_format": response_format,
"seed": seed,
"stop": stop,
"temperature": temperature,
"tool_choice": tool_choice,
"tool_prompt": tool_prompt,
"tools": tools,
"top_logprobs": top_logprobs,
"top_p": top_p,
"stream": stream,
"stream_options": stream_options,
}
# Prepare the payload
payload = provider_helper.prepare_payload(inputs=messages, parameters=parameters, model=payload_model)
data = self.post(model=model_url, json=payload, stream=stream)

if stream:
return _stream_chat_completion_response(data) # type: ignore[arg-type]

return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]

def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
# `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")

# Resolve URL if it's a model ID
model_url = (
model_id_or_url
if model_id_or_url.startswith(("http://", "https://"))
else self._resolve_url(model_id_or_url, task="text-generation")
)

# Strip trailing /
model_url = model_url.rstrip("/")

# Append /chat/completions if not already present
if model_url.endswith("/v1"):
model_url += "/chat/completions"

# Append /v1/chat/completions if not already present
if not model_url.endswith("/chat/completions"):
model_url += "/v1/chat/completions"

return model_url

def document_question_answering(
self,
image: ContentT,
Expand Down Expand Up @@ -2426,7 +2423,6 @@ def text_to_image(
>>> image.save("better_astronaut.png")
```
"""

parameters = {
"negative_prompt": negative_prompt,
"height": height,
Expand All @@ -2438,8 +2434,19 @@ def text_to_image(
"seed": seed,
**kwargs,
}
payload = _prepare_payload(prompt, parameters=parameters)

model = model or self.model
provider_helper = get_provider_helper(self.provider, task="text-to-image")
self.headers = provider_helper.prepare_headers(headers=self.headers, **{"token": self.token})
model = provider_helper.map_model(model=model)
payload = provider_helper.prepare_payload(
prompt,
parameters=parameters,
model=model,
)

response = self.post(**payload, model=model, task="text-to-image")
response = provider_helper.get_response(response)
return _bytes_to_image(response)

def text_to_speech(
Expand Down Expand Up @@ -2966,36 +2973,6 @@ def zero_shot_image_classification(
)
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)

def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
model = model or self.model or self.base_url

# If model is already a URL, ignore `task` and return directly
if model is not None and (model.startswith("http://") or model.startswith("https://")):
return model

# # If no model but task is set => fetch the recommended one for this task
if model is None:
if task is None:
raise ValueError(
"You must specify at least a model (repo_id or URL) or a task, either when instantiating"
" `InferenceClient` or when making a request."
)
model = self.get_recommended_model(task)
logger.info(
f"Using recommended model {model} for task {task}. Note that it is"
f" encouraged to explicitly set `model='{model}'` as the recommended"
" models list might get updated without prior notice."
)

# Compute InferenceAPI url
return (
# Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks.
f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}"
if task in ("feature-extraction", "sentence-similarity")
# Otherwise, we use the default endpoint
else f"{INFERENCE_ENDPOINT}/models/{model}"
)

@staticmethod
def get_recommended_model(task: str) -> str:
"""
Expand Down
Loading
Loading