Skip to content

Commit 2f02968

Browse files
Add Cohere provider (#1202)
### What Adds Cohere as an inference provider. ### Test Plan Added new tests for Cohere both with and without streaming. ### What Should Reviewers Focus On? Is the implementation correct? Anything important that I missed? Also happy to get feedback on the code, I am a bit rusty with my JS! --------- Co-authored-by: SBrandeis <[email protected]>
1 parent 3857938 commit 2f02968

File tree

10 files changed

+158
-6
lines changed

10 files changed

+158
-6
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ jobs:
4242
env:
4343
HF_TOKEN: ${{ secrets.HF_TOKEN }}
4444
HF_BLACK_FOREST_LABS_KEY: dummy
45+
HF_COHERE_KEY: dummy
4546
HF_FAL_KEY: dummy
4647
HF_FIREWORKS_KEY: dummy
4748
HF_HYPERBOLIC_KEY: dummy
@@ -87,6 +88,7 @@ jobs:
8788
env:
8889
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8990
HF_BLACK_FOREST_LABS_KEY: dummy
91+
HF_COHERE_KEY: dummy
9092
HF_FAL_KEY: dummy
9193
HF_FIREWORKS_KEY: dummy
9294
HF_HYPERBOLIC_KEY: dummy
@@ -159,6 +161,7 @@ jobs:
159161
NPM_CONFIG_REGISTRY: http://localhost:4874/
160162
HF_TOKEN: ${{ secrets.HF_TOKEN }}
161163
HF_BLACK_FOREST_LABS_KEY: dummy
164+
HF_COHERE_KEY: dummy
162165
HF_FAL_KEY: dummy
163166
HF_FIREWORKS_KEY: dummy
164167
HF_HYPERBOLIC_KEY: dummy

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ await uploadFile({
2323
// Can work with native File in browsers
2424
file: {
2525
path: "pytorch_model.bin",
26-
content: new Blob(...)
26+
content: new Blob(...)
2727
}
2828
});
2929

@@ -39,7 +39,7 @@ await inference.chatCompletion({
3939
],
4040
max_tokens: 512,
4141
temperature: 0.5,
42-
provider: "sambanova", // or together, fal-ai, replicate, …
42+
provider: "sambanova", // or together, fal-ai, replicate, cohere
4343
});
4444

