Skip to content

Commit 85250e7

Browse files
Add Cerebras as a provider (#1242)
## What Adds Cerebras as an inference provider. ## Test Plan Added new tests for Cerebras both with and without streaming. ## What Should Reviewers Focus On? I used the Cohere PR as an example. --------- Co-authored-by: Lucain <[email protected]> Co-authored-by: Lucain <[email protected]>
1 parent 6e65421 commit 85250e7

File tree

8 files changed

+130
-0
lines changed

8 files changed

+130
-0
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Currently, we support the following providers:
5757
- [Together](https://together.xyz)
5858
- [Blackforestlabs](https://blackforestlabs.ai)
5959
- [Cohere](https://cohere.com)
60+
- [Cerebras](https://cerebras.ai/)
6061

6162
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
6263
```ts
@@ -82,6 +83,7 @@ Only a subset of models are supported when requesting third-party providers. You
8283
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
8384
- [Together supported models](https://huggingface.co/api/partners/together/models)
8485
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
86+
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
8587
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
8688

8789
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
22
import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
3+
import { CEREBRAS_CONFIG } from "../providers/cerebras";
34
import { COHERE_CONFIG } from "../providers/cohere";
45
import { FAL_AI_CONFIG } from "../providers/fal-ai";
56
import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
@@ -29,6 +30,7 @@ let tasks: Record<string, { models: { id: string }[] }> | null = null;
2930
*/
3031
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
3132
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
33+
cerebras: CEREBRAS_CONFIG,
3234
cohere: COHERE_CONFIG,
3335
"fal-ai": FAL_AI_CONFIG,
3436
"fireworks-ai": FIREWORKS_AI_CONFIG,
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/**
2+
* See the registered mapping of HF model ID => Cerebras model ID here:
3+
*
4+
* https://huggingface.co/api/partners/cerebras/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at Cerebras and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to Cerebras, please open an issue on the present repo
13+
* and we will tag Cerebras team members.
14+
*
15+
* Thanks!
16+
*/
17+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
18+
19+
const CEREBRAS_API_BASE_URL = "https://api.cerebras.ai";
20+
21+
const makeBody = (params: BodyParams): Record<string, unknown> => {
22+
return {
23+
...params.args,
24+
model: params.model,
25+
};
26+
};
27+
28+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
29+
return { Authorization: `Bearer ${params.accessToken}` };
30+
};
31+
32+
const makeUrl = (params: UrlParams): string => {
33+
return `${params.baseUrl}/v1/chat/completions`;
34+
};
35+
36+
export const CEREBRAS_CONFIG: ProviderConfig = {
37+
baseUrl: CEREBRAS_API_BASE_URL,
38+
makeBody,
39+
makeHeaders,
40+
makeUrl,
41+
};

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
1717
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1818
*/
1919
"black-forest-labs": {},
20+
cerebras: {},
2021
cohere: {},
2122
"fal-ai": {},
2223
"fireworks-ai": {},

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ export type InferenceTask = Exclude<PipelineType, "other">;
3030

3131
export const INFERENCE_PROVIDERS = [
3232
"black-forest-labs",
33+
"cerebras",
3334
"cohere",
3435
"fal-ai",
3536
"fireworks-ai",

packages/inference/test/HfInference.spec.ts

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,4 +1406,50 @@ describe.concurrent("HfInference", () => {
14061406
},
14071407
TIMEOUT
14081408
);
1409+
describe.concurrent(
1410+
"Cerebras",
1411+
() => {
1412+
const client = new HfInference(env.HF_CEREBRAS_KEY ?? "dummy");
1413+
1414+
HARDCODED_MODEL_ID_MAPPING["cerebras"] = {
1415+
"meta-llama/llama-3.1-8b-instruct": "llama3.1-8b",
1416+
};
1417+
1418+
it("chatCompletion", async () => {
1419+
const res = await client.chatCompletion({
1420+
model: "meta-llama/llama-3.1-8b-instruct",
1421+
provider: "cerebras",
1422+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1423+
});
1424+
if (res.choices && res.choices.length > 0) {
1425+
const completion = res.choices[0].message?.content;
1426+
expect(completion).toContain("two");
1427+
}
1428+
});
1429+
1430+
it("chatCompletion stream", async () => {
1431+
const stream = client.chatCompletionStream({
1432+
model: "meta-llama/llama-3.1-8b-instruct",
1433+
provider: "cerebras",
1434+
messages: [{ role: "user", content: "Say 'this is a test'" }],
1435+
stream: true,
1436+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1437+
1438+
let fullResponse = "";
1439+
for await (const chunk of stream) {
1440+
if (chunk.choices && chunk.choices.length > 0) {
1441+
const content = chunk.choices[0].delta?.content;
1442+
if (content) {
1443+
fullResponse += content;
1444+
}
1445+
}
1446+
}
1447+
1448+
// Verify we got a meaningful response
1449+
expect(fullResponse).toBeTruthy();
1450+
expect(fullResponse.length).toBeGreaterThan(0);
1451+
});
1452+
},
1453+
TIMEOUT
1454+
);
14091455
});

packages/inference/test/tapes.json

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7470,5 +7470,41 @@
74707470
"vary": "Origin"
74717471
}
74727472
}
7473+
},
7474+
"10bec4daddf2346c7a9f864941e1867cd523b44640476be3ce44740823f8e115": {
7475+
"url": "https://api.cerebras.ai/v1/chat/completions",
7476+
"init": {
7477+
"headers": {
7478+
"Content-Type": "application/json"
7479+
},
7480+
"method": "POST",
7481+
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"llama3.1-8b\"}"
7482+
},
7483+
"response": {
7484+
"body": "{\"id\":\"chatcmpl-081dc230-a18f-4c4a-b2b0-efe6a5d8767d\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"message\":{\"content\":\"two.\",\"role\":\"assistant\"}}],\"created\":1740721365,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion\",\"usage\":{\"prompt_tokens\":46,\"completion_tokens\":3,\"total_tokens\":49},\"time_info\":{\"queue_time\":0.000080831,\"prompt_time\":0.002364294,\"completion_time\":0.001345785,\"total_time\":0.005622386932373047,\"created\":1740721365}}",
7485+
"status": 200,
7486+
"statusText": "OK",
7487+
"headers": {
7488+
"content-type": "application/json"
7489+
}
7490+
}
7491+
},
7492+
"b3cad22ff43c9ca503ba3ec2cf3301e935679652e5512d942d12ae060465d2dd": {
7493+
"url": "https://api.cerebras.ai/v1/chat/completions",
7494+
"init": {
7495+
"headers": {
7496+
"Content-Type": "application/json"
7497+
},
7498+
"method": "POST",
7499+
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Say 'this is a test'\"}],\"stream\":true,\"model\":\"llama3.1-8b\"}"
7500+
},
7501+
"response": {
7502+
"body": "data: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"role\":\"assistant\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\"This\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\" is\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\" a\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\" test\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\".\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\",\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\",\"usage\":{\"prompt_tokens\":42,\"completion_tokens\":6,\"total_tokens\":48},\"time_info\":{\"queue_time\":0.000093189,\"prompt_time\":0.002155987,\"completion_time\":0.002688504,\"total_time\":0.0070416927337646484,\"created\":1740751481}}\"",
7503+
"status": 200,
7504+
"statusText": "OK",
7505+
"headers": {
7506+
"content-type": "text/event-stream"
7507+
}
7508+
}
74737509
}
74747510
}

packages/tasks/src/inference-providers.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/// This list is for illustration purposes only.
22
/// in the `tasks` sub-package, we do not need actual strong typing of the inference providers.
33
const INFERENCE_PROVIDERS = [
4+
"cerebras",
45
"cohere",
56
"fal-ai",
67
"fireworks-ai",

0 commit comments

Comments
 (0)