Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Currently, we support the following providers:
- [Blackforestlabs](https://blackforestlabs.ai)
- [Cohere](https://cohere.com)
- [Cerebras](https://cerebras.ai/)
- [CentML](https://centml.ai)
- [Groq](https://groq.com)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
Expand Down Expand Up @@ -89,6 +90,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
- [CentML supported models](https://huggingface.co/api/partners/centml/models)
- [Groq supported models](https://console.groq.com/docs/models)
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

Expand Down
5 changes: 5 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import * as BlackForestLabs from "../providers/black-forest-labs";
import * as Cerebras from "../providers/cerebras";
import * as CentML from "../providers/centml";
import * as Cohere from "../providers/cohere";
import * as FalAI from "../providers/fal-ai";
import * as FeatherlessAI from "../providers/featherless-ai";
Expand Down Expand Up @@ -55,6 +56,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
cerebras: {
conversational: new Cerebras.CerebrasConversationalTask(),
},
centml: {
conversational: new CentML.CentMLConversationalTask(),
"text-generation": new CentML.CentMLTextGenerationTask(),
},
cohere: {
conversational: new Cohere.CohereConversationalTask(),
},
Expand Down
71 changes: 71 additions & 0 deletions packages/inference/src/providers/centml.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/**
* CentML provider implementation for serverless inference.
* This provider supports chat completions and text generation through CentML's serverless endpoints.
*/
import type { ChatCompletionOutput, TextGenerationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { BodyParams } from "../types";
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";

const CENTML_API_BASE_URL = "https://api.centml.com";

export class CentMLConversationalTask extends BaseConversationalTask {
constructor() {
super("centml", CENTML_API_BASE_URL);
}

override makeRoute(): string {
return "openai/v1/chat/completions";
}

override preparePayload(params: BodyParams): Record<string, unknown> {
const { args, model } = params;
return {
...args,
model,
api_key: args.accessToken,
};
}

override async getResponse(response: ChatCompletionOutput): Promise<ChatCompletionOutput> {
if (
typeof response === "object" &&
Array.isArray(response?.choices) &&
typeof response?.created === "number" &&
typeof response?.id === "string" &&
typeof response?.model === "string" &&
typeof response?.usage === "object"
) {
return response;
}

throw new InferenceOutputError("Expected ChatCompletionOutput");
}
}

export class CentMLTextGenerationTask extends BaseTextGenerationTask {
constructor() {
super("centml", CENTML_API_BASE_URL);
}

override makeRoute(): string {
return "openai/v1/completions";
}

override preparePayload(params: BodyParams): Record<string, unknown> {
const { args, model } = params;
return {
...args,
model,
api_key: args.accessToken,
};
}

override async getResponse(response: TextGenerationOutput): Promise<TextGenerationOutput> {
if (typeof response === "object" && typeof response?.generated_text === "string") {
return response;
}

throw new InferenceOutputError("Expected TextGenerationOutput");
}
}
8 changes: 8 additions & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
*/
"black-forest-labs": {},
cerebras: {},
centml: {
"meta-llama/Llama-3.2-3B-Instruct": {
hfModelId: "meta-llama/Llama-3.2-3B-Instruct",
providerId: "meta-llama/Llama-3.2-3B-Instruct", // CentML expects same id
status: "live", // or "staging" if you prefer the warning
task: "conversational" // <-- WidgetType from @huggingface/tasks
}
},
cohere: {},
"fal-ai": {},
"featherless-ai": {},
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";
export const INFERENCE_PROVIDERS = [
"black-forest-labs",
"cerebras",
"centml",
"cohere",
"fal-ai",
"featherless-ai",
Expand Down