diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e1e0636651..e178603763 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,6 +42,7 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_FAL_KEY: dummy + HF_NEBIUS_KEY: dummy HF_REPLICATE_KEY: dummy HF_SAMBANOVA_KEY: dummy HF_TOGETHER_KEY: dummy @@ -83,6 +84,7 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_FAL_KEY: dummy + HF_NEBIUS_KEY: dummy HF_REPLICATE_KEY: dummy HF_SAMBANOVA_KEY: dummy HF_TOGETHER_KEY: dummy @@ -151,6 +153,7 @@ jobs: NPM_CONFIG_REGISTRY: http://localhost:4874/ HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_FAL_KEY: dummy + HF_NEBIUS_KEY: dummy HF_REPLICATE_KEY: dummy HF_SAMBANOVA_KEY: dummy HF_TOGETHER_KEY: dummy diff --git a/packages/inference/README.md b/packages/inference/README.md index d1d79416f6..304920cea9 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -49,6 +49,7 @@ You can send inference requests to third-party providers with the inference clie Currently, we support the following providers: - [Fal.ai](https://fal.ai) - [Fireworks AI](https://fireworks.ai) +- [Nebius](https://studio.nebius.ai) - [Replicate](https://replicate.com) - [Sambanova](https://sambanova.ai) - [Together](https://together.xyz) @@ -71,12 +72,13 @@ When authenticated with a third-party provider key, the request is made directly Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here: - [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models) - [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models) +- [Nebius supported models](https://huggingface.co/api/partners/nebius/models) - [Replicate supported models](https://huggingface.co/api/partners/replicate/models) - [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models) - [Together supported models](https://huggingface.co/api/partners/together/models) - [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending) -❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type. +❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type. This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you! 👋**Want to add another provider?** Get in touch if you'd like to add support for another Inference provider, and/or request it on https://huggingface.co/spaces/huggingface/HuggingDiscussions/discussions/49 @@ -463,7 +465,7 @@ await hf.zeroShotImageClassification({ model: 'openai/clip-vit-large-patch14-336', inputs: { image: await (await fetch('https://placekitten.com/300/300')).blob() - }, + }, parameters: { candidate_labels: ['cat', 'dog'] } diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index b68dfa09fd..f31134ce11 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -1,5 +1,6 @@ import { HF_HUB_URL, HF_ROUTER_URL } from "../config"; import { FAL_AI_API_BASE_URL } from "../providers/fal-ai"; +import { NEBIUS_API_BASE_URL } from "../providers/nebius"; import { REPLICATE_API_BASE_URL } from "../providers/replicate"; import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova"; import { TOGETHER_API_BASE_URL } from "../providers/together"; @@ -143,7 +144,7 @@ export async function makeRequestOptions( ? args.data : JSON.stringify({ ...otherArgs, - ...(chatCompletion || provider === "together" ? { model } : undefined), + ...(chatCompletion || provider === "together" || provider === "nebius" ? { model } : undefined), }), ...(credentials ? { credentials } : undefined), signal: options?.signal, @@ -172,6 +173,22 @@ function makeUrl(params: { : FAL_AI_API_BASE_URL; return `${baseUrl}/${params.model}`; } + case "nebius": { + const baseUrl = shouldProxy + ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) + : NEBIUS_API_BASE_URL; + + if (params.taskHint === "text-to-image") { + return `${baseUrl}/v1/images/generations`; + } + if (params.taskHint === "text-generation") { + if (params.chatCompletion) { + return `${baseUrl}/v1/chat/completions`; + } + return `${baseUrl}/v1/completions`; + } + return baseUrl; + } case "replicate": { const baseUrl = shouldProxy ? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider) diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index 6f0112366a..e26e57c843 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -19,6 +19,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record Nebius model ID here: + * + * https://huggingface.co/api/partners/nebius/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Nebius and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Nebius, please open an issue on the present repo + * and we will tag Nebius team members. + * + * Thanks! + */ diff --git a/packages/inference/src/tasks/cv/textToImage.ts b/packages/inference/src/tasks/cv/textToImage.ts index c8e2d3cbf9..b678b0ee20 100644 --- a/packages/inference/src/tasks/cv/textToImage.ts +++ b/packages/inference/src/tasks/cv/textToImage.ts @@ -21,11 +21,15 @@ interface OutputUrlImageGeneration { */ export async function textToImage(args: TextToImageArgs, options?: Options): Promise { const payload = - args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate" + args.provider === "together" || + args.provider === "fal-ai" || + args.provider === "replicate" || + args.provider === "nebius" ? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, ...(args.provider !== "replicate" ? { response_format: "base64" } : undefined), + ...(args.provider === "nebius" ? { response_format: "b64_json" } : undefined), prompt: args.inputs, } : args; diff --git a/packages/inference/src/tasks/nlp/chatCompletion.ts b/packages/inference/src/tasks/nlp/chatCompletion.ts index 740362095f..8d73b40075 100644 --- a/packages/inference/src/tasks/nlp/chatCompletion.ts +++ b/packages/inference/src/tasks/nlp/chatCompletion.ts @@ -15,14 +15,17 @@ export async function chatCompletion( taskHint: "text-generation", chatCompletion: true, }); + const isValidOutput = typeof res === "object" && Array.isArray(res?.choices) && typeof res?.created === "number" && typeof res?.id === "string" && typeof res?.model === "string" && - /// Together.ai does not output a system_fingerprint - (res.system_fingerprint === undefined || typeof res.system_fingerprint === "string") && + /// Together.ai and Nebius do not output a system_fingerprint + (res.system_fingerprint === undefined || + res.system_fingerprint === null || + typeof res.system_fingerprint === "string") && typeof res?.usage === "object"; if (!isValidOutput) { diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 52401c9ae2..d9497874c7 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -47,6 +47,7 @@ export type InferenceTask = Exclude; export const INFERENCE_PROVIDERS = [ "fal-ai", "fireworks-ai", + "nebius", "hf-inference", "replicate", "sambanova", diff --git a/packages/inference/test/HfInference.spec.ts b/packages/inference/test/HfInference.spec.ts index d23128143b..4dbc390d6c 100644 --- a/packages/inference/test/HfInference.spec.ts +++ b/packages/inference/test/HfInference.spec.ts @@ -1064,6 +1064,56 @@ describe.concurrent("HfInference", () => { TIMEOUT ); + describe.concurrent( + "Nebius", + () => { + const client = new HfInference(env.HF_NEBIUS_KEY); + + HARDCODED_MODEL_ID_MAPPING.nebius = { + "meta-llama/Llama-3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "meta-llama/Llama-3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", + "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell", + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.1-8B-Instruct", + provider: "nebius", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toMatch(/(two|2)/i); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.1-70B-Instruct", + provider: "nebius", + messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toMatch(/(two|2)/i); + }); + + it("textToImage", async () => { + const res = await client.textToImage({ + model: "black-forest-labs/FLUX.1-schnell", + provider: "nebius", + inputs: "award winning high resolution photo of a giant tortoise", + }); + expect(res).toBeInstanceOf(Blob); + }); + }, + TIMEOUT + ); + describe.concurrent("3rd party providers", () => { it("chatCompletion - fails with unsupported model", async () => { expect( diff --git a/packages/inference/test/tapes.json b/packages/inference/test/tapes.json index 169059e291..22d9e06f55 100644 --- a/packages/inference/test/tapes.json +++ b/packages/inference/test/tapes.json @@ -6920,5 +6920,78 @@ "vary": "Accept-Encoding" } } + }, + "90dc791157e9ec8ed109eaf07946d878e9208ed6eee79af8dd52a56ef7d40371": { + "url": "https://api.studio.nebius.ai/v1/chat/completions", + "init": { + "headers": { + "Content-Type": "application/json" + }, + "method": "POST", + "body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"meta-llama/Meta-Llama-3.1-8B-Instruct\"}" + }, + "response": { + "body": "{\"id\":\"chatcmpl-89392f51529b4d1c82c3d58b210735c5\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"logprobs\":null,\"message\":{\"content\":\"Two\",\"refusal\":null,\"role\":\"assistant\",\"audio\":null,\"function_call\":null,\"tool_calls\":[]},\"stop_reason\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\"object\":\"chat.completion\",\"service_tier\":null,\"system_fingerprint\":null,\"usage\":{\"completion_tokens\":2,\"prompt_tokens\":21,\"total_tokens\":23,\"completion_tokens_details\":null,\"prompt_tokens_details\":null},\"prompt_logprobs\":null}", + "status": 200, + "statusText": "OK", + "headers": { + "connection": "keep-alive", + "content-type": "application/json", + "strict-transport-security": "max-age=15724800; includeSubDomains" + } + } + }, + "2b75bf387ea5775a8172608df8a1bf7d652b1c5e10f0263e39456ec56e20eedf": { + "url": "https://api.studio.nebius.ai/v1/chat/completions", + "init": { + "headers": { + "Content-Type": "application/json" + }, + "method": "POST", + "body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete the equation 1 + 1 = , just the answer\"}],\"stream\":true,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\"}" + }, + "response": { + "body": "data: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\"\",\"role\":\"assistant\"},\"finish_reason\":null,\"index\":0,\"logprobs\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\"2\"},\"finish_reason\":null,\"index\":0,\"logprobs\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\".\"},\"finish_reason\":null,\"index\":0,\"logprobs\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: {\"id\":\"chatcmpl-2ef2941399e74965a2705410b42b4a3f\",\"choices\":[{\"delta\":{\"content\":\"\"},\"finish_reason\":\"stop\",\"index\":0,\"logprobs\":null,\"stop_reason\":null}],\"created\":1738968734,\"model\":\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\"object\":\"chat.completion.chunk\"}\n\ndata: [DONE]\n\n", + "status": 200, + "statusText": "OK", + "headers": { + "connection": "keep-alive", + "content-type": "text/event-stream; charset=utf-8", + "strict-transport-security": "max-age=15724800; includeSubDomains", + "transfer-encoding": "chunked" + } + } + }, + "b4345ef6e7eb30328b2bf84508cdfc172ecb697d8506c51cc7e447adc7323658": { + "url": "https://api.studio.nebius.ai/v1/images/generations", + "init": { + "headers": { + "Content-Type": "application/json" + }, + "method": "POST", + "body": "{\"response_format\":\"b64_json\",\"prompt\":\"award winning high resolution photo of a giant tortoise\",\"model\":\"black-forest-labs/flux-schnell\"}" + }, + "response": { + "body": "{\"data\":[{\"b64_json\":\"UklGRujgAQBXRUJQVlA4INzgAQAQegOdASoAAgACPhkIg0EhBgtzgAQAYSli6wABWWhZqn6p9Jvqf8h+z35S/JFxn1Heo/uH+X/4X+F/cz5b/+7/Ofkd2G9bf+n/PejR0Z/4v8l/r/26+YH+c/9H+m/2nwh/RH/d/y/79/QP+qv/G/un+g/b36Df8n9z/eR/dv+P+VPwW/rX+t/+X+y/4P//+Yj/tfub7uv8F/0/3K+An+1f77//f8TtPf309hD97/V6/8/7wf9P5YP7R/0P3K/7fyNf1L/R//b/Xf8H4AP//7fPMjwf+OH5r8lPM38Z+Yft39u/yf+W/t//2/2f2FfcX+P/ivDf6j/N/7n/J/5/2D/jv3C/Ff3b/Lf8X+9fu98l/6j/FfuT/jfUX8t/af8//i/8//x/8Z+7n2F/i38o/t39q/yv+j/uP7f/RB8j/tf8//qP/d/rfQt1r/D/6T/If4//j/v/9CPrj88/vv9w/yH+2/u/7n+3v/if5b/Qf+X3s/Uv8V/uf81+Rn2Bfyv+gf6D+7/5H/r/4n///+D7F/1H/U/1P+o/7nqWfav8z/1v87/q/2k+wH+Yf1L/Uf3z/L/+n/J////+/it/Mf8b/I/6H/xf5j///+j4s/nv9z/3H+K/zv/i/yv///9f6BfyD+d/4/+4/5H/n/4H///+P7pv+l+fPzu/Zf/ofnr9F/6r/7P86f3//9ifCC7lwK/mrcqzl7KswpK6fOxHeXWMQ66NvfdMdbfPjryRxBoK7Cl7y7JAFFu+4OPy3T0UhWdetGanczWNDZ6hP3af/YD7nH0cZyq5Rr2pdXTLcRWW0tONLzyK/1DKAvE1aKmARq2bjActrCWJBr1gAZAvPvwGN4dsjAz3pJhhhc2fNHYEWcUa+pJs9szFbKPCJXckqc+4Kf8bf84MsgozZo6FC7tto7W2DY7Nk6+pZzpRA+1qY81hRqJMTXVSduE+HlovvQL0CDxW2x2qSkGNulpY\"}],\"id\":\"text2img-b743e941-3756-4fb8-aeca-2883ab029516\"}", + "status": 200, + "statusText": "OK", + "headers": { + "connection": "keep-alive", + "content-type": "application/json", + "strict-transport-security": "max-age=15724800; includeSubDomains" + } + } + }, + "76e27a0a58b167b19f3a059ab499955365e64ca8b816440ec321764b0f14fd98": { + "url": "", + "init": {}, + "response": { + "body": "", + "status": 200, + "statusText": "OK", + "headers": { + "content-type": "image/jpeg" + } + } } } \ No newline at end of file