diff --git a/packages/inference/README.md b/packages/inference/README.md index a6c0bc4a5f..3907b43340 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -57,6 +57,7 @@ Currently, we support the following providers: - [Together](https://together.xyz) - [Blackforestlabs](https://blackforestlabs.ai) - [Cohere](https://cohere.com) +- [Cerebras](https://cerebras.ai/) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token. ```ts @@ -82,6 +83,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models) - [Together supported models](https://huggingface.co/api/partners/together/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) +- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) - [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending) ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type. diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index ba16ab3105..a6472d99db 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -1,5 +1,6 @@ import { HF_HUB_URL, HF_ROUTER_URL } from "../config"; import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs"; +import { CEREBRAS_CONFIG } from "../providers/cerebras"; import { COHERE_CONFIG } from "../providers/cohere"; import { FAL_AI_CONFIG } from "../providers/fal-ai"; import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai"; @@ -29,6 +30,7 @@ let tasks: Record | null = null; */ const providerConfigs: Record = { "black-forest-labs": BLACK_FOREST_LABS_CONFIG, + cerebras: CEREBRAS_CONFIG, cohere: COHERE_CONFIG, "fal-ai": FAL_AI_CONFIG, "fireworks-ai": FIREWORKS_AI_CONFIG, diff --git a/packages/inference/src/providers/cerebras.ts b/packages/inference/src/providers/cerebras.ts new file mode 100644 index 0000000000..fbbeae92e7 --- /dev/null +++ b/packages/inference/src/providers/cerebras.ts @@ -0,0 +1,41 @@ +/** + * See the registered mapping of HF model ID => Cerebras model ID here: + * + * https://huggingface.co/api/partners/cerebras/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Cerebras and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Cerebras, please open an issue on the present repo + * and we will tag Cerebras team members. + * + * Thanks! + */ +import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types"; + +const CEREBRAS_API_BASE_URL = "https://api.cerebras.ai"; + +const makeBody = (params: BodyParams): Record => { + return { + ...params.args, + model: params.model, + }; +}; + +const makeHeaders = (params: HeaderParams): Record => { + return { Authorization: `Bearer ${params.accessToken}` }; +}; + +const makeUrl = (params: UrlParams): string => { + return `${params.baseUrl}/v1/chat/completions`; +}; + +export const CEREBRAS_CONFIG: ProviderConfig = { + baseUrl: CEREBRAS_API_BASE_URL, + makeBody, + makeHeaders, + makeUrl, +}; diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index f1c0fd1eae..184a5d2425 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -17,6 +17,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record; export const INFERENCE_PROVIDERS = [ "black-forest-labs", + "cerebras", "cohere", "fal-ai", "fireworks-ai", diff --git a/packages/inference/test/HfInference.spec.ts b/packages/inference/test/HfInference.spec.ts index 2f225efc9f..c6c5350408 100644 --- a/packages/inference/test/HfInference.spec.ts +++ b/packages/inference/test/HfInference.spec.ts @@ -1406,4 +1406,50 @@ describe.concurrent("HfInference", () => { }, TIMEOUT ); + describe.concurrent( + "Cerebras", + () => { + const client = new HfInference(env.HF_CEREBRAS_KEY ?? "dummy"); + + HARDCODED_MODEL_ID_MAPPING["cerebras"] = { + "meta-llama/llama-3.1-8b-instruct": "llama3.1-8b", + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/llama-3.1-8b-instruct", + provider: "cerebras", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/llama-3.1-8b-instruct", + provider: "cerebras", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + }, + TIMEOUT + ); }); diff --git a/packages/inference/test/tapes.json b/packages/inference/test/tapes.json index 09307395b9..a9c490c147 100644 --- a/packages/inference/test/tapes.json +++ b/packages/inference/test/tapes.json @@ -7470,5 +7470,41 @@ "vary": "Origin" } } + }, + "10bec4daddf2346c7a9f864941e1867cd523b44640476be3ce44740823f8e115": { + "url": "https://api.cerebras.ai/v1/chat/completions", + "init": { + "headers": { + "Content-Type": "application/json" + }, + "method": "POST", + "body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"llama3.1-8b\"}" + }, + "response": { + "body": "{\"id\":\"chatcmpl-081dc230-a18f-4c4a-b2b0-efe6a5d8767d\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"message\":{\"content\":\"two.\",\"role\":\"assistant\"}}],\"created\":1740721365,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion\",\"usage\":{\"prompt_tokens\":46,\"completion_tokens\":3,\"total_tokens\":49},\"time_info\":{\"queue_time\":0.000080831,\"prompt_time\":0.002364294,\"completion_time\":0.001345785,\"total_time\":0.005622386932373047,\"created\":1740721365}}", + "status": 200, + "statusText": "OK", + "headers": { + "content-type": "application/json" + } + } + }, + "b3cad22ff43c9ca503ba3ec2cf3301e935679652e5512d942d12ae060465d2dd": { + "url": "https://api.cerebras.ai/v1/chat/completions", + "init": { + "headers": { + "Content-Type": "application/json" + }, + "method": "POST", + "body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Say 'this is a test'\"}],\"stream\":true,\"model\":\"llama3.1-8b\"}" + }, + "response": { + "body": "data: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"role\":\"assistant\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\"This\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\" is\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\" a\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\" test\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\".\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\",\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\",\"usage\":{\"prompt_tokens\":42,\"completion_tokens\":6,\"total_tokens\":48},\"time_info\":{\"queue_time\":0.000093189,\"prompt_time\":0.002155987,\"completion_time\":0.002688504,\"total_time\":0.0070416927337646484,\"created\":1740751481}}\"", + "status": 200, + "statusText": "OK", + "headers": { + "content-type": "text/event-stream" + } + } } } diff --git a/packages/tasks/src/inference-providers.ts b/packages/tasks/src/inference-providers.ts index 49de2553ab..41f5b8d729 100644 --- a/packages/tasks/src/inference-providers.ts +++ b/packages/tasks/src/inference-providers.ts @@ -1,6 +1,7 @@ /// This list is for illustration purposes only. /// in the `tasks` sub-package, we do not need actual strong typing of the inference providers. const INFERENCE_PROVIDERS = [ + "cerebras", "cohere", "fal-ai", "fireworks-ai",