Skip to content

Commit 3452916

Browse files
committed
add centml
1 parent 2475d6d commit 3452916

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/**
2+
* CentML provider implementation for serverless inference.
3+
* This provider supports chat completions and text generation through CentML's serverless endpoints.
4+
*/
5+
import type { ChatCompletionOutput, TextGenerationOutput } from "@huggingface/tasks";
6+
import { InferenceOutputError } from "../lib/InferenceOutputError";
7+
import type { BodyParams } from "../types";
8+
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";
9+
10+
const CENTML_API_BASE_URL = "https://api.centml.ai";
11+
12+
export class CentMLConversationalTask extends BaseConversationalTask {
13+
constructor() {
14+
super("centml", CENTML_API_BASE_URL);
15+
}
16+
17+
override preparePayload(params: BodyParams): Record<string, unknown> {
18+
const { args, model } = params;
19+
return {
20+
...args,
21+
model,
22+
api_key: args.accessToken, // Use the accessToken from args
23+
};
24+
}
25+
26+
override async getResponse(response: ChatCompletionOutput): Promise<ChatCompletionOutput> {
27+
if (
28+
typeof response === "object" &&
29+
Array.isArray(response?.choices) &&
30+
typeof response?.created === "number" &&
31+
typeof response?.id === "string" &&
32+
typeof response?.model === "string" &&
33+
typeof response?.usage === "object"
34+
) {
35+
return response;
36+
}
37+
38+
throw new InferenceOutputError("Expected ChatCompletionOutput");
39+
}
40+
}
41+
42+
export class CentMLTextGenerationTask extends BaseTextGenerationTask {
43+
constructor() {
44+
super("centml", CENTML_API_BASE_URL);
45+
}
46+
47+
override preparePayload(params: BodyParams): Record<string, unknown> {
48+
const { args, model } = params;
49+
return {
50+
...args,
51+
model,
52+
api_key: args.accessToken, // Use the accessToken from args
53+
};
54+
}
55+
56+
override async getResponse(response: TextGenerationOutput): Promise<TextGenerationOutput> {
57+
if (
58+
typeof response === "object" &&
59+
typeof response?.generated_text === "string"
60+
) {
61+
return response;
62+
}
63+
64+
throw new InferenceOutputError("Expected TextGenerationOutput");
65+
}
66+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";
4040
export const INFERENCE_PROVIDERS = [
4141
"black-forest-labs",
4242
"cerebras",
43+
"centml",
4344
"cohere",
4445
"fal-ai",
4546
"featherless-ai",

0 commit comments

Comments
 (0)