diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 357b34b51b..cb6195ab07 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,6 +42,7 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_BLACK_FOREST_LABS_KEY: dummy + HF_COHERE_KEY: dummy HF_FAL_KEY: dummy HF_FIREWORKS_KEY: dummy HF_HYPERBOLIC_KEY: dummy @@ -87,6 +88,7 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_BLACK_FOREST_LABS_KEY: dummy + HF_COHERE_KEY: dummy HF_FAL_KEY: dummy HF_FIREWORKS_KEY: dummy HF_HYPERBOLIC_KEY: dummy @@ -159,6 +161,7 @@ jobs: NPM_CONFIG_REGISTRY: http://localhost:4874/ HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_BLACK_FOREST_LABS_KEY: dummy + HF_COHERE_KEY: dummy HF_FAL_KEY: dummy HF_FIREWORKS_KEY: dummy HF_HYPERBOLIC_KEY: dummy diff --git a/README.md b/README.md index eeda91985b..a5ab0062c1 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ await uploadFile({ // Can work with native File in browsers file: { path: "pytorch_model.bin", - content: new Blob(...) + content: new Blob(...) } }); @@ -39,7 +39,7 @@ await inference.chatCompletion({ ], max_tokens: 512, temperature: 0.5, - provider: "sambanova", // or together, fal-ai, replicate, … + provider: "sambanova", // or together, fal-ai, replicate, cohere … }); await inference.textToImage({ @@ -146,12 +146,12 @@ for await (const chunk of inference.chatCompletionStream({ console.log(chunk.choices[0].delta.content); } -/// Using a third-party provider: +/// Using a third-party provider: await inference.chatCompletion({ model: "meta-llama/Llama-3.1-8B-Instruct", messages: [{ role: "user", content: "Hello, nice to meet you!" }], max_tokens: 512, - provider: "sambanova", // or together, fal-ai, replicate, … + provider: "sambanova", // or together, fal-ai, replicate, cohere … }) await inference.textToImage({ @@ -211,7 +211,7 @@ await uploadFile({ // Can work with native File in browsers file: { path: "pytorch_model.bin", - content: new Blob(...) + content: new Blob(...) } }); @@ -244,7 +244,7 @@ console.log(messages); // contains the data // or you can run the code directly, however you can't check that the code is safe to execute this way, use at your own risk. const messages = await agent.run("Draw a picture of a cat wearing a top hat. Then caption the picture and read it out loud.") -console.log(messages); +console.log(messages); ``` There are more features of course, check each library's README! diff --git a/packages/inference/README.md b/packages/inference/README.md index 3289fc6746..a6c0bc4a5f 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -56,6 +56,7 @@ Currently, we support the following providers: - [Sambanova](https://sambanova.ai) - [Together](https://together.xyz) - [Blackforestlabs](https://blackforestlabs.ai) +- [Cohere](https://cohere.com) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token. ```ts @@ -80,6 +81,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Replicate supported models](https://huggingface.co/api/partners/replicate/models) - [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models) - [Together supported models](https://huggingface.co/api/partners/together/models) +- [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending) ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type. diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index 98fc277d79..8121938e36 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -1,5 +1,6 @@ import { HF_HUB_URL, HF_ROUTER_URL } from "../config"; import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs"; +import { COHERE_CONFIG } from "../providers/cohere"; import { FAL_AI_CONFIG } from "../providers/fal-ai"; import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai"; import { HF_INFERENCE_CONFIG } from "../providers/hf-inference"; @@ -27,6 +28,7 @@ let tasks: Record | null = null; */ const providerConfigs: Record = { "black-forest-labs": BLACK_FOREST_LABS_CONFIG, + cohere: COHERE_CONFIG, "fal-ai": FAL_AI_CONFIG, "fireworks-ai": FIREWORKS_AI_CONFIG, "hf-inference": HF_INFERENCE_CONFIG, diff --git a/packages/inference/src/providers/cohere.ts b/packages/inference/src/providers/cohere.ts new file mode 100644 index 0000000000..b48e073460 --- /dev/null +++ b/packages/inference/src/providers/cohere.ts @@ -0,0 +1,42 @@ +/** + * See the registered mapping of HF model ID => Cohere model ID here: + * + * https://huggingface.co/api/partners/cohere/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Cohere and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Cohere, please open an issue on the present repo + * and we will tag Cohere team members. + * + * Thanks! + */ +import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types"; + +const COHERE_API_BASE_URL = "https://api.cohere.com"; + + +const makeBody = (params: BodyParams): Record => { + return { + ...params.args, + model: params.model, + }; +}; + +const makeHeaders = (params: HeaderParams): Record => { + return { Authorization: `Bearer ${params.accessToken}` }; +}; + +const makeUrl = (params: UrlParams): string => { + return `${params.baseUrl}/compatibility/v1/chat/completions`; +}; + +export const COHERE_CONFIG: ProviderConfig = { + baseUrl: COHERE_API_BASE_URL, + makeBody, + makeHeaders, + makeUrl, +}; diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index b782767a11..6089a14c52 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -17,6 +17,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record; export const INFERENCE_PROVIDERS = [ "black-forest-labs", + "cohere", "fal-ai", "fireworks-ai", "hf-inference", diff --git a/packages/inference/test/HfInference.spec.ts b/packages/inference/test/HfInference.spec.ts index ec10bdc1d3..69bbb83f03 100644 --- a/packages/inference/test/HfInference.spec.ts +++ b/packages/inference/test/HfInference.spec.ts @@ -1350,4 +1350,51 @@ describe.concurrent("HfInference", () => { }, TIMEOUT ); + describe.concurrent( + "Cohere", + () => { + const client = new HfInference(env.HF_COHERE_KEY); + + HARDCODED_MODEL_ID_MAPPING["cohere"] = { + "CohereForAI/c4ai-command-r7b-12-2024": "command-r7b-12-2024", + "CohereForAI/aya-expanse-8b": "c4ai-aya-expanse-8b", + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "CohereForAI/c4ai-command-r7b-12-2024", + provider: "cohere", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "CohereForAI/c4ai-command-r7b-12-2024", + provider: "cohere", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + }, + TIMEOUT + ); }); diff --git a/packages/inference/test/tapes.json b/packages/inference/test/tapes.json index 658ebdcc2e..9479c41ce0 100644 --- a/packages/inference/test/tapes.json +++ b/packages/inference/test/tapes.json @@ -7386,5 +7386,58 @@ "content-type": "image/jpeg" } } + }, + "cb34d07934bd210fd64da207415c49fc6e2870d3564164a2a5d541f713227fbf": { + "url": "https://api.cohere.com/compatibility/v1/chat/completions", + "init": { + "headers": { + "Content-Type": "application/json" + }, + "method": "POST", + "body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Say 'this is a test'\"}],\"stream\":true,\"model\":\"command-r7b-12-2024\"}" + }, + "response": { + "body": "data: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\"\",\"role\":\"assistant\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\"This\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" is\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" a\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\" test\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":null,\"delta\":{\"content\":\".\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"3178eb0c-d523-4504-bb82-01b8f02da6da\",\"choices\":[{\"index\":0,\"finish_reason\":\"stop\",\"delta\":{}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion.chunk\",\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":5,\"total_tokens\":12}}\n\ndata: [DONE]\n\n", + "status": 200, + "statusText": "OK", + "headers": { + "access-control-expose-headers": "X-Debug-Trace-ID", + "alt-svc": "h3=\":443\"; ma=2592000,h3-29=\":443\"; ma=2592000", + "cache-control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0", + "content-type": "text/event-stream", + "expires": "Thu, 01 Jan 1970 00:00:00 UTC", + "pragma": "no-cache", + "server": "envoy", + "transfer-encoding": "chunked", + "vary": "Origin" + } + } + }, + "8c6ffbc794573c463ed5666e3b560e5966cd975c2893c901c18adb696ba54a6a": { + "url": "https://api.cohere.com/compatibility/v1/chat/completions", + "init": { + "headers": { + "Content-Type": "application/json" + }, + "method": "POST", + "body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"command-r7b-12-2024\"}" + }, + "response": { + "body": "{\"id\":\"f8bf661b-c600-44e5-8412-df37c9dcd985\",\"choices\":[{\"index\":0,\"finish_reason\":\"stop\",\"message\":{\"role\":\"assistant\",\"content\":\"One plus one is equal to two.\"}}],\"created\":1740652112,\"model\":\"command-r7b-12-2024\",\"object\":\"chat.completion\",\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":8,\"total_tokens\":19}}", + "status": 200, + "statusText": "OK", + "headers": { + "access-control-expose-headers": "X-Debug-Trace-ID", + "alt-svc": "h3=\":443\"; ma=2592000,h3-29=\":443\"; ma=2592000", + "cache-control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0", + "content-type": "application/json", + "expires": "Thu, 01 Jan 1970 00:00:00 UTC", + "num_chars": "2635", + "num_tokens": "19", + "pragma": "no-cache", + "server": "envoy", + "vary": "Origin" + } + } } } \ No newline at end of file diff --git a/packages/tasks/src/inference-providers.ts b/packages/tasks/src/inference-providers.ts index 82f3d808f7..49de2553ab 100644 --- a/packages/tasks/src/inference-providers.ts +++ b/packages/tasks/src/inference-providers.ts @@ -1,6 +1,7 @@ /// This list is for illustration purposes only. /// in the `tasks` sub-package, we do not need actual strong typing of the inference providers. const INFERENCE_PROVIDERS = [ + "cohere", "fal-ai", "fireworks-ai", "hf-inference",