Skip to content

Commit ad7950b

Browse files
alexrs-cohereWauplinhanouticelina
committed
Add Cohere as an Inference Provider (#2888)
* Add Cohere as an Inference Provider * Use new Cohere OpenAI compatible API * Update docs/source/en/guides/inference.md * Update src/huggingface_hub/inference/_providers/cohere.py Co-authored-by: Célina <[email protected]> --------- Co-authored-by: Lucain <[email protected]> Co-authored-by: Lucain <[email protected]> Co-authored-by: Célina <[email protected]>
1 parent 9323080 commit ad7950b

File tree

10 files changed

+327
-36
lines changed

10 files changed

+327
-36
lines changed

docs/source/en/guides/inference.md

Lines changed: 34 additions & 34 deletions
Large diffs are not rendered by default.

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class InferenceClient:
132132
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
133133
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
134134
provider (`str`, *optional*):
135-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
135+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
136136
defaults to hf-inference (Hugging Face Serverless Inference API).
137137
If model is a URL or `base_url` is passed, then `provider` is not used.
138138
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class AsyncInferenceClient:
120120
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
121121
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
122122
provider (`str`, *optional*):
123-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
123+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
124124
defaults to hf-inference (Hugging Face Serverless Inference API).
125125
If model is a URL or `base_url` is passed, then `provider` is not used.
126126
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._common import TaskProviderHelper
44
from .black_forest_labs import BlackForestLabsTextToImageTask
5+
from .cohere import CohereConversationalTask
56
from .fal_ai import (
67
FalAIAutomaticSpeechRecognitionTask,
78
FalAITextToImageTask,
@@ -20,6 +21,7 @@
2021

2122
PROVIDER_T = Literal[
2223
"black-forest-labs",
24+
"cohere",
2325
"fal-ai",
2426
"fireworks-ai",
2527
"hf-inference",
@@ -35,6 +37,9 @@
3537
"black-forest-labs": {
3638
"text-to-image": BlackForestLabsTextToImageTask(),
3739
},
40+
"cohere": {
41+
"conversational": CohereConversationalTask(),
42+
},
3843
"fal-ai": {
3944
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
4045
"text-to-image": FalAITextToImageTask(),

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#
1818
# Example:
1919
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
20+
"cohere": {},
2021
"fal-ai": {},
2122
"fireworks-ai": {},
2223
"hf-inference": {},
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from huggingface_hub.inference._providers._common import (
2+
BaseConversationalTask,
3+
)
4+
5+
6+
_PROVIDER = "cohere"
7+
_BASE_URL = "https://api.cohere.com"
8+
9+
10+
class CohereConversationalTask(BaseConversationalTask):
11+
def __init__(self):
12+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
13+
14+
def _prepare_route(self, mapped_model: str) -> str:
15+
return "/compatibility/v1/chat/completions"
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
interactions:
2+
- request:
3+
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
4+
{"role": "user", "content": "What is deep learning?"}], "model": "command-r7b-12-2024",
5+
"stream": false}'
6+
headers:
7+
Accept:
8+
- '*/*'
9+
Accept-Encoding:
10+
- gzip, deflate
11+
Connection:
12+
- keep-alive
13+
Content-Length:
14+
- '181'
15+
Content-Type:
16+
- application/json
17+
X-Amzn-Trace-Id:
18+
- 204391c6-92c8-4214-a394-04b025f3e86a
19+
method: POST
20+
uri: https://api.cohere.com/compatibility/v1/chat/completions
21+
response:
22+
body:
23+
string: '{"id":"3b5751bb-10a2-4fc8-95a0-d1e6cfa788b3","choices":[{"index":0,"finish_reason":"stop","message":{"role":"assistant","content":"Deep
24+
learning is a subfield of machine learning and artificial intelligence that
25+
focuses on training artificial neural networks to learn and make predictions
26+
from data. It is inspired by the structure and function of the human brain,
27+
particularly the interconnected network of neurons.\n\nIn deep learning, artificial
28+
neural networks are composed of multiple layers of interconnected nodes, or
29+
\"neurons,\" which process and transform input data. These networks are designed
30+
to automatically learn and extract hierarchical representations of data through
31+
a process called \"training.\" The training process involves adjusting the
32+
network''s internal parameters (weights and biases) to minimize the difference
33+
between predicted and actual outputs.\n\nHere are some key characteristics
34+
and concepts in deep learning:\n\n1. Neural Networks: Deep learning models
35+
are primarily based on artificial neural networks, which are composed of layers
36+
of nodes. These networks can have various architectures, such as convolutional
37+
neural networks (CNNs) for image processing, recurrent neural networks (RNNs)
38+
for sequential data, and transformer networks for natural language processing.\n\n2.
39+
Deep Architecture: The term \"deep\" in deep learning refers to the depth
40+
of the neural network, meaning it has multiple hidden layers between the input
41+
and output layers. These hidden layers enable the network to learn complex
42+
patterns and representations from the data.\n\n3. Learning and Training: Deep
43+
learning models are trained using large amounts of labeled data and a process
44+
called backpropagation. During training, the network adjusts its internal
45+
parameters to minimize a loss function, which measures the difference between
46+
predicted and actual outputs. This optimization process is typically done
47+
using gradient descent or its variants.\n\n4. Feature Learning: One of the
48+
key advantages of deep learning is its ability to automatically learn relevant
49+
features from raw data. Unlike traditional machine learning, where feature
50+
engineering is required, deep learning models can discover and extract features
51+
at multiple levels of abstraction.\n\n5. Applications: Deep learning has been
52+
applied to a wide range of tasks and domains, including image and speech recognition,
53+
natural language processing, object detection, medical diagnosis, game playing
54+
(e.g., AlphaGo), and autonomous driving.\n\nDeep learning has revolutionized
55+
many areas of artificial intelligence due to its ability to handle complex
56+
and large-scale data, learn hierarchical representations, and achieve state-of-the-art
57+
performance in various tasks. It has driven significant advancements in areas
58+
like computer vision, natural language understanding, and speech recognition."}}],"created":1740653732,"model":"command-r7b-12-2024","object":"chat.completion","usage":{"prompt_tokens":11,"completion_tokens":476,"total_tokens":487}}'
59+
headers:
60+
Alt-Svc:
61+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
62+
Transfer-Encoding:
63+
- chunked
64+
Via:
65+
- 1.1 google
66+
access-control-expose-headers:
67+
- X-Debug-Trace-ID
68+
cache-control:
69+
- no-cache, no-store, no-transform, must-revalidate, private, max-age=0
70+
content-type:
71+
- application/json
72+
date:
73+
- Thu, 27 Feb 2025 10:55:32 GMT
74+
expires:
75+
- Thu, 01 Jan 1970 00:00:00 UTC
76+
num_chars:
77+
- '2831'
78+
num_tokens:
79+
- '487'
80+
pragma:
81+
- no-cache
82+
server:
83+
- envoy
84+
vary:
85+
- Origin
86+
x-accel-expires:
87+
- '0'
88+
x-api-warning:
89+
- Please set an API version, for more information please refer to https://docs.cohere.com/versioning-reference
90+
- Version is deprecated, for more information please refer to https://docs.cohere.com/versioning-reference
91+
x-debug-trace-id:
92+
- 430c1e5519b95b094771bcc36304445e
93+
x-envoy-upstream-service-time:
94+
- '2740'
95+
x-trial-endpoint-call-limit:
96+
- '100'
97+
x-trial-endpoint-call-remaining:
98+
- '99'
99+
status:
100+
code: 200
101+
message: OK
102+
version: 1
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
interactions:
2+
- request:
3+
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
4+
{"role": "user", "content": "What is deep learning?"}], "model": "command-r7b-12-2024",
5+
"max_tokens": 20, "stream": true}'
6+
headers:
7+
Accept:
8+
- '*/*'
9+
Accept-Encoding:
10+
- gzip, deflate
11+
Connection:
12+
- keep-alive
13+
Content-Length:
14+
- '198'
15+
Content-Type:
16+
- application/json
17+
X-Amzn-Trace-Id:
18+
- 68c492d9-abbd-4d0a-8462-e598765021e4
19+
method: POST
20+
uri: https://api.cohere.com/compatibility/v1/chat/completions
21+
response:
22+
body:
23+
string: 'data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"","role":"assistant"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
24+
25+
26+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"Deep"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
27+
28+
29+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
30+
learning"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
31+
32+
33+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
34+
is"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
35+
36+
37+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
38+
a"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
39+
40+
41+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
42+
sub"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
43+
44+
45+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"field"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
46+
47+
48+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
49+
of"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
50+
51+
52+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
53+
machine"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
54+
55+
56+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
57+
learning"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
58+
59+
60+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
61+
and"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
62+
63+
64+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
65+
artificial"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
66+
67+
68+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
69+
intelligence"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
70+
71+
72+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
73+
that"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
74+
75+
76+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
77+
focuses"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
78+
79+
80+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
81+
on"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
82+
83+
84+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
85+
training"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
86+
87+
88+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
89+
artificial"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
90+
91+
92+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
93+
neural"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
94+
95+
96+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":null,"delta":{"content":"
97+
networks"}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk"}
98+
99+
100+
data: {"id":"2bb1b33e-53d9-4fae-8958-2e54c1e60f09","choices":[{"index":0,"finish_reason":"length","delta":{}}],"created":1740653733,"model":"command-r7b-12-2024","object":"chat.completion.chunk","usage":{"prompt_tokens":11,"completion_tokens":19,"total_tokens":30}}
101+
102+
103+
data: [DONE]
104+
105+
106+
'
107+
headers:
108+
Alt-Svc:
109+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
110+
Transfer-Encoding:
111+
- chunked
112+
Via:
113+
- 1.1 google
114+
access-control-expose-headers:
115+
- X-Debug-Trace-ID
116+
cache-control:
117+
- no-cache, no-store, no-transform, must-revalidate, private, max-age=0
118+
content-type:
119+
- text/event-stream
120+
date:
121+
- Thu, 27 Feb 2025 10:55:33 GMT
122+
expires:
123+
- Thu, 01 Jan 1970 00:00:00 UTC
124+
pragma:
125+
- no-cache
126+
server:
127+
- envoy
128+
vary:
129+
- Origin
130+
x-accel-expires:
131+
- '0'
132+
x-api-warning:
133+
- Please set an API version, for more information please refer to https://docs.cohere.com/versioning-reference
134+
- Version is deprecated, for more information please refer to https://docs.cohere.com/versioning-reference
135+
x-debug-trace-id:
136+
- 4bc0ce4bda5305b5b60ef6268db5e3a7
137+
x-envoy-upstream-service-time:
138+
- '88'
139+
x-trial-endpoint-call-limit:
140+
- '100'
141+
x-trial-endpoint-call-remaining:
142+
- '98'
143+
status:
144+
code: 200
145+
message: OK
146+
version: 1