4545
await inference.textToImage({
@@ -146,12 +146,12 @@ for await (const chunk of inference.chatCompletionStream({
146146
console.log(chunk.choices[0].delta.content);
147147
}
148148

149-
/// Using a third-party provider:
149+
/// Using a third-party provider:
150150
await inference.chatCompletion({
151151
model: "meta-llama/Llama-3.1-8B-Instruct",
152152
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
153153
max_tokens: 512,
154-
provider: "sambanova", // or together, fal-ai, replicate, …
154+
provider: "sambanova", // or together, fal-ai, replicate, cohere
155155
})
156156

157157
await inference.textToImage({
@@ -211,7 +211,7 @@ await uploadFile({
211211
// Can work with native File in browsers
212212
file: {
213213
path: "pytorch_model.bin",
214-
content: new Blob(...)
214+
content: new Blob(...)
215215
}
216216
});
217217

@@ -244,7 +244,7 @@ console.log(messages); // contains the data
244244

245245
// or you can run the code directly, however you can't check that the code is safe to execute this way, use at your own risk.
246246
const messages = await agent.run("Draw a picture of a cat wearing a top hat. Then caption the picture and read it out loud.")
247-
console.log(messages);
247+
console.log(messages);
248248
```
249249

250250
There are more features of course, check each library's README!

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Currently, we support the following providers:
5656
- [Sambanova](https://sambanova.ai)
5757
- [Together](https://together.xyz)
5858
- [Blackforestlabs](https://blackforestlabs.ai)
59+
- [Cohere](https://cohere.com)
5960

6061
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
6162
```ts
@@ -80,6 +81,7 @@ Only a subset of models are supported when requesting third-party providers. You
8081
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
8182
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
8283
- [Together supported models](https://huggingface.co/api/partners/together/models)
84+
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
8385
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
8486

8587
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
22
import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
3+
import { COHERE_CONFIG } from "../providers/cohere";
34
import { FAL_AI_CONFIG } from "../providers/fal-ai";
45
import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
56
import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
@@ -27,6 +28,7 @@ let tasks: Record<string, { models: { id: string }[] }> | null = null;
2728
*/
2829
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
2930
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
31+
cohere: COHERE_CONFIG,
3032
"fal-ai": FAL_AI_CONFIG,
3133
"fireworks-ai": FIREWORKS_AI_CONFIG,
3234
"hf-inference": HF_INFERENCE_CONFIG,
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/**
2+
* See the registered mapping of HF model ID => Cohere model ID here:
3+
*
4+
* https://huggingface.co/api/partners/cohere/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at Cohere and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to Cohere, please open an issue on the present repo
13+
* and we will tag Cohere team members.
14+
*
15+
* Thanks!
16+
*/
17+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18+
19+
const COHERE_API_BASE_URL = "https://api.cohere.com";
20+
21+
22+
const makeBody = (params: BodyParams): Record<string, unknown> => {
23+
return {
24+
...params.args,
25+
model: params.model,
26+
};
27+
};
28+
29+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
30+
return { Authorization: `Bearer ${params.accessToken}` };
31+
};
32+
33+
const makeUrl = (params: UrlParams): string => {
34+
return `${params.baseUrl}/compatibility/v1/chat/completions`;
35+
};
36+
37+
export const COHERE_CONFIG: ProviderConfig = {
38+
baseUrl: COHERE_API_BASE_URL,
39+
makeBody,
40+
makeHeaders,
41+
makeUrl,
42+
};

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
1717
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1818
*/
1919
"black-forest-labs": {},
20+
cohere: {},
2021
"fal-ai": {},
2122
"fireworks-ai": {},
2223
"hf-inference": {},

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ export type InferenceTask = Exclude<PipelineType, "other">;
3030

3131
export const INFERENCE_PROVIDERS = [
3232
"black-forest-labs",
33+
"cohere",
3334
"fal-ai",
3435
"fireworks-ai",
3536
"hf-inference",

packages/inference/test/HfInference.spec.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,4 +1350,51 @@ describe.concurrent("HfInference", () => {
13501350
},
13511351
TIMEOUT
13521352
);
1353+
describe.concurrent(
1354+
"Cohere",
1355+
() => {
1356+
const client = new HfInference(env.HF_COHERE_KEY);
1357+
1358+
HARDCODED_MODEL_ID_MAPPING["cohere"] = {
1359+
"CohereForAI/c4ai-command-r7b-12-2024": "command-r7b-12-2024",
1360+
"CohereForAI/aya-expanse-8b": "c4ai-aya-expanse-8b",
1361+
};
1362+
1363+
it("chatCompletion", async () => {
1364+
const res = await client.chatCompletion({
1365+
model: "CohereForAI/c4ai-command-r7b-12-2024",
1366+
provider: "cohere",
1367+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1368+
});
1369+
if (res.choices && res.choices.length > 0) {
1370+
const completion = res.choices[0].message?.content;
1371+
expect(completion).toContain("two");
1372+
}
1373+
});
1374+
1375+
it("chatCompletion stream", async () => {
1376+
const stream = client.chatCompletionStream({
1377+
model: "CohereForAI/c4ai-command-r7b-12-2024",
1378+
provider: "cohere",
1379+
messages: [{ role: "user", content: "Say 'this is a test'" }],
1380+
stream: true,
1381+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1382+
1383+
let fullResponse = "";
1384+
for await (const chunk of stream) {
1385+
if (chunk.choices && chunk.choices.length > 0) {
1386+
const content = chunk.choices[0].delta?.content;
1387+
if (content) {
1388+
fullResponse += content;
1389+
}
1390+
}
1391+
}
1392+
1393+
// Verify we got a meaningful response
1394+
expect(fullResponse).toBeTruthy();
1395+
expect(fullResponse.length).toBeGreaterThan(0);
1396+
});
1397+
},
1398+
TIMEOUT
1399+
);
13531400
});

packages/inference/test/tapes.json

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7386,5 +7386,58 @@
73867386
"content-type": "image/jpeg"
73877387
}
73887388
}
7389+
},
7390+
"cb34d07934bd210fd64da207415c49fc6e2870d3564164a2a5d541f713227fbf": {
7391+
"url": "https://api.cohere.com/compatibility/v1/chat/completions",
7392+
"init": {
7393+
"headers": {
7394+
"Content-Type": "application/json"
7395+
},
7396+
"method": "POST",
7397+
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Say 'this is a test'\"}],\"stream\":true,\"model\":\"command-r7b-12-2024\"}"
7398+
},
7399+
"response": {
7400+
"body": "data: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\"\",\"role\":\"assistant\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\"This\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" is\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" a\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" test\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\".\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":\"stop\",\"delta\":{}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\",\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":5,\"total_tokens\":12}}\n\ndata: [DONE]\n\n",
7401+
"status": 200,
7402+
"statusText": "OK",
7403+
"headers": {
7404+
"access-control-expose-headers": "X-Debug-Trace-ID",
7405+
"alt-svc": "h3=\":443\"; ma=2592000,h3-29=\":443\"; ma=2592000",
7406+
"cache-control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0",
7407+
"content-type": "text/event-stream",
7408+
"expires": "Thu, 01 Jan 1970 00:00:00 UTC",
7409+
"pragma": "no-cache",
7410+
"server": "envoy",
7411+
"transfer-encoding": "chunked",
7412+
"vary": "Origin"
7413+
}
7414+
}
7415+
},
7416+
"8c6ffbc794573c463ed5666e3b560e5966cd975c2893c901c18adb696ba54a6a": {
7417+
"url": "https://api.cohere.com/compatibility/v1/chat/completions",
7418+
"init": {
7419+
"headers": {
7420+
"Content-Type": "application/json"
7421+
},
7422+
"method": "POST",
7423+
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"command-r7b-12-2024\"}"
7424+
},
7425+
"response": {
7426+
"body": "{\"id\":\"f8bf661b-c600-44e5-8412-df37c9dcd985\",\"choices\":[{\"index\":0,\"finish_reason\":\"stop\",\"message\":{\"role\":\"assistant\",\"content\":\"One plus one is equal to two.\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion\",\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":8,\"total_tokens\":19}}",
7427+
"status": 200,
7428+
"statusText": "OK",
7429+
"headers": {
7430+
"access-control-expose-headers": "X-Debug-Trace-ID",
7431+
"alt-svc": "h3=\":443\"; ma=2592000,h3-29=\":443\"; ma=2592000",
7432+
"cache-control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0",
7433+
"content-type": "application/json",
7434+
"expires": "Thu, 01 Jan 1970 00:00:00 UTC",
7435+
"num_chars": "2635",
7436+
"num_tokens": "19",
7437+
"pragma": "no-cache",
7438+
"server": "envoy",
7439+
"vary": "Origin"
7440+
}
7441+
}
73897442
}
73907443
}

packages/tasks/src/inference-providers.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/// This list is for illustration purposes only.
22
/// in the `tasks` sub-package, we do not need actual strong typing of the inference providers.
33
const INFERENCE_PROVIDERS = [
4+
"cohere",
45
"fal-ai",
56
"fireworks-ai",
67
"hf-inference",

0 commit comments

Comments
 (0)