1+ /**
2+ * CentML provider implementation for serverless inference.
3+ * This provider supports chat completions and text generation through CentML's serverless endpoints.
4+ */
5+ import type { ChatCompletionOutput , TextGenerationOutput } from "@huggingface/tasks" ;
6+ import { InferenceOutputError } from "../lib/InferenceOutputError" ;
7+ import type { BodyParams } from "../types" ;
8+ import { BaseConversationalTask , BaseTextGenerationTask } from "./providerHelper" ;
9+
10+ const CENTML_API_BASE_URL = "https://api.centml.ai" ;
11+
12+ export class CentMLConversationalTask extends BaseConversationalTask {
13+ constructor ( ) {
14+ super ( "centml" , CENTML_API_BASE_URL ) ;
15+ }
16+
17+ override preparePayload ( params : BodyParams ) : Record < string , unknown > {
18+ const { args, model } = params ;
19+ return {
20+ ...args ,
21+ model,
22+ api_key : args . accessToken , // Use the accessToken from args
23+ } ;
24+ }
25+
26+ override async getResponse ( response : ChatCompletionOutput ) : Promise < ChatCompletionOutput > {
27+ if (
28+ typeof response === "object" &&
29+ Array . isArray ( response ?. choices ) &&
30+ typeof response ?. created === "number" &&
31+ typeof response ?. id === "string" &&
32+ typeof response ?. model === "string" &&
33+ typeof response ?. usage === "object"
34+ ) {
35+ return response ;
36+ }
37+
38+ throw new InferenceOutputError ( "Expected ChatCompletionOutput" ) ;
39+ }
40+ }
41+
42+ export class CentMLTextGenerationTask extends BaseTextGenerationTask {
43+ constructor ( ) {
44+ super ( "centml" , CENTML_API_BASE_URL ) ;
45+ }
46+
47+ override preparePayload ( params : BodyParams ) : Record < string , unknown > {
48+ const { args, model } = params ;
49+ return {
50+ ...args ,
51+ model,
52+ api_key : args . accessToken , // Use the accessToken from args
53+ } ;
54+ }
55+
56+ override async getResponse ( response : TextGenerationOutput ) : Promise < TextGenerationOutput > {
57+ if (
58+ typeof response === "object" &&
59+ typeof response ?. generated_text === "string"
60+ ) {
61+ return response ;
62+ }
63+
64+ throw new InferenceOutputError ( "Expected TextGenerationOutput" ) ;
65+ }
66+ }
0 commit comments