diff --git a/packages/inference/README.md b/packages/inference/README.md index 0ea60b2be7..ed43e0644f 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -58,6 +58,7 @@ Currently, we support the following providers: - [OVHcloud](https://endpoints.ai.cloud.ovh.net/) - [Replicate](https://replicate.com) - [Sambanova](https://sambanova.ai) +- [Scaleway](https://www.scaleway.com/en/generative-apis/) - [Together](https://together.xyz) - [Blackforestlabs](https://blackforestlabs.ai) - [Cohere](https://cohere.com) @@ -92,6 +93,7 @@ Only a subset of models are supported when requesting third-party providers. You - [OVHcloud supported models](https://huggingface.co/api/partners/ovhcloud/models) - [Replicate supported models](https://huggingface.co/api/partners/replicate/models) - [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models) +- [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models) - [Together supported models](https://huggingface.co/api/partners/together/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 72cc35bf62..5836fb9f73 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -47,6 +47,7 @@ import type { } from "../providers/providerHelper.js"; import * as Replicate from "../providers/replicate.js"; import * as Sambanova from "../providers/sambanova.js"; +import * as Scaleway from "../providers/scaleway.js"; import * as Together from "../providers/together.js"; import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js"; import { InferenceClientInputError } from "../errors.js"; @@ -148,6 +149,11 @@ export const PROVIDERS: Record Scaleway model ID here: + * + * https://huggingface.co/api/partners/scaleway/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Scaleway and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Scaleway, please open an issue on the present repo + * and we will tag Scaleway team members. + * + * Thanks! + */ +import type { FeatureExtractionOutput, TextGenerationOutput } from "@huggingface/tasks"; +import type { BodyParams } from "../types.js"; +import { InferenceClientProviderOutputError } from "../errors.js"; + +import type { FeatureExtractionTaskHelper } from "./providerHelper.js"; +import { BaseConversationalTask, TaskProviderHelper, BaseTextGenerationTask } from "./providerHelper.js"; + +const SCALEWAY_API_BASE_URL = "https://api.scaleway.ai"; + +interface ScalewayEmbeddingsResponse { + data: Array<{ + embedding: number[]; + }>; +} + +export class ScalewayConversationalTask extends BaseConversationalTask { + constructor() { + super("scaleway", SCALEWAY_API_BASE_URL); + } +} + +export class ScalewayTextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("scaleway", SCALEWAY_API_BASE_URL); + } + + override preparePayload(params: BodyParams): Record { + return { + model: params.model, + ...params.args, + prompt: params.args.inputs, + }; + } + + override async getResponse(response: unknown): Promise { + if ( + typeof response === "object" && + response !== null && + "choices" in response && + Array.isArray(response.choices) && + response.choices.length > 0 + ) { + const completion: unknown = response.choices[0]; + if ( + typeof completion === "object" && + !!completion && + "text" in completion && + completion.text && + typeof completion.text === "string" + ) { + return { + generated_text: completion.text, + }; + } + } + throw new InferenceClientProviderOutputError("Received malformed response from Scaleway text generation API"); + } +} + +export class ScalewayFeatureExtractionTask extends TaskProviderHelper implements FeatureExtractionTaskHelper { + constructor() { + super("scaleway", SCALEWAY_API_BASE_URL); + } + + preparePayload(params: BodyParams): Record { + return { + input: params.args.inputs, + model: params.model, + }; + } + + makeRoute(): string { + return "v1/embeddings"; + } + + async getResponse(response: ScalewayEmbeddingsResponse): Promise { + return response.data.map((item) => item.embedding); + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 5d6be233d8..b31843b99b 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -61,6 +61,7 @@ export const INFERENCE_PROVIDERS = [ "ovhcloud", "replicate", "sambanova", + "scaleway", "together", ] as const; diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index d5cefcc60e..1cc60a43a9 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1516,6 +1516,116 @@ describe.skip("InferenceClient", () => { TIMEOUT ); + describe.concurrent( + "Scaleway", + () => { + const client = new InferenceClient(env.HF_SCALEWAY_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING.scaleway = { + "meta-llama/Llama-3.1-8B-Instruct": { + provider: "scaleway", + hfModelId: "meta-llama/Llama-3.1-8B-Instruct", + providerId: "llama-3.1-8b-instruct", + status: "live", + task: "conversational", + }, + "BAAI/bge-multilingual-gemma2": { + provider: "scaleway", + hfModelId: "BAAI/bge-multilingual-gemma2", + providerId: "bge-multilingual-gemma2", + task: "feature-extraction", + status: "live", + }, + "google/gemma-3-27b-it": { + provider: "scaleway", + hfModelId: "google/gemma-3-27b-it", + providerId: "gemma-3-27b-it", + task: "conversational", + status: "live", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "scaleway", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + tool_choice: "none", + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toMatch(/(to )?(two|2)/i); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "scaleway", + messages: [{ role: "system", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toMatch(/(two|2)/i); + }); + + it("chatCompletion multimodal", async () => { + const res = await client.chatCompletion({ + model: "google/gemma-3-27b-it", + provider: "scaleway", + messages: [ + { + role: "user", + content: [ + { + type: "image_url", + image_url: { + url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg", + }, + }, + { type: "text", text: "What is this?" }, + ], + }, + ], + }); + expect(res.choices).toBeDefined(); + expect(res.choices?.length).toBeGreaterThan(0); + expect(res.choices?.[0].message?.content).toContain("Statue of Liberty"); + }); + + it("textGeneration", async () => { + const res = await client.textGeneration({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "scaleway", + inputs: "Once upon a time,", + temperature: 0, + max_tokens: 19, + }); + + expect(res).toMatchObject({ + generated_text: + " in a small village nestled in the rolling hills of the countryside, there lived a young girl named", + }); + }); + + it("featureExtraction", async () => { + const res = await client.featureExtraction({ + model: "BAAI/bge-multilingual-gemma2", + provider: "scaleway", + inputs: "That is a happy person", + }); + + expect(res).toBeInstanceOf(Array); + expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)])); + }); + }, + TIMEOUT + ); + describe.concurrent("3rd party providers", () => { it("chatCompletion - fails with unsupported model", async () => { expect(