Skip to content

Commit e1bf67a

Browse files
committed
Use new Cohere OpenAI compatible API
1 parent eef7b1e commit e1bf67a

File tree

4 files changed

+8
-141
lines changed

4 files changed

+8
-141
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ let tasks: Record<string, { models: { id: string }[] }> | null = null;
2828
*/
2929
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
3030
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
31-
"cohere": COHERE_CONFIG,
31+
cohere: COHERE_CONFIG,
3232
"fal-ai": FAL_AI_CONFIG,
3333
"fireworks-ai": FIREWORKS_AI_CONFIG,
3434
"hf-inference": HF_INFERENCE_CONFIG,
Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,8 @@
11
import { InferenceOutputError } from "../../lib/InferenceOutputError";
2-
import type { CohereTextGenerationOutputFinishReason, CohereMessage, CohereLogprob } from "../../providers/cohere";
32
import type { BaseArgs, Options } from "../../types";
43
import { request } from "../custom/request";
54
import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
65

7-
interface CohereChatCompletionOutput {
8-
id: string;
9-
finish_reason: CohereTextGenerationOutputFinishReason;
10-
message: CohereMessage;
11-
usage: {
12-
billed_units: {
13-
input_tokens: number;
14-
output_tokens: number;
15-
};
16-
tokens: {
17-
input_tokens: number;
18-
output_tokens: number;
19-
};
20-
};
21-
logprobs?: CohereLogprob[]; // Optional field for log probabilities
22-
}
23-
24-
function convertCohereToChatCompletionOutput(res: CohereChatCompletionOutput): ChatCompletionOutput {
25-
// Create a ChatCompletionOutput object from the CohereChatCompletionOutput
26-
return {
27-
id: res.id,
28-
created: Date.now(),
29-
model: "cohere-model",
30-
system_fingerprint: "cohere-fingerprint",
31-
usage: {
32-
completion_tokens: res.usage.tokens.output_tokens,
33-
prompt_tokens: res.usage.tokens.input_tokens,
34-
total_tokens: res.usage.tokens.input_tokens + res.usage.tokens.output_tokens,
35-
},
36-
choices: [
37-
{
38-
finish_reason: res.finish_reason,
39-
index: 0,
40-
message: {
41-
role: res.message.role,
42-
content: res.message.content.map((c) => c.text).join(" "),
43-
tool_calls: res.message.tool_calls?.map((toolCall) => ({
44-
function: {
45-
arguments: toolCall.function.arguments,
46-
description: toolCall.function.description,
47-
name: toolCall.function.name,
48-
},
49-
id: toolCall.id,
50-
type: toolCall.type,
51-
})),
52-
},
53-
logprobs: res.logprobs
54-
? {
55-
content: res.logprobs.map((logprob) => ({
56-
logprob: logprob.logprob,
57-
token: logprob.token,
58-
top_logprobs: logprob.top_logprobs.map((topLogprob) => ({
59-
logprob: topLogprob.logprob,
60-
token: topLogprob.token,
61-
})),
62-
})),
63-
}
64-
: undefined,
65-
},
66-
],
67-
};
68-
}
69-
706
/**
717
* Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
728
*/
@@ -95,4 +31,5 @@ export async function chatCompletion(
9531
if (!isValidOutput) {
9632
throw new InferenceOutputError("Expected ChatCompletionOutput");
9733
}
34+
return res;
9835
}

packages/inference/src/tasks/nlp/chatCompletionStream.ts

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,7 @@
1-
import type { CohereLogprob, CohereMessageDelta, CohereTextGenerationOutputFinishReason } from "../../providers/cohere";
21
import type { BaseArgs, Options } from "../../types";
32
import { streamingRequest } from "../custom/streamingRequest";
43
import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
54

6-
interface CohereChatCompletionStreamOutput {
7-
id: string;
8-
finish_reason?: CohereTextGenerationOutputFinishReason;
9-
delta: {
10-
message: CohereMessageDelta;
11-
};
12-
usage?: {
13-
billed_units: {
14-
input_tokens: number;
15-
output_tokens: number;
16-
};
17-
tokens: {
18-
input_tokens: number;
19-
output_tokens: number;
20-
};
21-
};
22-
logprobs?: CohereLogprob[];
23-
}
24-
25-
function convertCohereToChatCompletionStreamOutput(res: CohereChatCompletionStreamOutput): ChatCompletionStreamOutput {
26-
return {
27-
id: res.id,
28-
created: Date.now(), // Assuming the current timestamp as created time
29-
model: "cohere-model", // Assuming a placeholder model name
30-
system_fingerprint: "cohere-fingerprint", // Assuming a placeholder fingerprint
31-
usage: res.usage
32-
? {
33-
completion_tokens: res.usage.tokens.output_tokens,
34-
prompt_tokens: res.usage.tokens.input_tokens,
35-
total_tokens: res.usage.tokens.input_tokens + res.usage.tokens.output_tokens,
36-
}
37-
: undefined,
38-
choices: [
39-
{
40-
delta: {
41-
role: res.delta?.message?.role,
42-
content: res.delta?.message?.content?.text,
43-
tool_calls: res.delta?.message?.tool_calls
44-
? {
45-
function: {
46-
arguments: JSON.stringify(res.delta?.message?.tool_calls[0]?.function.arguments), // Convert arguments to string
47-
description: res.delta?.message?.tool_calls[0]?.function.description,
48-
name: res.delta?.message?.tool_calls[0]?.function.name,
49-
},
50-
id: res.delta?.message?.tool_calls[0]?.id,
51-
index: 0, // Assuming a single tool call with index 0
52-
type: res.delta?.message?.tool_calls[0]?.type,
53-
}
54-
: undefined,
55-
},
56-
finish_reason: res.finish_reason,
57-
index: 0, // Assuming a single choice with index 0
58-
logprobs: res.logprobs
59-
? {
60-
content: res.logprobs.map((logprob) => ({
61-
logprob: logprob.logprob,
62-
token: logprob.token,
63-
top_logprobs: logprob.top_logprobs.map((topLogprob) => ({
64-
logprob: topLogprob.logprob,
65-
token: topLogprob.token,
66-
})),
67-
})),
68-
}
69-
: undefined,
70-
},
71-
],
72-
};
73-
}
74-
755
/**
766
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
777
*/

