Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_BLACK_FOREST_LABS_KEY: dummy
HF_COHERE_KEY: dummy
HF_FAL_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_HYPERBOLIC_KEY: dummy
Expand Down Expand Up @@ -87,6 +88,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_BLACK_FOREST_LABS_KEY: dummy
HF_COHERE_KEY: dummy
HF_FAL_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_HYPERBOLIC_KEY: dummy
Expand Down Expand Up @@ -159,6 +161,7 @@ jobs:
NPM_CONFIG_REGISTRY: http://localhost:4874/
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_BLACK_FOREST_LABS_KEY: dummy
HF_COHERE_KEY: dummy
HF_FAL_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_HYPERBOLIC_KEY: dummy
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ await uploadFile({
// Can work with native File in browsers
file: {
path: "pytorch_model.bin",
content: new Blob(...)
content: new Blob(...)
}
});

Expand All @@ -39,7 +39,7 @@ await inference.chatCompletion({
],
max_tokens: 512,
temperature: 0.5,
provider: "sambanova", // or together, fal-ai, replicate, …
provider: "sambanova", // or together, fal-ai, replicate, cohere
});

await inference.textToImage({
Expand Down Expand Up @@ -146,12 +146,12 @@ for await (const chunk of inference.chatCompletionStream({
console.log(chunk.choices[0].delta.content);
}

/// Using a third-party provider:
/// Using a third-party provider:
await inference.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
max_tokens: 512,
provider: "sambanova", // or together, fal-ai, replicate, …
provider: "sambanova", // or together, fal-ai, replicate, cohere
})

await inference.textToImage({
Expand Down Expand Up @@ -211,7 +211,7 @@ await uploadFile({
// Can work with native File in browsers
file: {
path: "pytorch_model.bin",
content: new Blob(...)
content: new Blob(...)
}
});

Expand Down Expand Up @@ -244,7 +244,7 @@ console.log(messages); // contains the data

// or you can run the code directly, however you can't check that the code is safe to execute this way, use at your own risk.
const messages = await agent.run("Draw a picture of a cat wearing a top hat. Then caption the picture and read it out loud.")
console.log(messages);
console.log(messages);
```

There are more features of course, check each library's README!
Expand Down
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Currently, we support the following providers:
- [Sambanova](https://sambanova.ai)
- [Together](https://together.xyz)
- [Blackforestlabs](https://blackforestlabs.ai)
- [Cohere](https://cohere.com)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
```ts
Expand All @@ -80,6 +81,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
Expand Down
2 changes: 2 additions & 0 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
import { COHERE_CONFIG } from "../providers/cohere";
import { FAL_AI_CONFIG } from "../providers/fal-ai";
import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
import { HF_INFERENCE_CONFIG } from "../providers/hf-inference";
Expand Down Expand Up @@ -27,6 +28,7 @@ let tasks: Record<string, { models: { id: string }[] }> | null = null;
*/
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
cohere: COHERE_CONFIG,
"fal-ai": FAL_AI_CONFIG,
"fireworks-ai": FIREWORKS_AI_CONFIG,
"hf-inference": HF_INFERENCE_CONFIG,
Expand Down
42 changes: 42 additions & 0 deletions packages/inference/src/providers/cohere.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/**
* See the registered mapping of HF model ID => Cohere model ID here:
*
* https://huggingface.co/api/partners/cohere/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Cohere and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Cohere, please open an issue on the present repo
* and we will tag Cohere team members.
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";

const COHERE_API_BASE_URL = "https://api.cohere.com";


const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
...params.args,
model: params.model,
};
};

const makeHeaders = (params: HeaderParams): Record<string, string> => {
return { Authorization: `Bearer ${params.accessToken}` };
};

const makeUrl = (params: UrlParams): string => {
return `${params.baseUrl}/compatibility/v1/chat/completions`;
};

export const COHERE_CONFIG: ProviderConfig = {
baseUrl: COHERE_API_BASE_URL,
makeBody,
makeHeaders,
makeUrl,
};
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
*/
"black-forest-labs": {},
cohere: {},
"fal-ai": {},
"fireworks-ai": {},
"hf-inference": {},
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export type InferenceTask = Exclude<PipelineType, "other">;

export const INFERENCE_PROVIDERS = [
"black-forest-labs",
"cohere",
"fal-ai",
"fireworks-ai",
"hf-inference",
Expand Down
47 changes: 47 additions & 0 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1350,4 +1350,51 @@ describe.concurrent("HfInference", () => {
},
TIMEOUT
);
describe.concurrent(
"Cohere",
() => {
const client = new HfInference(env.HF_COHERE_KEY);

HARDCODED_MODEL_ID_MAPPING["cohere"] = {
"CohereForAI/c4ai-command-r7b-12-2024": "command-r7b-12-2024",
"CohereForAI/aya-expanse-8b": "c4ai-aya-expanse-8b",
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "CohereForAI/c4ai-command-r7b-12-2024",
provider: "cohere",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "CohereForAI/c4ai-command-r7b-12-2024",
provider: "cohere",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});
},
TIMEOUT
);
});
53 changes: 53 additions & 0 deletions packages/inference/test/tapes.json
Original file line number Diff line number Diff line change
Expand Up @@ -7386,5 +7386,58 @@
"content-type": "image/jpeg"
}
}
},
"cb34d07934bd210fd64da207415c49fc6e2870d3564164a2a5d541f713227fbf": {
"url": "https://api.cohere.com/compatibility/v1/chat/completions",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Say 'this is a test'\"}],\"stream\":true,\"model\":\"command-r7b-12-2024\"}"
},
"response": {
"body": "data: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\"\",\"role\":\"assistant\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\"This\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" is\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" a\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" test\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\".\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":\"stop\",\"delta\":{}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\",\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":5,\"total_tokens\":12}}\n\ndata: [DONE]\n\n",
"status": 200,
"statusText": "OK",
"headers": {
"access-control-expose-headers": "X-Debug-Trace-ID",
"alt-svc": "h3=\":443\"; ma=2592000,h3-29=\":443\"; ma=2592000",
"cache-control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0",
"content-type": "text/event-stream",
"expires": "Thu, 01 Jan 1970 00:00:00 UTC",
"pragma": "no-cache",
"server": "envoy",
"transfer-encoding": "chunked",
"vary": "Origin"
}
}
},
"8c6ffbc794573c463ed5666e3b560e5966cd975c2893c901c18adb696ba54a6a": {
"url": "https://api.cohere.com/compatibility/v1/chat/completions",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"command-r7b-12-2024\"}"
},
"response": {
"body": "{\"id\":\"f8bf661b-c600-44e5-8412-df37c9dcd985\",\"choices\":[{\"index\":0,\"finish_reason\":\"stop\",\"message\":{\"role\":\"assistant\",\"content\":\"One plus one is equal to two.\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion\",\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":8,\"total_tokens\":19}}",
"status": 200,
"statusText": "OK",
"headers": {
"access-control-expose-headers": "X-Debug-Trace-ID",
"alt-svc": "h3=\":443\"; ma=2592000,h3-29=\":443\"; ma=2592000",
"cache-control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0",
"content-type": "application/json",
"expires": "Thu, 01 Jan 1970 00:00:00 UTC",
"num_chars": "2635",
"num_tokens": "19",
"pragma": "no-cache",
"server": "envoy",
"vary": "Origin"
}
}
}
}
1 change: 1 addition & 0 deletions packages/tasks/src/inference-providers.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/// This list is for illustration purposes only.
/// in the `tasks` sub-package, we do not need actual strong typing of the inference providers.
const INFERENCE_PROVIDERS = [
"cohere",
"fal-ai",
"fireworks-ai",
"hf-inference",
Expand Down