diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 72cc35bf62..bf439aca6d 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -130,6 +130,7 @@ export const PROVIDERS: Record { + choices: Array<{ + text: string; + finish_reason: TextGenerationOutputFinishReason; + }>; +} + interface NscaleCloudBase64ImageGeneration { data: Array<{ b64_json: string; @@ -34,6 +47,39 @@ export class NscaleConversationalTask extends BaseConversationalTask { } } +export class NscaleTextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("nscale", NSCALE_API_BASE_URL); + } + + override makeRoute(): string { + return "v1/completions"; + } + + override preparePayload(params: BodyParams): Record { + return { + model: params.model, + prompt: params.args.inputs, + ...(params.args.parameters || {}), + }; + } + + override async getResponse(response: NscaleTextGenerationOutput): Promise { + if ( + typeof response === "object" && + "choices" in response && + Array.isArray(response?.choices) && + response.choices.length > 0 && + typeof response.choices[0]?.text === "string" + ) { + return { + generated_text: response.choices[0].text, + }; + } + throw new InferenceClientProviderOutputError("Received malformed response from Nscale text generation API"); + } +} + export class NscaleTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper { constructor() { super("nscale", NSCALE_API_BASE_URL); diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index d5cefcc60e..407bfe176e 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1918,14 +1918,21 @@ describe.skip("InferenceClient", () => { "meta-llama/Llama-3.1-8B-Instruct": { provider: "nscale", hfModelId: "meta-llama/Llama-3.1-8B-Instruct", - providerId: "nscale", + providerId: "meta-llama/Llama-3.1-8B-Instruct", status: "live", task: "conversational", }, + "mistralai/Devstral-Small-2505": { + provider: "nscale", + hfModelId: "mistralai/Devstral-Small-2505", + providerId: "mistralai/Devstral-Small-2505", + status: "staging", + task: "text-generation", + }, "black-forest-labs/FLUX.1-schnell": { provider: "nscale", hfModelId: "black-forest-labs/FLUX.1-schnell", - providerId: "flux-schnell", + providerId: "black-forest-labs/FLUX.1-schnell", status: "live", task: "text-to-image", }, @@ -1969,6 +1976,21 @@ describe.skip("InferenceClient", () => { }); expect(res).toBeInstanceOf(Blob); }); + + it("textGeneration", async () => { + const res = await client.textGeneration({ + model: "mistralai/Devstral-Small-2505", + provider: "nscale", + inputs: "1+1=", + parameters: { + temperature: 0, + max_tokens: 1, + }, + }); + + expect(res.generated_text.length > 0); + expect(res.generated_text).toContain("2"); + }); }, TIMEOUT );