packages/inference/test/tapes.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7387,8 +7387,8 @@
73877387
}
73887388
}
73897389
},
7390-
"772e481d98640490fca3aab8e7fed5b771ea213f2b76b2f8858ce7bc90acb16b": {
7391-
"url": "https://api.cohere.com/v2/chat",
7390+
"cb34d07934bd210fd64da207415c49fc6e2870d3564164a2a5d541f713227fbf": {
7391+
"url": "https://api.cohere.com/compatibility/v1/chat/completions",
73927392
"init": {
73937393
"headers": {
73947394
"Content-Type": "application/json"
@@ -7397,7 +7397,7 @@
73977397
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Say 'this is a test'\"}],\"stream\":true,\"model\":\"command-r7b-12-2024\"}"
73987398
},
73997399
"response": {
7400-
"body": "event: message-start\ndata: {\"id\":\"b9c2d3f2-4532-473d-a8a1-9236d159ab26\",\"type\":\"message-start\",\"delta\":{\"message\":{\"role\":\"assistant\",\"content\":[],\"tool_plan\":\"\",\"tool_calls\":[],\"citations\":[]}}}\n\nevent: content-start\ndata: {\"type\":\"content-start\",\"index\":0,\"delta\":{\"message\":{\"content\":{\"type\":\"text\",\"text\":\"\"}}}}\n\nevent: content-delta\ndata: {\"type\":\"content-delta\",\"index\":0,\"delta\":{\"message\":{\"content\":{\"text\":\"This\"}}}}\n\nevent: content-delta\ndata: {\"type\":\"content-delta\",\"index\":0,\"delta\":{\"message\":{\"content\":{\"text\":\" is\"}}}}\n\nevent: content-delta\ndata: {\"type\":\"content-delta\",\"index\":0,\"delta\":{\"message\":{\"content\":{\"text\":\" a\"}}}}\n\nevent: content-delta\ndata: {\"type\":\"content-delta\",\"index\":0,\"delta\":{\"message\":{\"content\":{\"text\":\" test\"}}}}\n\nevent: content-delta\ndata: {\"type\":\"content-delta\",\"index\":0,\"delta\":{\"message\":{\"content\":{\"text\":\".\"}}}}\n\nevent: content-end\ndata: {\"type\":\"content-end\",\"index\":0}\n\nevent: message-end\ndata: {\"type\":\"message-end\",\"delta\":{\"finish_reason\":\"COMPLETE\",\"usage\":{\"billed_units\":{\"input_tokens\":7,\"output_tokens\":5},\"tokens\":{\"input_tokens\":502,\"output_tokens\":7}}}}\n\ndata: [DONE]\n\n",
7400+
"body": "data: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\"\",\"role\":\"assistant\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\"This\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" is\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" a\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" test\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\".\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":\"stop\",\"delta\":{}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\",\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":5,\"total_tokens\":12}}\n\ndata: [DONE]\n\n",
74017401
"status": 200,
74027402
"statusText": "OK",
74037403
"headers": {
@@ -7413,8 +7413,8 @@
74137413
}
74147414
}
74157415
},
7416-
"545bf4e8393bc07dedb7c66d13846ff8264a49e909117c3c93ae35e30e705cbb": {
7417-
"url": "https://api.cohere.com/v2/chat",
7416+
"8c6ffbc794573c463ed5666e3b560e5966cd975c2893c901c18adb696ba54a6a": {
7417+
"url": "https://api.cohere.com/compatibility/v1/chat/completions",
74187418
"init": {
74197419
"headers": {
74207420
"Content-Type": "application/json"
@@ -7423,7 +7423,7 @@
74237423
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"command-r7b-12-2024\"}"
74247424
},
74257425
"response": {
7426-
"body": "{\"id\":\"cd9a6a0f-5e4c-411f-9604-fd4bdffc3052\",\"message\":{\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One plus one is equal to two.\"}]},\"finish_reason\":\"COMPLETE\",\"usage\":{\"billed_units\":{\"input_tokens\":11,\"output_tokens\":8},\"tokens\":{\"input_tokens\":507,\"output_tokens\":10}}}",
7426+
"body": "{\"id\":\"f8bf661b-c600-44e5-8412-df37c9dcd985\",\"choices\":[{\"index\":0,\"finish_reason\":\"stop\",\"message\":{\"role\":\"assistant\",\"content\":\"One plus one is equal to two.\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion\",\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":8,\"total_tokens\":19}}",
74277427
"status": 200,
74287428
"statusText": "OK",
74297429
"headers": {

0 commit comments

Comments
 (0)