Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_BLACK_FOREST_LABS_KEY: dummy
HF_CEREBRAS_KEY: dummy
HF_COHERE_KEY: dummy
HF_FAL_KEY: dummy
HF_FIREWORKS_KEY: dummy
Expand Down Expand Up @@ -88,6 +89,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_BLACK_FOREST_LABS_KEY: dummy
HF_CEREBRAS_KEY: dummy
HF_COHERE_KEY: dummy
HF_FAL_KEY: dummy
HF_FIREWORKS_KEY: dummy
Expand Down Expand Up @@ -161,6 +163,7 @@ jobs:
NPM_CONFIG_REGISTRY: http://localhost:4874/
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_BLACK_FOREST_LABS_KEY: dummy
HF_CEREBRAS_KEY: dummy
HF_COHERE_KEY: dummy
HF_FAL_KEY: dummy
HF_FIREWORKS_KEY: dummy
Expand Down
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Currently, we support the following providers:
- [Together](https://together.xyz)
- [Blackforestlabs](https://blackforestlabs.ai)
- [Cohere](https://cohere.com)
- [Cerebras](https://cerebras.ai/)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
```ts
Expand All @@ -82,6 +83,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
Expand Down
2 changes: 2 additions & 0 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs";
import { CEREBRAS_CONFIG } from "../providers/cerebras";
import { COHERE_CONFIG } from "../providers/cohere";
import { FAL_AI_CONFIG } from "../providers/fal-ai";
import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai";
Expand Down Expand Up @@ -29,6 +30,7 @@ let tasks: Record<string, { models: { id: string }[] }> | null = null;
*/
const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
"black-forest-labs": BLACK_FOREST_LABS_CONFIG,
cerebras: CEREBRAS_CONFIG,
cohere: COHERE_CONFIG,
"fal-ai": FAL_AI_CONFIG,
"fireworks-ai": FIREWORKS_AI_CONFIG,
Expand Down
41 changes: 41 additions & 0 deletions packages/inference/src/providers/cerebras.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* See the registered mapping of HF model ID => Cerebras model ID here:
*
* https://huggingface.co/api/partners/cerebras/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Cerebras and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Cerebras, please open an issue on the present repo
* and we will tag Cerebras team members.
*
* Thanks!
*/
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";

const CEREBRAS_API_BASE_URL = "https://api.cerebras.ai";

const makeBody = (params: BodyParams): Record<string, unknown> => {
return {
...params.args,
model: params.model,
};
};

const makeHeaders = (params: HeaderParams): Record<string, string> => {
return { Authorization: `Bearer ${params.accessToken}` };
};

const makeUrl = (params: UrlParams): string => {
return `${params.baseUrl}/v1/chat/completions`;
};

export const CEREBRAS_CONFIG: ProviderConfig = {
baseUrl: CEREBRAS_API_BASE_URL,
makeBody,
makeHeaders,
makeUrl,
};
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
*/
"black-forest-labs": {},
cerebras: {},
cohere: {},
"fal-ai": {},
"fireworks-ai": {},
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export type InferenceTask = Exclude<PipelineType, "other">;

export const INFERENCE_PROVIDERS = [
"black-forest-labs",
"cerebras",
"cohere",
"fal-ai",
"fireworks-ai",
Expand Down
46 changes: 46 additions & 0 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1406,4 +1406,50 @@ describe.concurrent("HfInference", () => {
},
TIMEOUT
);
describe.concurrent(
"Cerebras",
() => {
const client = new HfInference(env.HF_CEREBRAS_KEY);

HARDCODED_MODEL_ID_MAPPING["cerebras"] = {
"meta-llama/llama-3.1-8b-instruct": "llama3.1-8b",
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "meta-llama/llama-3.1-8b-instruct",
provider: "cerebras",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "meta-llama/llama-3.1-8b-instruct",
provider: "cerebras",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});
},
TIMEOUT
);
});
36 changes: 36 additions & 0 deletions packages/inference/test/tapes.json
Original file line number Diff line number Diff line change
Expand Up @@ -7470,5 +7470,41 @@
"vary": "Origin"
}
}
},
"10bec4daddf2346c7a9f864941e1867cd523b44640476be3ce44740823f8e115": {
"url": "https://api.cerebras.ai/v1/chat/completions",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"llama3.1-8b\"}"
},
"response": {
"body": "{\"id\":\"chatcmpl-081dc230-a18f-4c4a-b2b0-efe6a5d8767d\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"message\":{\"content\":\"two.\",\"role\":\"assistant\"}}],\"created\":1740721365,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion\",\"usage\":{\"prompt_tokens\":46,\"completion_tokens\":3,\"total_tokens\":49},\"time_info\":{\"queue_time\":0.000080831,\"prompt_time\":0.002364294,\"completion_time\":0.001345785,\"total_time\":0.005622386932373047,\"created\":1740721365}}",
"status": 200,
"statusText": "OK",
"headers": {
"content-type": "application/json"
}
}
},
"b3cad22ff43c9ca503ba3ec2cf3301e935679652e5512d942d12ae060465d2dd": {
"url": "https://api.cerebras.ai/v1/chat/completions",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Say 'this is a test'\"}],\"stream\":true,\"model\":\"llama3.1-8b\"}"
},
"response": {
"body": "data: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"role\":\"assistant\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\"This\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\" is\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\" a\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\" test\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{\"content\":\".\"},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{},\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-f75a0132-920e-4700-b366-bfe92b75f3d6\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\",\"index\":0}],\"created\":1740751481,\"model\":\"llama3.1-8b\",\"system_fingerprint\":\"fp_d0e90c449e\",\"object\":\"chat.completion.chunk\",\"usage\":{\"prompt_tokens\":42,\"completion_tokens\":6,\"total_tokens\":48},\"time_info\":{\"queue_time\":0.000093189,\"prompt_time\":0.002155987,\"completion_time\":0.002688504,\"total_time\":0.0070416927337646484,\"created\":1740751481}}\"",
"status": 200,
"statusText": "OK",
"headers": {
"content-type": "text/event-stream"
}
}
}
}
1 change: 1 addition & 0 deletions packages/tasks/src/inference-providers.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/// This list is for illustration purposes only.
/// in the `tasks` sub-package, we do not need actual strong typing of the inference providers.
const INFERENCE_PROVIDERS = [
"cerebras",
"cohere",
"fal-ai",
"fireworks-ai",
Expand Down