Skip to content

Commit 6b1e2a4

Browse files
committed
add Nebius to the list of inference providers
1 parent f69c26c commit 6b1e2a4

File tree

8 files changed

+103
-7
lines changed

8 files changed

+103
-7
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ jobs:
4242
env:
4343
HF_TOKEN: ${{ secrets.HF_TOKEN }}
4444
HF_FAL_KEY: dummy
45+
HF_NEBIUS_KEY: dummy
4546
HF_REPLICATE_KEY: dummy
4647
HF_SAMBANOVA_KEY: dummy
4748
HF_TOGETHER_KEY: dummy
@@ -82,6 +83,7 @@ jobs:
8283
env:
8384
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8485
HF_FAL_KEY: dummy
86+
HF_NEBIUS_KEY: dummy
8587
HF_REPLICATE_KEY: dummy
8688
HF_SAMBANOVA_KEY: dummy
8789
HF_TOGETHER_KEY: dummy
@@ -149,6 +151,7 @@ jobs:
149151
NPM_CONFIG_REGISTRY: http://localhost:4874/
150152
HF_TOKEN: ${{ secrets.HF_TOKEN }}
151153
HF_FAL_KEY: dummy
154+
HF_NEBIUS_KEY: dummy
152155
HF_REPLICATE_KEY: dummy
153156
HF_SAMBANOVA_KEY: dummy
154157
HF_TOGETHER_KEY: dummy

packages/inference/README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ Your access token should be kept private. If you need to protect it in front-end
4646

4747
You can send inference requests to third-party providers with the inference client.
4848

