Skip to content

Commit ace6668

Browse files
AktsvigunAkim TsvigunhanouticelinaWauplin
authored
Nebius AI Studio provider added (#2866)
* Nebius provider added * nebius occurance sorted alphabetically * nebius text-to-image task fixed; tests for nebius provider added * upload cassettes and update docs * maintain alphabetical order * fix merging * height and width are not required * Update docs/source/en/guides/inference.md --------- Co-authored-by: Akim Tsvigun <[email protected]> Co-authored-by: Celina Hanouti <[email protected]> Co-authored-by: Lucain <[email protected]>
1 parent 604b9ca commit ace6668

11 files changed

+533
-32
lines changed

docs/source/en/guides/inference.md

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -248,36 +248,36 @@ You might wonder why using [`InferenceClient`] instead of OpenAI's client? There
248248

249249
[`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks:
250250

251-
| Domain | Task | HF Inference | fal-ai | Fireworks AI | Hyperbolic | Novita AI | Replicate | Sambanova | Together |
252-
| ------------------- | --------------------------------------------------- | ------------ | ------ | ------------ | ---------- | ------ | --------- | --------- | -------- |
253-
| **Audio** | [`~InferenceClient.audio_classification`] ||| ||||||
254-
| | [`~InferenceClient.audio_to_audio`] ||| ||||||
255-
| | [`~InferenceClient.automatic_speech_recognition`] ||| ||||||
256-
| | [`~InferenceClient.text_to_speech`] ||| ||||||
257-
| **Computer Vision** | [`~InferenceClient.image_classification`] ||| ||||||
258-
| | [`~InferenceClient.image_segmentation`] ||| ||||||
259-
| | [`~InferenceClient.image_to_image`] ||| ||||||
260-
| | [`~InferenceClient.image_to_text`] ||| ||||||
261-
| | [`~InferenceClient.object_detection`] ||| ||||||
262-
| | [`~InferenceClient.text_to_image`] ||| ||||||
263-
| | [`~InferenceClient.text_to_video`] ||| ||||||
264-
| | [`~InferenceClient.zero_shot_image_classification`] ||| ||||||
265-
| **Multimodal** | [`~InferenceClient.document_question_answering`] ||| ||||||
266-
| | [`~InferenceClient.visual_question_answering`] ||| ||||||
267-
| **NLP** | [`~InferenceClient.chat_completion`] ||| ||||||
268-
| | [`~InferenceClient.feature_extraction`] ||| ||||||
269-
| | [`~InferenceClient.fill_mask`] ||| ||||||
270-
| | [`~InferenceClient.question_answering`] ||| ||||||
271-
| | [`~InferenceClient.sentence_similarity`] ||| ||||||
272-
| | [`~InferenceClient.summarization`] ||| ||||||
273-
| | [`~InferenceClient.table_question_answering`] ||| ||||||
274-
| | [`~InferenceClient.text_classification`] ||| ||||||
275-
| | [`~InferenceClient.text_generation`] ||| ||||||
276-
| | [`~InferenceClient.token_classification`] ||| ||||||
277-
| | [`~InferenceClient.translation`] ||| ||||||
278-
| | [`~InferenceClient.zero_shot_classification`] ||| ||||||
279-
| **Tabular** | [`~InferenceClient.tabular_classification`] ||| ||||||
280-
| | [`~InferenceClient.tabular_regression`] ||| ||||||
251+
| Domain | Task | HF Inference | fal-ai | Fireworks AI | Hyperbolic | Nebius AI Studio | Novita AI | Replicate | Sambanova | Together |
252+
| ------------------- | --------------------------------------------------- | ------------ | ------ | --------- | ---------- | ---------------- | ------ | --------- | --------- | ----------- |
253+
| **Audio** | [`~InferenceClient.audio_classification`] ||||| |||| |
254+
| | [`~InferenceClient.audio_to_audio`] ||||| |||| |
255+
| | [`~InferenceClient.automatic_speech_recognition`] ||||| |||| |
256+
| | [`~InferenceClient.text_to_speech`] ||||| |||| |
257+
| **Computer Vision** | [`~InferenceClient.image_classification`] ||||| |||| |
258+
| | [`~InferenceClient.image_segmentation`] ||||| |||| |
259+
| | [`~InferenceClient.image_to_image`] ||||| |||| |
260+
| | [`~InferenceClient.image_to_text`] ||||| |||| |
261+
| | [`~InferenceClient.object_detection`] ||||| |||| |
262+
| | [`~InferenceClient.text_to_image`] ||||| |||| |
263+
| | [`~InferenceClient.text_to_video`] ||||| |||| |
264+
| | [`~InferenceClient.zero_shot_image_classification`] ||||| |||| |
265+
| **Multimodal** | [`~InferenceClient.document_question_answering`] ||||| |||| |
266+
| | [`~InferenceClient.visual_question_answering`] ||||| |||| |
267+
| **NLP** | [`~InferenceClient.chat_completion`] ||||| |||| |
268+
| | [`~InferenceClient.feature_extraction`] ||||| |||| |
269+
| | [`~InferenceClient.fill_mask`] ||||| |||| |
270+
| | [`~InferenceClient.question_answering`] ||||| |||| |
271+
| | [`~InferenceClient.sentence_similarity`] ||||| |||| |
272+
| | [`~InferenceClient.summarization`] ||||| |||| |
273+
| | [`~InferenceClient.table_question_answering`] ||||| |||| |
274+
| | [`~InferenceClient.text_classification`] ||||| |||| |
275+
| | [`~InferenceClient.text_generation`] ||||| |||| |
276+
| | [`~InferenceClient.token_classification`] ||||| |||| |
277+
| | [`~InferenceClient.translation`] ||||| |||| |
278+
| | [`~InferenceClient.zero_shot_classification`] ||||| |||| |
279+
| **Tabular** | [`~InferenceClient.tabular_classification`] ||||| |||| |
280+
| | [`~InferenceClient.tabular_regression`] ||||| |||| |
281281

282282
<Tip>
283283

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class InferenceClient:
132132
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
133133
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
134134
provider (`str`, *optional*):
135-
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
135+
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
136136
defaults to hf-inference (Hugging Face Serverless Inference API).
137137
If model is a URL or `base_url` is passed, then `provider` is not used.
138138
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class AsyncInferenceClient:
120120
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
121121
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
122122
provider (`str`, *optional*):
123-
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
123+
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
124124
defaults to hf-inference (Hugging Face Serverless Inference API).
125125
If model is a URL or `base_url` is passed, then `provider` is not used.
126126
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .fireworks_ai import FireworksAIConversationalTask
1111
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
1212
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
13+
from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
1314
from .novita import NovitaConversationalTask, NovitaTextGenerationTask
1415
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
1516
from .sambanova import SambanovaConversationalTask
@@ -21,6 +22,7 @@
2122
"fireworks-ai",
2223
"hf-inference",
2324
"hyperbolic",
25+
"nebius",
2426
"novita",
2527
"replicate",
2628
"sambanova",
@@ -70,6 +72,11 @@
7072
"conversational": HyperbolicTextGenerationTask("conversational"),
7173
"text-generation": HyperbolicTextGenerationTask("text-generation"),
7274
},
75+
"nebius": {
76+
"text-to-image": NebiusTextToImageTask(),
77+
"conversational": NebiusConversationalTask(),
78+
"text-generation": NebiusTextGenerationTask(),
79+
},
7380
"novita": {
7481
"text-generation": NovitaTextGenerationTask(),
7582
"conversational": NovitaConversationalTask(),

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"fireworks-ai": {},
2222
"hf-inference": {},
2323
"hyperbolic": {},
24+
"nebius": {},
2425
"replicate": {},
2526
"sambanova": {},
2627
"together": {},
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import base64
2+
from typing import Any, Dict, Optional, Union
3+
4+
from huggingface_hub.inference._common import _as_dict
5+
from huggingface_hub.inference._providers._common import (
6+
BaseConversationalTask,
7+
BaseTextGenerationTask,
8+
TaskProviderHelper,
9+
filter_none,
10+
)
11+
12+
13+
class NebiusTextGenerationTask(BaseTextGenerationTask):
14+
def __init__(self):
15+
super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai")
16+
17+
18+
class NebiusConversationalTask(BaseConversationalTask):
19+
def __init__(self):
20+
super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai")
21+
22+
23+
class NebiusTextToImageTask(TaskProviderHelper):
24+
def __init__(self):
25+
super().__init__(task="text-to-image", provider="nebius", base_url="https://api.studio.nebius.ai")
26+
27+
def _prepare_route(self, mapped_model: str) -> str:
28+
return "/v1/images/generations"
29+
30+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
31+
parameters = filter_none(parameters)
32+
if "guidance_scale" in parameters:
33+
parameters.pop("guidance_scale")
34+
if parameters.get("response_format") not in ("b64_json", "url"):
35+
parameters["response_format"] = "b64_json"
36+
37+
return {"prompt": inputs, **parameters, "model": mapped_model}
38+
39+
def get_response(self, response: Union[bytes, Dict]) -> Any:
40+
response_dict = _as_dict(response)
41+
return base64.b64decode(response_dict["data"][0]["b64_json"])

0 commit comments

Comments
 (0)