From 39615a6d5f5cd23437f46783ec9c8cb99398b6e9 Mon Sep 17 00:00:00 2001 From: Fabien Ric Date: Fri, 21 Mar 2025 11:42:02 +0100 Subject: [PATCH 1/4] add ovhcloud inference provider --- packages/inference/README.md | 2 + .../inference/src/lib/makeRequestOptions.ts | 2 + packages/inference/src/providers/consts.ts | 1 + packages/inference/src/providers/ovhcloud.ts | 41 ++++++++++++++ packages/inference/src/types.ts | 1 + .../inference/test/InferenceClient.spec.ts | 56 +++++++++++++++++++ packages/tasks/src/inference-providers.ts | 1 + 7 files changed, 104 insertions(+) create mode 100644 packages/inference/src/providers/ovhcloud.ts diff --git a/packages/inference/README.md b/packages/inference/README.md index c4a2ea4b1d..1ce051cff5 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -58,6 +58,7 @@ Currently, we support the following providers: - [Blackforestlabs](https://blackforestlabs.ai) - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) +- [OVHcloud](https://endpoints.ai.cloud.ovh.net/) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token. ```ts @@ -84,6 +85,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Together supported models](https://huggingface.co/api/partners/together/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) +- [OVHcloud supported models](https://huggingface.co/api/partners/ovhcloud/models) - [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending) ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type. diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index 2da705c677..d747e51795 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -12,6 +12,7 @@ import { REPLICATE_CONFIG } from "../providers/replicate"; import { SAMBANOVA_CONFIG } from "../providers/sambanova"; import { TOGETHER_CONFIG } from "../providers/together"; import { OPENAI_CONFIG } from "../providers/openai"; +import { OVHCLOUD_CONFIG } from "../providers/ovhcloud"; import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types"; import { isUrl } from "./isUrl"; import { version as packageVersion, name as packageName } from "../../package.json"; @@ -42,6 +43,7 @@ const providerConfigs: Record = { replicate: REPLICATE_CONFIG, sambanova: SAMBANOVA_CONFIG, together: TOGETHER_CONFIG, + ovhcloud: OVHCLOUD_CONFIG, }; /** diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index 184a5d2425..73ed2f6406 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -26,6 +26,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record OVHcloud model ID here: + * + * https://huggingface.co/api/partners/ovhcloud/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at OVHcloud and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to OVHcloud, please open an issue on the present repo + * and we will tag OVHcloud team members. + * + * Thanks! + */ +import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types"; + +const OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net"; + +const makeBody = (params: BodyParams): Record => { + return { + ...params.args, + model: params.model, + }; +}; + +const makeHeaders = (params: HeaderParams): Record => { + return { Authorization: `Bearer ${params.accessToken}` }; +}; + +const makeUrl = (params: UrlParams): string => { + return `${params.baseUrl}/v1/chat/completions`; +}; + +export const OVHCLOUD_CONFIG: ProviderConfig = { + baseUrl: OVHCLOUD_API_BASE_URL, + makeBody, + makeHeaders, + makeUrl, +}; diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 07f597d93f..be5db8ac08 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -39,6 +39,7 @@ export const INFERENCE_PROVIDERS = [ "nebius", "novita", "openai", + "ovhcloud", "replicate", "sambanova", "together", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 0d16f9d1d8..9a4786afcf 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1466,4 +1466,60 @@ describe.concurrent("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "OVHcloud", + () => { + const client = new HfInference(env.HF_OVHCLOUD_KEY ?? "dummy"); + + HARDCODED_MODEL_ID_MAPPING["ovhcloud"] = { + "meta-llama/llama-3.1-8b-instruct": "Meta-Llama-3-8B-Instruct", + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/llama-3.1-8b-instruct", + provider: "ovhcloud", + messages: [{ role: "user", content: "Complete the sequence: A, B, " }], + parameters: { + temperature: 0, + top_p: 0.01, + max_new_tokens: 1, + }, + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("C"); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/llama-3.1-8b-instruct", + provider: "ovhcloud", + messages: [{ role: "user", content: "Complete the sequence: A, B, " }], + stream: true, + parameters: { + temperature: 0, + top_p: 0.01, + max_new_tokens: 1, + }, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + // Verify we got a meaningful response + expect(fullResponse).toBeTruthy(); + expect(fullResponse).toContain("C"); + }); + }, + TIMEOUT + ); }); diff --git a/packages/tasks/src/inference-providers.ts b/packages/tasks/src/inference-providers.ts index 0a5ae18099..ee08d12943 100644 --- a/packages/tasks/src/inference-providers.ts +++ b/packages/tasks/src/inference-providers.ts @@ -7,6 +7,7 @@ const INFERENCE_PROVIDERS = [ "fireworks-ai", "hf-inference", "hyperbolic", + "ovhcloud", "replicate", "sambanova", "together", From 85441ee661e2e672e0cf3622974773389e09d635 Mon Sep 17 00:00:00 2001 From: Fabien Ric Date: Tue, 22 Apr 2025 16:09:35 +0200 Subject: [PATCH 2/4] - merge main into branch - fix tests --- packages/inference/test/InferenceClient.spec.ts | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 186e36af7f..0f6224d005 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1,6 +1,6 @@ import { assert, describe, expect, it } from "vitest"; -import type { ChatCompletionStreamOutput, TextGenerationStreamOutput } from "@huggingface/tasks"; +import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; import type { TextToImageArgs } from "../src"; import { @@ -1696,8 +1696,13 @@ describe.concurrent("InferenceClient", () => { () => { const client = new HfInference(env.HF_OVHCLOUD_KEY ?? "dummy"); - HARDCODED_MODEL_ID_MAPPING["ovhcloud"] = { - "meta-llama/llama-3.1-8b-instruct": "Llama-3.1-8B-Instruct", + HARDCODED_MODEL_INFERENCE_MAPPING["ovhcloud"] = { + "meta-llama/llama-3.1-8b-instruct": { + hfModelId: "meta-llama/llama-3.1-8b-instruct", + providerId: "Llama-3.1-8B-Instruct", + status: "live", + task: "conversational", + }, }; it("chatCompletion", async () => { From 1212dc8f6c9fc341cc82f2a40dd68fae355c6fc9 Mon Sep 17 00:00:00 2001 From: Fabien Ric Date: Tue, 22 Apr 2025 16:16:02 +0200 Subject: [PATCH 3/4] fix unused import --- packages/inference/src/providers/ovhcloud.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/inference/src/providers/ovhcloud.ts b/packages/inference/src/providers/ovhcloud.ts index 1e62fea950..c6a5644f12 100644 --- a/packages/inference/src/providers/ovhcloud.ts +++ b/packages/inference/src/providers/ovhcloud.ts @@ -18,7 +18,6 @@ import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper"; import type { ChatCompletionOutput, - ChatCompletionStreamOutput, TextGenerationOutput, TextGenerationOutputFinishReason, } from "@huggingface/tasks"; From 31c666fcd7161cc4212b3210d0ee81ae39c0685f Mon Sep 17 00:00:00 2001 From: Fabien Ric Date: Mon, 28 Apr 2025 10:44:02 +0200 Subject: [PATCH 4/4] fix chatcompletion payload --- packages/inference/src/providers/ovhcloud.ts | 35 +++++++------------ .../inference/test/InferenceClient.spec.ts | 20 +++++------ 2 files changed, 21 insertions(+), 34 deletions(-) diff --git a/packages/inference/src/providers/ovhcloud.ts b/packages/inference/src/providers/ovhcloud.ts index c6a5644f12..5d886e1010 100644 --- a/packages/inference/src/providers/ovhcloud.ts +++ b/packages/inference/src/providers/ovhcloud.ts @@ -24,23 +24,10 @@ import type { import { InferenceOutputError } from "../lib/InferenceOutputError"; import type { BodyParams } from "../types"; import { omit } from "../utils/omit"; +import type { TextGenerationInput } from "@huggingface/tasks"; const OVHCLOUD_API_BASE_URL = "https://oai.endpoints.kepler.ai.cloud.ovh.net"; -function prepareBaseOvhCloudPayload(params: BodyParams): Record { - return { - model: params.model, - ...omit(params.args, ["inputs", "parameters"]), - ...(params.args.parameters - ? { - max_tokens: (params.args.parameters as Record).max_new_tokens, - ...omit(params.args.parameters as Record, "max_new_tokens"), - } - : undefined), - prompt: params.args.inputs, - }; -} - interface OvhCloudTextCompletionOutput extends Omit { choices: Array<{ text: string; @@ -54,10 +41,6 @@ export class OvhCloudConversationalTask extends BaseConversationalTask { constructor() { super("ovhcloud", OVHCLOUD_API_BASE_URL); } - - override preparePayload(params: BodyParams): Record { - return prepareBaseOvhCloudPayload(params); - } } export class OvhCloudTextGenerationTask extends BaseTextGenerationTask { @@ -65,10 +48,18 @@ export class OvhCloudTextGenerationTask extends BaseTextGenerationTask { super("ovhcloud", OVHCLOUD_API_BASE_URL); } - override preparePayload(params: BodyParams): Record { - const payload = prepareBaseOvhCloudPayload(params); - payload.prompt = params.args.inputs; - return payload; + override preparePayload(params: BodyParams): Record { + return { + model: params.model, + ...omit(params.args, ["inputs", "parameters"]), + ...(params.args.parameters + ? { + max_tokens: (params.args.parameters as Record).max_new_tokens, + ...omit(params.args.parameters as Record, "max_new_tokens"), + } + : undefined), + prompt: params.args.inputs, + }; } override async getResponse(response: OvhCloudTextCompletionOutput): Promise { diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 0f6224d005..25d0979ab4 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1710,12 +1710,10 @@ describe.concurrent("InferenceClient", () => { model: "meta-llama/llama-3.1-8b-instruct", provider: "ovhcloud", messages: [{ role: "user", content: "A, B, C, " }], - parameters: { - seed: 42, - temperature: 0, - top_p: 0.01, - max_new_tokens: 1, - }, + seed: 42, + temperature: 0, + top_p: 0.01, + max_tokens: 1, }); expect(res.choices && res.choices.length > 0); const completion = res.choices[0].message?.content; @@ -1728,12 +1726,10 @@ describe.concurrent("InferenceClient", () => { provider: "ovhcloud", messages: [{ role: "user", content: "A, B, C, " }], stream: true, - parameters: { - seed: 42, - temperature: 0, - top_p: 0.01, - max_new_tokens: 1, - }, + seed: 42, + temperature: 0, + top_p: 0.01, + max_tokens: 1, }) as AsyncGenerator; let fullResponse = "";