Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 34 additions & 34 deletions docs/source/en/guides/inference.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class InferenceClient:
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
defaults to hf-inference (Hugging Face Serverless Inference API).
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class AsyncInferenceClient:
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
provider (`str`, *optional*):
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
defaults to hf-inference (Hugging Face Serverless Inference API).
If model is a URL or `base_url` is passed, then `provider` is not used.
token (`str`, *optional*):
Expand Down
5 changes: 5 additions & 0 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ._common import TaskProviderHelper
from .black_forest_labs import BlackForestLabsTextToImageTask
from .cohere import CohereConversationalTask
from .fal_ai import (
FalAIAutomaticSpeechRecognitionTask,
FalAITextToImageTask,
Expand All @@ -20,6 +21,7 @@

PROVIDER_T = Literal[
"black-forest-labs",
"cohere",
"fal-ai",
"fireworks-ai",
"hf-inference",
Expand All @@ -35,6 +37,9 @@
"black-forest-labs": {
"text-to-image": BlackForestLabsTextToImageTask(),
},
"cohere": {
"conversational": CohereConversationalTask(),
},
"fal-ai": {
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
"text-to-image": FalAITextToImageTask(),
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#
# Example:
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"cohere": {},
"fal-ai": {},
"fireworks-ai": {},
"hf-inference": {},
Expand Down
15 changes: 15 additions & 0 deletions src/huggingface_hub/inference/_providers/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from huggingface_hub.inference._providers._common import (
BaseConversationalTask,
)


_PROVIDER = "cohere"
_BASE_URL = "https://api.cohere.com"


class CohereConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)

def _prepare_route(self, mapped_model: str) -> str:
return "/compatibility/v1/chat/completions"
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is deep learning?"}], "model": "command-r7b-12-2024",
"stream": false}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate
Connection:
- keep-alive
Content-Length:
- '181'
Content-Type:
- application/json
X-Amzn-Trace-Id:
- 204391c6-92c8-4214-a394-04b025f3e86a
method: POST
uri: https://api.cohere.com/compatibility/v1/chat/completions
response:
body:
string: '{"id":"3b5751bb-10a2-4fc8-95a0-d1e6cfa788b3","choices":[{"index":0,"finish_reason":"stop","message":{"role":"assistant","content":"Deep
learning is a subfield of machine learning and artificial intelligence that
focuses on training artificial neural networks to learn and make predictions
from data. It is inspired by the structure and function of the human brain,
particularly the interconnected network of neurons.\n\nIn deep learning, artificial
neural networks are composed of multiple layers of interconnected nodes, or
\"neurons,\" which process and transform input data. These networks are designed
to automatically learn and extract hierarchical representations of data through
a process called \"training.\" The training process involves adjusting the
network''s internal parameters (weights and biases) to minimize the difference
between predicted and actual outputs.\n\nHere are some key characteristics
and concepts in deep learning:\n\n1. Neural Networks: Deep learning models
are primarily based on artificial neural networks, which are composed of layers
of nodes. These networks can have various architectures, such as convolutional
neural networks (CNNs) for image processing, recurrent neural networks (RNNs)
for sequential data, and transformer networks for natural language processing.\n\n2.
Deep Architecture: The term \"deep\" in deep learning refers to the depth
of the neural network, meaning it has multiple hidden layers between the input
and output layers. These hidden layers enable the network to learn complex
patterns and representations from the data.\n\n3. Learning and Training: Deep
learning models are trained using large amounts of labeled data and a process
called backpropagation. During training, the network adjusts its internal
parameters to minimize a loss function, which measures the difference between
predicted and actual outputs. This optimization process is typically done
using gradient descent or its variants.\n\n4. Feature Learning: One of the
key advantages of deep learning is its ability to automatically learn relevant
features from raw data. Unlike traditional machine learning, where feature
engineering is required, deep learning models can discover and extract features
at multiple levels of abstraction.\n\n5. Applications: Deep learning has been
applied to a wide range of tasks and domains, including image and speech recognition,
natural language processing, object detection, medical diagnosis, game playing
(e.g., AlphaGo), and autonomous driving.\n\nDeep learning has revolutionized
many areas of artificial intelligence due to its ability to handle complex
and large-scale data, learn hierarchical representations, and achieve state-of-the-art
performance in various tasks. It has driven significant advancements in areas
like computer vision, natural language understanding, and speech recognition."}}],"created":1740653732,"model":"command-r7b-12-2024","object":"chat.completion","usage":{"prompt_tokens":11,"completion_tokens":476,"total_tokens":487}}'
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Transfer-Encoding:
- chunked
Via:
- 1.1 google
access-control-expose-headers:
- X-Debug-Trace-ID
cache-control:
- no-cache, no-store, no-transform, must-revalidate, private, max-age=0
content-type:
- application/json
date:
- Thu, 27 Feb 2025 10:55:32 GMT
expires:
- Thu, 01 Jan 1970 00:00:00 UTC
num_chars:
- '2831'
num_tokens:
- '487'
pragma:
- no-cache
server:
- envoy
vary:
- Origin
x-accel-expires:
- '0'
x-api-warning:
- Please set an API version, for more information please refer to https://docs.cohere.com/versioning-reference
- Version is deprecated, for more information please refer to https://docs.cohere.com/versioning-reference
x-debug-trace-id:
- 430c1e5519b95b094771bcc36304445e
x-envoy-upstream-service-time:
- '2740'
x-trial-endpoint-call-limit:
- '100'
x-trial-endpoint-call-remaining:
- '99'
status:
code: 200
message: OK
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is deep learning?"}], "model": "command-r7b-12-2024",
"max_tokens": 20, "stream": true}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate
Connection:
- keep-alive
Content-Length:
- '198'
Content-Type:
- application/json
X-Amzn-Trace-Id:
- 68c492d9-abbd-4d0a-8462-e598765021e4
method: POST
uri: https://api.cohere.com/compatibility/v1/chat/completions
response:
body:
string: 'data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"","role":"assistant"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"Deep"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
learning"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
is"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
a"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
sub"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"field"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
of"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
machine"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
learning"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
and"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
artificial"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
intelligence"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
that"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
focuses"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
on"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
training"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
artificial"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
neural"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
networks"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}


