Skip to content

Commit 37f2b11

Browse files
committed
Add Cohere provider
1 parent e15c809 commit 37f2b11

File tree

11 files changed

+359
-32
lines changed

11 files changed

+359
-32
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ jobs:
5050
HF_REPLICATE_KEY: dummy
5151
HF_SAMBANOVA_KEY: dummy
5252
HF_TOGETHER_KEY: dummy
53+
HF_COHERE_KEY: dummy
5354

5455
browser:
5556
runs-on: ubuntu-latest

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ await uploadFile({
2323
// Can work with native File in browsers
2424
file: {
2525
path: "pytorch_model.bin",
26-
content: new Blob(...)
26+
content: new Blob(...)
2727
}
2828
});
2929

@@ -39,7 +39,7 @@ await inference.chatCompletion({
3939
],
4040
max_tokens: 512,
4141
temperature: 0.5,
42-
provider: "sambanova", // or together, fal-ai, replicate, …
42+
provider: "sambanova", // or together, fal-ai, replicate, cohere
4343
});
4444

4545
await inference.textToImage({
@@ -146,12 +146,12 @@ for await (const chunk of inference.chatCompletionStream({
146146
console.log(chunk.choices[0].delta.content);
147147
}
148148

149-
/// Using a third-party provider:
149+
/// Using a third-party provider:
150150
await inference.chatCompletion({
151151
model: "meta-llama/Llama-3.1-8B-Instruct",
152152
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
153153
max_tokens: 512,
154-
provider: "sambanova", // or together, fal-ai, replicate, …
154+
provider: "sambanova", // or together, fal-ai, replicate, cohere
155155
})
156156

157157
await inference.textToImage({
@@ -211,7 +211,7 @@ await uploadFile({
211211
// Can work with native File in browsers
212212
file: {
213213
path: "pytorch_model.bin",
214-
content: new Blob(...)
214+
content: new Blob(...)
215215
}
216216
});
217217

@@ -244,7 +244,7 @@ console.log(messages); // contains the data
244244

245245
// or you can run the code directly, however you can't check that the code is safe to execute this way, use at your own risk.
246246
const messages = await agent.run("Draw a picture of a cat wearing a top hat. Then caption the picture and read it out loud.")
247-
console.log(messages);
247+
console.log(messages);
248248
```
249249

250250
There are more features of course, check each library's README!

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Currently, we support the following providers:
5656
- [Sambanova](https://sambanova.ai)
5757
- [Together](https://together.xyz)
5858
- [Blackforestlabs](https://blackforestlabs.ai)
59+
- [Cohere](https://cohere.com)
5960

6061
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
6162
```ts
@@ -80,6 +81,7 @@ Only a subset of models are supported when requesting third-party providers. You
8081
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
8182
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
8283
- [Together supported models](https://huggingface.co/api/partners/together/models)
84+
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
8385
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
8486

8587
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { NOVITA_API_BASE_URL } from "../providers/novita";
88
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
99
import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic";
1010
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
11+
import { COHERE_API_BASE_URL } from "../providers/cohere";
1112
import type { InferenceProvider } from "../types";
1213
import type { InferenceTask, Options, RequestArgs } from "../types";
1314
import { isUrl } from "./isUrl";
@@ -256,6 +257,15 @@ function makeUrl(params: {
256257
}
257258
return baseUrl;
258259
}
260+
case "cohere": {
261+
const baseUrl = shouldProxy
262+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
263+
: COHERE_API_BASE_URL;
264+
if (params.taskHint === "text-generation") {
265+
return `${baseUrl}/v2/chat`;
266+
}
267+
return baseUrl;
268+
}
259269
default: {
260270
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
261271
if (params.taskHint && ["feature-extraction", "sentence-similarity"].includes(params.taskHint)) {
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const COHERE_API_BASE_URL = "https://api.cohere.com";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Cohere model ID here:
5+
*
6+
* https://huggingface.co/api/partners/cohere/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Cohere and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Cohere, please open an issue on the present repo
15+
* and we will tag Cohere team members.
16+
*
17+
* Thanks!
18+
*/

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
2626
sambanova: {},
2727
together: {},
2828
novita: {},
29+
cohere: {},
2930
};

packages/inference/src/tasks/nlp/chatCompletion.ts

Lines changed: 146 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,158 @@ import type { BaseArgs, Options } from "../../types";
33
import { request } from "../custom/request";
44
import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
55

6+
export type CohereTextGenerationOutputFinishReason =
7+
| "COMPLETE"
8+
| "STOP_SEQUENCE"
9+
| "MAX_TOKENS"
10+
| "TOOL_CALL"
11+
| "ERROR";
12+
13+
interface CohereChatCompletionOutput {
14+
id: string;
15+
finish_reason: CohereTextGenerationOutputFinishReason;
16+
message: CohereMessage;
17+
usage: CohereChatCompletionOutputUsage;
18+
logprobs?: CohereLogprob[]; // Optional field for log probabilities
19+
}
20+
21+
interface CohereMessage {
22+
role: string;
23+
content: Array<{
24+
type: string;
25+
text: string;
26+
}>;
27+
tool_calls?: CohereToolCall[]; // Optional field for tool calls
28+
}
29+
30+
interface CohereChatCompletionOutputUsage {
31+
billed_units: CohereInputOutputTokens;
32+
tokens: CohereInputOutputTokens;
33+
}
34+
35+
interface CohereInputOutputTokens {
36+
input_tokens: number;
37+
output_tokens: number;
38+
}
39+
40+
interface CohereLogprob {
41+
logprob: number;
42+
token: string;
43+
top_logprobs: CohereTopLogprob[];
44+
}
45+
46+
interface CohereTopLogprob {
47+
logprob: number;
48+
token: string;
49+
}
50+
51+
interface CohereToolCall {
52+
function: CohereFunctionDefinition;
53+
id: string;
54+
type: string;
55+
}
56+
57+
interface CohereFunctionDefinition {
58+
arguments: unknown;
59+
description?: string;
60+
name: string;
61+
}
62+
63+
function convertCohereToChatCompletionOutput(res: CohereChatCompletionOutput): ChatCompletionOutput {
64+
// Create a ChatCompletionOutput object from the CohereChatCompletionOutput
65+
return {
66+
id: res.id,
67+
created: Date.now(),
68+
model: "cohere-model",
69+
system_fingerprint: "cohere-fingerprint",
70+
usage: {
71+
completion_tokens: res.usage.tokens.output_tokens,
72+
prompt_tokens: res.usage.tokens.input_tokens,
73+
total_tokens: res.usage.tokens.input_tokens + res.usage.tokens.output_tokens,
74+
},
75+
choices: [
76+
{
77+
finish_reason: res.finish_reason,
78+
index: 0,
79+
message: {
80+
role: res.message.role,
81+
content: res.message.content.map((c) => c.text).join(" "),
82+
tool_calls: res.message.tool_calls?.map((toolCall) => ({
83+
function: {
84+
arguments: toolCall.function.arguments,
85+
description: toolCall.function.description,
86+
name: toolCall.function.name,
87+
},
88+
id: toolCall.id,
89+
type: toolCall.type,
90+
})),
91+
},
92+
logprobs: res.logprobs
93+
? {
94+
content: res.logprobs.map((logprob) => ({
95+
logprob: logprob.logprob,
96+
token: logprob.token,
97+
top_logprobs: logprob.top_logprobs.map((topLogprob) => ({
98+
logprob: topLogprob.logprob,
99+
token: topLogprob.token,
100+
})),
101+
})),
102+
}
103+
: undefined,
104+
},
105+
],
106+
};
107+
}
108+
6109
/**
7110
* Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
8111
*/
9112
export async function chatCompletion(
10113
args: BaseArgs & ChatCompletionInput,
11114
options?: Options
12115
): Promise<ChatCompletionOutput> {
13-
const res = await request<ChatCompletionOutput>(args, {
14-
...options,
15-
taskHint: "text-generation",
16-
chatCompletion: true,
17-
});
18-
19-
const isValidOutput =
20-
typeof res === "object" &&
21-
Array.isArray(res?.choices) &&
22-
typeof res?.created === "number" &&
23-
typeof res?.id === "string" &&
24-
typeof res?.model === "string" &&
25-
/// Together.ai and Nebius do not output a system_fingerprint
26-
(res.system_fingerprint === undefined ||
27-
res.system_fingerprint === null ||
28-
typeof res.system_fingerprint === "string") &&
29-
typeof res?.usage === "object";
30-
31-
if (!isValidOutput) {
32-
throw new InferenceOutputError("Expected ChatCompletionOutput");
116+
if (args.provider === "cohere") {
117+
const res = await request<CohereChatCompletionOutput>(args, {
118+
...options,
119+
taskHint: "text-generation",
120+
chatCompletion: true,
121+
});
122+
123+
const isValidOutput =
124+
typeof res === "object" &&
125+
typeof res?.id === "string" &&
126+
typeof res?.finish_reason === "string" &&
127+
typeof res?.message === "object" &&
128+
Array.isArray(res?.message.content) &&
129+
typeof res?.usage === "object";
130+
131+
if (!isValidOutput) {
132+
throw new InferenceOutputError("Expected CohereChatCompletionOutput");
133+
}
134+
135+
return convertCohereToChatCompletionOutput(res);
136+
} else {
137+
const res = await request<ChatCompletionOutput>(args, {
138+
...options,
139+
taskHint: "text-generation",
140+
chatCompletion: true,
141+
});
142+
143+
const isValidOutput =
144+
typeof res === "object" &&
145+
Array.isArray(res?.choices) &&
146+
typeof res?.created === "number" &&
147+
typeof res?.id === "string" &&
148+
typeof res?.model === "string" &&
149+
/// Together.ai and Nebius do not output a system_fingerprint
150+
(res.system_fingerprint === undefined ||
151+
res.system_fingerprint === null ||
152+
typeof res.system_fingerprint === "string") &&
153+
typeof res?.usage === "object";
154+
155+
if (!isValidOutput) {
156+
throw new InferenceOutputError("Expected ChatCompletionOutput");
157+
}
158+
return res;
33159
}
34-
return res;
35160
}

0 commit comments

Comments
 (0)