diff --git a/packages/inference/README.md b/packages/inference/README.md index ad5c9fc1a6..1d2013d519 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -54,6 +54,7 @@ Currently, we support the following providers: - [Nebius](https://studio.nebius.ai) - [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link) - [Nscale](https://nscale.com) +- [OVHcloud](https://endpoints.ai.cloud.ovh.net/) - [Replicate](https://replicate.com) - [Sambanova](https://sambanova.ai) - [Together](https://together.xyz) @@ -84,6 +85,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models) - [Nebius supported models](https://huggingface.co/api/partners/nebius/models) - [Nscale supported models](https://huggingface.co/api/partners/nscale/models) +- [OVHcloud supported models](https://huggingface.co/api/partners/ovhcloud/models) - [Replicate supported models](https://huggingface.co/api/partners/replicate/models) - [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models) - [Together supported models](https://huggingface.co/api/partners/together/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 14bd941987..2a1ce00fe1 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -11,6 +11,7 @@ import * as Nebius from "../providers/nebius"; import * as Novita from "../providers/novita"; import * as Nscale from "../providers/nscale"; import * as OpenAI from "../providers/openai"; +import * as OvhCloud from "../providers/ovhcloud"; import type { AudioClassificationTaskHelper, AudioToAudioTaskHelper, @@ -126,6 +127,10 @@ export const PROVIDERS: Record OVHcloud model ID here: + * + * https://huggingface.co/api/partners/ovhcloud/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at OVHcloud and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to OVHcloud, please open an issue on the present repo + * and we will tag OVHcloud team members. + * + * Thanks! + */ + +import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper"; +import type { + ChatCompletionOutput, + TextGenerationOutput, + TextGenerationOutputFinishReason, +} from "@huggingface/tasks"; +import { InferenceOutputError } from "../lib/InferenceOutputError"; +import type { BodyParams } from "../types"; +import { omit } from "../utils/omit"; +import type { TextGenerationInput } from "@huggingface/tasks"; + +const OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net"; + +interface OvhCloudTextCompletionOutput extends Omit { + choices: Array<{ + text: string; + finish_reason: TextGenerationOutputFinishReason; + logprobs: unknown; + index: number; + }>; +} + +export class OvhCloudConversationalTask extends BaseConversationalTask { + constructor() { + super("ovhcloud", OVHCLOUD_API_BASE_URL); + } +} + +export class OvhCloudTextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("ovhcloud", OVHCLOUD_API_BASE_URL); + } + + override preparePayload(params: BodyParams): Record { + return { + model: params.model, + ...omit(params.args, ["inputs", "parameters"]), + ...(params.args.parameters + ? { + max_tokens: (params.args.parameters as Record).max_new_tokens, + ...omit(params.args.parameters as Record, "max_new_tokens"), + } + : undefined), + prompt: params.args.inputs, + }; + } + + override async getResponse(response: OvhCloudTextCompletionOutput): Promise { + if ( + typeof response === "object" && + "choices" in response && + Array.isArray(response?.choices) && + typeof response?.model === "string" + ) { + const completion = response.choices[0]; + return { + generated_text: completion.text, + }; + } + throw new InferenceOutputError("Expected OVHcloud text generation response format"); + } + +} \ No newline at end of file diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index e5870f6ef3..bab3eb81aa 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -51,6 +51,7 @@ export const INFERENCE_PROVIDERS = [ "novita", "nscale", "openai", + "ovhcloud", "replicate", "sambanova", "together", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index c64a396d37..f8ab02d1d1 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1875,4 +1875,107 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "OVHcloud", + () => { + const client = new HfInference(env.HF_OVHCLOUD_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["ovhcloud"] = { + "meta-llama/llama-3.1-8b-instruct": { + hfModelId: "meta-llama/llama-3.1-8b-instruct", + providerId: "Llama-3.1-8B-Instruct", + status: "live", + task: "conversational", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/llama-3.1-8b-instruct", + provider: "ovhcloud", + messages: [{ role: "user", content: "A, B, C, " }], + seed: 42, + temperature: 0, + top_p: 0.01, + max_tokens: 1, + }); + expect(res.choices && res.choices.length > 0); + const completion = res.choices[0].message?.content; + expect(completion).toContain("D"); + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/llama-3.1-8b-instruct", + provider: "ovhcloud", + messages: [{ role: "user", content: "A, B, C, " }], + stream: true, + seed: 42, + temperature: 0, + top_p: 0.01, + max_tokens: 1, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse).toContain("D"); + }); + + it("textGeneration", async () => { + const res = await client.textGeneration({ + model: "meta-llama/llama-3.1-8b-instruct", + provider: "ovhcloud", + inputs: "A B C ", + parameters: { + seed: 42, + temperature: 0, + top_p: 0.01, + max_new_tokens: 1, + }, + }); + expect(res.generated_text.length > 0); + expect(res.generated_text).toContain("D"); + }); + + it("textGeneration stream", async () => { + const stream = client.textGenerationStream({ + model: "meta-llama/llama-3.1-8b-instruct", + provider: "ovhcloud", + inputs: "A B C ", + stream: true, + parameters: { + seed: 42, + temperature: 0, + top_p: 0.01, + max_new_tokens: 1, + }, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].text; + if (content) { + fullResponse += content; + } + } + } + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse).toContain("D"); + }); + }, + TIMEOUT + ); }); diff --git a/packages/tasks/src/inference-providers.ts b/packages/tasks/src/inference-providers.ts index 0a5ae18099..ee08d12943 100644 --- a/packages/tasks/src/inference-providers.ts +++ b/packages/tasks/src/inference-providers.ts @@ -7,6 +7,7 @@ const INFERENCE_PROVIDERS = [ "fireworks-ai", "hf-inference", "hyperbolic", + "ovhcloud", "replicate", "sambanova", "together",