Skip to content

Commit 429e81f

Browse files
committed
Add Cohere provider
1 parent 5f7f423 commit 429e81f

File tree

11 files changed

+316
-7
lines changed

11 files changed

+316
-7
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ jobs:
5050
HF_REPLICATE_KEY: dummy
5151
HF_SAMBANOVA_KEY: dummy
5252
HF_TOGETHER_KEY: dummy
53+
HF_COHERE_KEY: dummy
5354

5455
browser:
5556
runs-on: ubuntu-latest

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ await uploadFile({
2323
// Can work with native File in browsers
2424
file: {
2525
path: "pytorch_model.bin",
26-
content: new Blob(...)
26+
content: new Blob(...)
2727
}
2828
});
2929

@@ -39,7 +39,7 @@ await inference.chatCompletion({
3939
],
4040
max_tokens: 512,
4141
temperature: 0.5,
42-
provider: "sambanova", // or together, fal-ai, replicate, …
42+
provider: "sambanova", // or together, fal-ai, replicate, cohere
4343
});
4444

4545
await inference.textToImage({
@@ -146,12 +146,12 @@ for await (const chunk of inference.chatCompletionStream({
146146
console.log(chunk.choices[0].delta.content);
147147
}
148148

149-
/// Using a third-party provider:
149+
/// Using a third-party provider:
150150
await inference.chatCompletion({
151151
model: "meta-llama/Llama-3.1-8B-Instruct",
152152
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
153153
max_tokens: 512,
154-
provider: "sambanova", // or together, fal-ai, replicate, …
154+
provider: "sambanova", // or together, fal-ai, replicate, cohere
155155
})
156156

157157
await inference.textToImage({
@@ -211,7 +211,7 @@ await uploadFile({
211211
// Can work with native File in browsers
212212
file: {
213213
path: "pytorch_model.bin",
214-
content: new Blob(...)
214+
content: new Blob(...)
215215
}
216216
});
217217

@@ -244,7 +244,7 @@ console.log(messages); // contains the data
244244

245245
// or you can run the code directly, however you can't check that the code is safe to execute this way, use at your own risk.
246246
const messages = await agent.run("Draw a picture of a cat wearing a top hat. Then caption the picture and read it out loud.")
247-
console.log(messages);
247+
console.log(messages);
248248
```
249249

250250
There are more features of course, check each library's README!

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Currently, we support the following providers:
5656
- [Sambanova](https://sambanova.ai)
5757
- [Together](https://together.xyz)
5858
- [Blackforestlabs](https://blackforestlabs.ai)
59+
- [Cohere](https://cohere.com)
5960

6061
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
6162
```ts
@@ -80,6 +81,7 @@ Only a subset of models are supported when requesting third-party providers. You
8081
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
8182
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
8283
- [Together supported models](https://huggingface.co/api/partners/together/models)
84+
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
8385
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
8486

8587
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
22
import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
3+
import { COHERE_CONFIG } from "../providers/cohere";
34
import { FAL_AI_CONFIG } from "../providers/fal-ai";
45
import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
56
import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
@@ -27,6 +28,7 @@ let tasks: Record<string, { models: { id: string }[] }> | null = null;
2728
*/
2829
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
2930
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
31+
"cohere": COHERE_CONFIG,
3032
"fal-ai": FAL_AI_CONFIG,
3133
"fireworks-ai": FIREWORKS_AI_CONFIG,
3234
"hf-inference": HF_INFERENCE_CONFIG,
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/**
2+
* See the registered mapping of HF model ID => Cohere model ID here:
3+
*
4+
* https://huggingface.co/api/partners/cohere/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at Cohere and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to Cohere, please open an issue on the present repo
13+
* and we will tag Cohere team members.
14+
*
15+
* Thanks!
16+
*/
17+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18+
19+
const COHERE_API_BASE_URL = "https://api.cohere.com";
20+
21+
22+
const makeBody = (params: BodyParams): Record<string, unknown> => {
23+
return {
24+
...params.args,
25+
model: params.model,
26+
};
27+
};
28+
29+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
30+
return { Authorization: `Bearer ${params.accessToken}` };
31+
};
32+
33+
const makeUrl = (params: UrlParams): string => {
34+
return `${params.baseUrl}/compatibility/v1/chat/completions`;
35+
};
36+
37+
export const COHERE_CONFIG: ProviderConfig = {
38+
baseUrl: COHERE_API_BASE_URL,
39+
makeBody,
40+
makeHeaders,
41+
makeUrl,
42+
};

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
1717
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1818
*/
1919
"black-forest-labs": {},
20+
cohere: {},
2021
"fal-ai": {},
2122
"fireworks-ai": {},
2223
"hf-inference": {},

packages/inference/src/tasks/nlp/chatCompletion.ts

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,109 @@ import type { BaseArgs, Options } from "../../types";
33
import { request } from "../custom/request";
44
import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
55

6+
export type CohereTextGenerationOutputFinishReason =
7+
| "COMPLETE"
8+
| "STOP_SEQUENCE"
9+
| "MAX_TOKENS"
10+
| "TOOL_CALL"
11+
| "ERROR";
12+
13+
interface CohereChatCompletionOutput {
14+
id: string;
15+
finish_reason: CohereTextGenerationOutputFinishReason;
16+
message: CohereMessage;
17+
usage: CohereChatCompletionOutputUsage;
18+
logprobs?: CohereLogprob[]; // Optional field for log probabilities
19+
}
20+
21+
interface CohereMessage {
22+
role: string;
23+
content: Array<{
24+
type: string;
25+
text: string;
26+
}>;
27+
tool_calls?: CohereToolCall[]; // Optional field for tool calls
28+
}
29+
30+
interface CohereChatCompletionOutputUsage {
31+
billed_units: CohereInputOutputTokens;
32+
tokens: CohereInputOutputTokens;
33+
}
34+
35+
interface CohereInputOutputTokens {
36+
input_tokens: number;
37+
output_tokens: number;
38+
}
39+
40+
interface CohereLogprob {
41+
logprob: number;
42+
token: string;
43+
top_logprobs: CohereTopLogprob[];
44+
}
45+
46+
interface CohereTopLogprob {
47+
logprob: number;
48+
token: string;
49+
}
50+
51+
interface CohereToolCall {
52+
function: CohereFunctionDefinition;
53+
id: string;
54+
type: string;
55+
}
56+
57+
interface CohereFunctionDefinition {
58+
arguments: unknown;
59+
description?: string;
60+
name: string;
61+
}
62+
63+
function convertCohereToChatCompletionOutput(res: CohereChatCompletionOutput): ChatCompletionOutput {
64+
// Create a ChatCompletionOutput object from the CohereChatCompletionOutput
65+
return {
66+
id: res.id,
67+
created: Date.now(),
68+
model: "cohere-model",
69+
system_fingerprint: "cohere-fingerprint",
70+
usage: {
71+
completion_tokens: res.usage.tokens.output_tokens,
72+
prompt_tokens: res.usage.tokens.input_tokens,
73+
total_tokens: res.usage.tokens.input_tokens + res.usage.tokens.output_tokens,
74+
},
75+
choices: [
76+
{
77+
finish_reason: res.finish_reason,
78+
index: 0,
79+
message: {
80+
role: res.message.role,
81+
content: res.message.content.map((c) => c.text).join(" "),
82+
tool_calls: res.message.tool_calls?.map((toolCall) => ({
83+
function: {
84+
arguments: toolCall.function.arguments,
85+
description: toolCall.function.description,
86+
name: toolCall.function.name,
87+
},
88+
id: toolCall.id,
89+
type: toolCall.type,
90+
})),
91+
},
92+
logprobs: res.logprobs
93+
? {
94+
content: res.logprobs.map((logprob) => ({
95+
logprob: logprob.logprob,
96+
token: logprob.token,
97+
top_logprobs: logprob.top_logprobs.map((topLogprob) => ({
98+
logprob: topLogprob.logprob,
99+
token: topLogprob.token,
100+
})),
101+
})),
102+
}
103+
: undefined,
104+
},
105+
],
106+
};
107+
}
108+
6109
/**
7110
* Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
8111
*/
@@ -31,5 +134,4 @@ export async function chatCompletion(
31134
if (!isValidOutput) {
32135
throw new InferenceOutputError("Expected ChatCompletionOutput");
33136
}
34-
return res;
35137
}

packages/inference/src/tasks/nlp/chatCompletionStream.ts

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,117 @@ import type { BaseArgs, Options } from "../../types";
22
import { streamingRequest } from "../custom/streamingRequest";
33
import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
44

5+
export type CohereTextGenerationOutputFinishReason =
6+
| "COMPLETE"
7+
| "STOP_SEQUENCE"
8+
| "MAX_TOKENS"
9+
| "TOOL_CALL"
10+
| "ERROR";
11+
12+
interface CohereChatCompletionStreamOutput {
13+
id: string;
14+
finish_reason?: CohereTextGenerationOutputFinishReason;
15+
delta: CohereMessageDelta;
16+
usage?: CohereChatCompletionOutputUsage;
17+
logprobs?: CohereLogprob[];
18+
}
19+
20+
interface CohereMessage {
21+
role: string;
22+
content: {
23+
type: string;
24+
text: string;
25+
};
26+
tool_calls?: CohereToolCall[];
27+
}
28+
29+
interface CohereMessageDelta {
30+
message: CohereMessage;
31+
}
32+
33+
interface CohereChatCompletionOutputUsage {
34+
billed_units: CohereInputOutputTokens;
35+
tokens: CohereInputOutputTokens;
36+
}
37+
38+
interface CohereInputOutputTokens {
39+
input_tokens: number;
40+
output_tokens: number;
41+
}
42+
43+
interface CohereLogprob {
44+
logprob: number;
45+
token: string;
46+
top_logprobs: CohereTopLogprob[];
47+
}
48+
49+
interface CohereTopLogprob {
50+
logprob: number;
51+
token: string;
52+
}
53+
54+
interface CohereToolCall {
55+
function: CohereFunctionDefinition;
56+
id: string;
57+
type: string;
58+
}
59+
60+
interface CohereFunctionDefinition {
61+
arguments: unknown;
62+
description?: string;
63+
name: string;
64+
}
65+
66+
function convertCohereToChatCompletionStreamOutput(res: CohereChatCompletionStreamOutput): ChatCompletionStreamOutput {
67+
return {
68+
id: res.id,
69+
created: Date.now(), // Assuming the current timestamp as created time
70+
model: "cohere-model", // Assuming a placeholder model name
71+
system_fingerprint: "cohere-fingerprint", // Assuming a placeholder fingerprint
72+
usage: res.usage
73+
? {
74+
completion_tokens: res.usage.tokens.output_tokens,
75+
prompt_tokens: res.usage.tokens.input_tokens,
76+
total_tokens: res.usage.tokens.input_tokens + res.usage.tokens.output_tokens,
77+
}
78+
: undefined,
79+
choices: [
80+
{
81+
delta: {
82+
role: res.delta?.message?.role,
83+
content: res.delta?.message?.content?.text,
84+
tool_calls: res.delta?.message?.tool_calls
85+
? {
86+
function: {
87+
arguments: JSON.stringify(res.delta?.message?.tool_calls[0]?.function.arguments), // Convert arguments to string
88+
description: res.delta?.message?.tool_calls[0]?.function.description,
89+
name: res.delta?.message?.tool_calls[0]?.function.name,
90+
},
91+
id: res.delta?.message?.tool_calls[0]?.id,
92+
index: 0, // Assuming a single tool call with index 0
93+
type: res.delta?.message?.tool_calls[0]?.type,
94+
}
95+
: undefined,
96+
},
97+
finish_reason: res.finish_reason,
98+
index: 0, // Assuming a single choice with index 0
99+
logprobs: res.logprobs
100+
? {
101+
content: res.logprobs.map((logprob) => ({
102+
logprob: logprob.logprob,
103+
token: logprob.token,
104+
top_logprobs: logprob.top_logprobs.map((topLogprob) => ({
105+
logprob: topLogprob.logprob,
106+
token: topLogprob.token,
107+
})),
108+
})),
109+
}
110+
: undefined,
111+
},
112+
],
113+
};
114+
}
115+
5116
/**
6117
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
7118
*/

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ export type InferenceTask = Exclude<PipelineType, "other">;
3030

3131
export const INFERENCE_PROVIDERS = [
3232
"black-forest-labs",
33+
"cohere",
3334
"fal-ai",
3435
"fireworks-ai",
3536
"hf-inference",

0 commit comments

Comments
 (0)