tests/test_inference_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@
6363
"black-forest-labs": {
6464
"text-to-image": "black-forest-labs/FLUX.1-dev",
6565
},
66+
"cohere": {
67+
"conversational": "CohereForAI/c4ai-command-r7b-12-2024",
68+
},
6669
"together": {
6770
"conversational": "meta-llama/Meta-Llama-3-8B-Instruct",
6871
"text-generation": "meta-llama/Llama-2-70b-hf",

tests/test_inference_providers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
recursive_merge,
1010
)
1111
from huggingface_hub.inference._providers.black_forest_labs import BlackForestLabsTextToImageTask
12+
from huggingface_hub.inference._providers.cohere import CohereConversationalTask
1213
from huggingface_hub.inference._providers.fal_ai import (
1314
FalAIAutomaticSpeechRecognitionTask,
1415
FalAITextToImageTask,
@@ -110,6 +111,24 @@ def test_get_response_success(self, mocker):
110111
)
111112

112113

114+
class TestCohereConversationalTask:
115+
def test_prepare_url(self):
116+
helper = CohereConversationalTask()
117+
assert helper.task == "conversational"
118+
url = helper._prepare_url("cohere_token", "username/repo_name")
119+
assert url == "https://api.cohere.com/compatibility/v1/chat/completions"
120+
121+
def test_prepare_payload_as_dict(self):
122+
helper = CohereConversationalTask()
123+
payload = helper._prepare_payload_as_dict(
124+
[{"role": "user", "content": "Hello!"}], {}, "CohereForAI/command-r7b-12-2024"
125+
)
126+
assert payload == {
127+
"messages": [{"role": "user", "content": "Hello!"}],
128+
"model": "CohereForAI/command-r7b-12-2024",
129+
}
130+
131+
113132
class TestFalAIProvider:
114133
def test_prepare_headers_fal_ai_key(self):
115134
"""When using direct call, must use Key authorization."""

0 commit comments

Comments
 (0)