data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":"length","delta":{}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk","usage":{"prompt_tokens":11,"completion_tokens":19,"total_tokens":30}}


data: [DONE]


'
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Transfer-Encoding:
- chunked
Via:
- 1.1 google
access-control-expose-headers:
- X-Debug-Trace-ID
cache-control:
- no-cache, no-store, no-transform, must-revalidate, private, max-age=0
content-type:
- text/event-stream
date:
- Thu, 27 Feb 2025 10:55:33 GMT
expires:
- Thu, 01 Jan 1970 00:00:00 UTC
pragma:
- no-cache
server:
- envoy
vary:
- Origin
x-accel-expires:
- '0'
x-api-warning:
- Please set an API version, for more information please refer to https://docs.cohere.com/versioning-reference
- Version is deprecated, for more information please refer to https://docs.cohere.com/versioning-reference
x-debug-trace-id:
- 4bc0ce4bda5305b5b60ef6268db5e3a7
x-envoy-upstream-service-time:
- '88'
x-trial-endpoint-call-limit:
- '100'
x-trial-endpoint-call-remaining:
- '98'
status:
code: 200
message: OK
version: 1
3 changes: 3 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
"black-forest-labs": {
"text-to-image": "black-forest-labs/FLUX.1-dev",
},
"cohere": {
"conversational": "CohereForAI/c4ai-command-r7b-12-2024",
},
"together": {
"conversational": "meta-llama/Meta-Llama-3-8B-Instruct",
"text-generation": "meta-llama/Llama-2-70b-hf",
Expand Down
19 changes: 19 additions & 0 deletions tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
recursive_merge,
)
from huggingface_hub.inference._providers.black_forest_labs import BlackForestLabsTextToImageTask
from huggingface_hub.inference._providers.cohere import CohereConversationalTask
from huggingface_hub.inference._providers.fal_ai import (
FalAIAutomaticSpeechRecognitionTask,
FalAITextToImageTask,
Expand Down Expand Up @@ -110,6 +111,24 @@ def test_get_response_success(self, mocker):
)


class TestCohereConversationalTask:
def test_prepare_url(self):
helper = CohereConversationalTask()
assert helper.task == "conversational"
url = helper._prepare_url("cohere_token", "username/repo_name")
assert url == "https://api.cohere.com/compatibility/v1/chat/completions"

def test_prepare_payload_as_dict(self):
helper = CohereConversationalTask()
payload = helper._prepare_payload_as_dict(
[{"role": "user", "content": "Hello!"}], {}, "CohereForAI/command-r7b-12-2024"
)
assert payload == {
"messages": [{"role": "user", "content": "Hello!"}],
"model": "CohereForAI/command-r7b-12-2024",
}


class TestFalAIProvider:
def test_prepare_headers_fal_ai_key(self):
"""When using direct call, must use Key authorization."""
Expand Down