49-
Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai).
49+
Currently, we support the following providers:
50+
- [Fal.ai](https://fal.ai)
51+
- [Nebius](https://studio.nebius.ai)
52+
- [Replicate](https://replicate.com)
53+
- [Sambanova](https://sambanova.ai)
54+
- [Together](https://together.xyz)
5055

5156
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
5257
```ts
@@ -65,12 +70,13 @@ When authenticated with a third-party provider key, the request is made directly
6570

6671
Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
6772
- [Fal.ai supported models](./src/providers/fal-ai.ts)
73+
- [Nebius supported models](./src/providers/nebius.ts)
6874
- [Replicate supported models](./src/providers/replicate.ts)
6975
- [Sambanova supported models](./src/providers/sambanova.ts)
7076
- [Together supported models](./src/providers/together.ts)
7177
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)
7278

73-
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
79+
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
7480
This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you!
7581

7682
👋**Want to add another provider?** Get in touch if you'd like to add support for another Inference provider, and/or request it on https://huggingface.co/spaces/huggingface/HuggingDiscussions/discussions/49
@@ -457,7 +463,7 @@ await hf.zeroShotImageClassification({
457463
model: 'openai/clip-vit-large-patch14-336',
458464
inputs: {
459465
image: await (await fetch('https://placekitten.com/300/300')).blob()
460-
},
466+
},
461467
parameters: {
462468
candidate_labels: ['cat', 'dog']
463469
}

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { HF_HUB_URL, HF_ROUTER_URL } from "../config";
22
import { FAL_AI_API_BASE_URL } from "../providers/fal-ai";
3+
import { NEBIUS_API_BASE_URL } from "../providers/nebius";
34
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
45
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
56
import { TOGETHER_API_BASE_URL } from "../providers/together";
@@ -143,6 +144,7 @@ export async function makeRequestOptions(
143144
: JSON.stringify({
144145
...otherArgs,
145146
...(chatCompletion || provider === "together" ? { model } : undefined),
147+
...(provider === "nebius" ? { model } : {}),
146148
}),
147149
...(credentials ? { credentials } : undefined),
148150
signal: options?.signal,
@@ -171,6 +173,22 @@ function makeUrl(params: {
171173
: FAL_AI_API_BASE_URL;
172174
return `${baseUrl}/${params.model}`;
173175
}
176+
case "nebius": {
177+
const baseUrl = shouldProxy
178+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
179+
: NEBIUS_API_BASE_URL;
180+
181+
if (params.taskHint === "text-to-image") {
182+
return `${baseUrl}/v1/images/generations`;
183+
}
184+
if (params.taskHint === "text-generation") {
185+
if (params.chatCompletion) {
186+
return `${baseUrl}/v1/chat/completions`;
187+
}
188+
return `${baseUrl}/v1/completions`;
189+
}
190+
return baseUrl;
191+
}
174192
case "replicate": {
175193
const baseUrl = shouldProxy
176194
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Nebius model ID here:
5+
*
6+
* https://huggingface.co/api/partners/nebius/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Nebius and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Nebius, please open an issue on the present repo
15+
* and we will tag Nebius team members.
16+
*
17+
* Thanks!
18+
*/

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@ interface OutputUrlImageGeneration {
2121
*/
2222
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<Blob> {
2323
const payload =
24-
args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate"
24+
args.provider === "together" ||
25+
args.provider === "fal-ai" ||
26+
args.provider === "replicate" ||
27+
args.provider === "nebius"
2528
? {
2629
...omit(args, ["inputs", "parameters"]),
2730
...args.parameters,
2831
...(args.provider !== "replicate" ? { response_format: "base64" } : undefined),
32+
...(args.provider === "nebius" ? { response_format: "b64_json" } : {}),
2933
prompt: args.inputs,
3034
}
3135
: args;

packages/inference/src/tasks/nlp/chatCompletion.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@ export async function chatCompletion(
1515
taskHint: "text-generation",
1616
chatCompletion: true,
1717
});
18+
1819
const isValidOutput =
1920
typeof res === "object" &&
2021
Array.isArray(res?.choices) &&
2122
typeof res?.created === "number" &&
2223
typeof res?.id === "string" &&
2324
typeof res?.model === "string" &&
24-
/// Together.ai does not output a system_fingerprint
25-
(res.system_fingerprint === undefined || typeof res.system_fingerprint === "string") &&
25+
/// Together.ai and Nebius do not output a system_fingerprint
26+
(res.system_fingerprint === undefined ||
27+
res.system_fingerprint === null ||
28+
typeof res.system_fingerprint === "string") &&
2629
typeof res?.usage === "object";
2730

2831
if (!isValidOutput) {

packages/inference/src/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ export interface Options {
4444

4545
export type InferenceTask = Exclude<PipelineType, "other">;
4646

47-
export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"] as const;
47+
export const INFERENCE_PROVIDERS = ["fal-ai", "nebius", "replicate", "sambanova", "together", "hf-inference"] as const;
4848
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
4949

5050
export interface BaseArgs {

packages/inference/test/HfInference.spec.ts

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,50 @@ describe.concurrent("HfInference", () => {
10631063
TIMEOUT
10641064
);
10651065

1066+
describe.concurrent(
1067+
"Nebius",
1068+
() => {
1069+
const client = new HfInference(env.NEBIUS_API_KEY);
1070+
1071+
it("chatCompletion", async () => {
1072+
const res = await client.chatCompletion({
1073+
model: "meta-llama/Meta-Llama-3.1-8B-Instruct",
1074+
provider: "nebius",
1075+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1076+
});
1077+
if (res.choices && res.choices.length > 0) {
1078+
const completion = res.choices[0].message?.content;
1079+
expect(completion).toMatch(/(two|2)/i);
1080+
}
1081+
});
1082+
1083+
it("chatCompletion stream", async () => {
1084+
const stream = client.chatCompletionStream({
1085+
model: "meta-llama/Meta-Llama-3.1-70B-Instruct",
1086+
provider: "nebius",
1087+
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
1088+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1089+
let out = "";
1090+
for await (const chunk of stream) {
1091+
if (chunk.choices && chunk.choices.length > 0) {
1092+
out += chunk.choices[0].delta.content;
1093+
}
1094+
}
1095+
expect(out).toMatch(/(two|2)/i);
1096+
});
1097+
1098+
it("textToImage", async () => {
1099+
const res = await client.textToImage({
1100+
model: "black-forest-labs/flux-schnell",
1101+
provider: "nebius",
1102+
inputs: "award winning high resolution photo of a giant tortoise",
1103+
});
1104+
expect(res).toBeInstanceOf(Blob);
1105+
});
1106+
},
1107+
TIMEOUT
1108+
);
1109+
10661110
describe.concurrent("3rd party providers", () => {
10671111
it("chatCompletion - fails with unsupported model", async () => {
10681112
expect(

0 commit comments

Comments
 (0)