Skip to content

Commit 604b9ca

Browse files
Add Novita provider (#2865)
* add novita * refactor tests * quality * regenerate cassettes
1 parent fb4f42e commit 604b9ca

11 files changed

+450
-68
lines changed

docs/source/en/guides/inference.md

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -248,36 +248,36 @@ You might wonder why using [`InferenceClient`] instead of OpenAI's client? There
248248

249249
[`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks:
250250

251-
| Domain | Task | HF Inference | fal-ai | Fireworks AI | Hyperbolic | Replicate | Sambanova | Together |
252-
| ------------------- | --------------------------------------------------- | ------------ | ------ | ------------ | ---------- | --------- | --------- | -------- |
253-
| **Audio** | [`~InferenceClient.audio_classification`] ||||||||
254-
| | [`~InferenceClient.audio_to_audio`] ||||||||
255-
| | [`~InferenceClient.automatic_speech_recognition`] ||||||||
256-
| | [`~InferenceClient.text_to_speech`] ||||||||
257-
| **Computer Vision** | [`~InferenceClient.image_classification`] ||||||||
258-
| | [`~InferenceClient.image_segmentation`] ||||||||
259-
| | [`~InferenceClient.image_to_image`] ||||||||
260-
| | [`~InferenceClient.image_to_text`] ||||||||
261-
| | [`~InferenceClient.object_detection`] ||||||||
262-
| | [`~InferenceClient.text_to_image`] ||||||||
263-
| | [`~InferenceClient.text_to_video`] ||||||||
264-
| | [`~InferenceClient.zero_shot_image_classification`] ||||||||
265-
| **Multimodal** | [`~InferenceClient.document_question_answering`] ||||||||
266-
| | [`~InferenceClient.visual_question_answering`] ||||||||
267-
| **NLP** | [`~InferenceClient.chat_completion`] ||||||||
268-
| | [`~InferenceClient.feature_extraction`] ||||||||
269-
| | [`~InferenceClient.fill_mask`] ||||||||
270-
| | [`~InferenceClient.question_answering`] ||||||||
271-
| | [`~InferenceClient.sentence_similarity`] ||||||||
272-
| | [`~InferenceClient.summarization`] ||||||||
273-
| | [`~InferenceClient.table_question_answering`] ||||||||
274-
| | [`~InferenceClient.text_classification`] ||||||||
275-
| | [`~InferenceClient.text_generation`] ||||||||
276-
| | [`~InferenceClient.token_classification`] ||||||||
277-
| | [`~InferenceClient.translation`] ||||||||
278-
| | [`~InferenceClient.zero_shot_classification`] ||||||||
279-
| **Tabular** | [`~InferenceClient.tabular_classification`] ||||||||
280-
| | [`~InferenceClient.tabular_regression`] ||||||||
251+
| Domain | Task | HF Inference | fal-ai | Fireworks AI | Hyperbolic | Novita AI | Replicate | Sambanova | Together |
252+
| ------------------- | --------------------------------------------------- | ------------ | ------ | ------------ | ---------- | ------ | --------- | --------- | -------- |
253+
| **Audio** | [`~InferenceClient.audio_classification`] |||||| |||
254+
| | [`~InferenceClient.audio_to_audio`] |||||| |||
255+
| | [`~InferenceClient.automatic_speech_recognition`] |||||| |||
256+
| | [`~InferenceClient.text_to_speech`] ||||| | |||
257+
| **Computer Vision** | [`~InferenceClient.image_classification`] |||||| |||
258+
| | [`~InferenceClient.image_segmentation`] |||||| |||
259+
| | [`~InferenceClient.image_to_image`] |||||| |||
260+
| | [`~InferenceClient.image_to_text`] |||||| |||
261+
| | [`~InferenceClient.object_detection`] |||||| |||
262+
| | [`~InferenceClient.text_to_image`] ||||| | |||
263+
| | [`~InferenceClient.text_to_video`] ||||| | |||
264+
| | [`~InferenceClient.zero_shot_image_classification`] |||||| |||
265+
| **Multimodal** | [`~InferenceClient.document_question_answering`] |||||| |||
266+
| | [`~InferenceClient.visual_question_answering`] |||||| |||
267+
| **NLP** | [`~InferenceClient.chat_completion`] ||||| | |||
268+
| | [`~InferenceClient.feature_extraction`] |||||| |||
269+
| | [`~InferenceClient.fill_mask`] |||||| |||
270+
| | [`~InferenceClient.question_answering`] |||||| |||
271+
| | [`~InferenceClient.sentence_similarity`] |||||| |||
272+
| | [`~InferenceClient.summarization`] |||||| |||
273+
| | [`~InferenceClient.table_question_answering`] |||||| |||
274+
| | [`~InferenceClient.text_classification`] |||||| |||
275+
| | [`~InferenceClient.text_generation`] ||||| | |||
276+
| | [`~InferenceClient.token_classification`] |||||| |||
277+
| | [`~InferenceClient.translation`] |||||| |||
278+
| | [`~InferenceClient.zero_shot_classification`] |||||| |||
279+
| **Tabular** | [`~InferenceClient.tabular_classification`] |||||| |||
280+
| | [`~InferenceClient.tabular_regression`] |||||| |||
281281

282282
<Tip>
283283

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class InferenceClient:
132132
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
133133
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
134134
provider (`str`, *optional*):
135-
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"replicate"`, "sambanova"` or `"together"`.
135+
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
136136
defaults to hf-inference (Hugging Face Serverless Inference API).
137137
If model is a URL or `base_url` is passed, then `provider` is not used.
138138
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class AsyncInferenceClient:
120120
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
121121
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
122122
provider (`str`, *optional*):
123-
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"replicate"`, "sambanova"` or `"together"`.
123+
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
124124
defaults to hf-inference (Hugging Face Serverless Inference API).
125125
If model is a URL or `base_url` is passed, then `provider` is not used.
126126
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .fireworks_ai import FireworksAIConversationalTask
1111
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
1212
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
13+
from .novita import NovitaConversationalTask, NovitaTextGenerationTask
1314
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
1415
from .sambanova import SambanovaConversationalTask
1516
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
@@ -20,6 +21,7 @@
2021
"fireworks-ai",
2122
"hf-inference",
2223
"hyperbolic",
24+
"novita",
2325
"replicate",
2426
"sambanova",
2527
"together",
@@ -68,6 +70,10 @@
6870
"conversational": HyperbolicTextGenerationTask("conversational"),
6971
"text-generation": HyperbolicTextGenerationTask("text-generation"),
7072
},
73+
"novita": {
74+
"text-generation": NovitaTextGenerationTask(),
75+
"conversational": NovitaConversationalTask(),
76+
},
7177
"replicate": {
7278
"text-to-image": ReplicateTask("text-to-image"),
7379
"text-to-speech": ReplicateTextToSpeechTask(),
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from huggingface_hub.inference._providers._common import (
2+
BaseConversationalTask,
3+
BaseTextGenerationTask,
4+
)
5+
6+
7+
_PROVIDER = "novita"
8+
_BASE_URL = "https://api.novita.ai/v3/openai"
9+
10+
11+
class NovitaTextGenerationTask(BaseTextGenerationTask):
12+
def __init__(self):
13+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
14+
15+
def _prepare_route(self, mapped_model: str) -> str:
16+
# there is no v1/ route for novita
17+
return "/completions"
18+
19+
20+
class NovitaConversationalTask(BaseConversationalTask):
21+
def __init__(self):
22+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
23+
24+
def _prepare_route(self, mapped_model: str) -> str:
25+
# there is no v1/ route for novita
26+
return "/chat/completions"

src/huggingface_hub/inference/_providers/replicate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
from huggingface_hub.utils import get_session
66

77

8+
_PROVIDER = "replicate"
9+
_BASE_URL = "https://api.replicate.com"
10+
11+
812
class ReplicateTask(TaskProviderHelper):
913
def __init__(self, task: str):
10-
super().__init__(provider="replicate", base_url="https://api.replicate.com", task=task)
14+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task)
1115

1216
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
1317
headers = super()._prepare_headers(headers, api_key)

src/huggingface_hub/inference/_providers/together.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
)
1212

1313

14+
_PROVIDER = "together"
15+
_BASE_URL = "https://api.together.xyz"
16+
17+
1418
class TogetherTask(TaskProviderHelper, ABC):
1519
"""Base class for Together API tasks."""
1620

1721
def __init__(self, task: str):
18-
super().__init__(provider="together", base_url="https://api.together.xyz", task=task)
22+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task)
1923

2024
def _prepare_route(self, mapped_model: str) -> str:
2125
if self.task == "text-to-image":
@@ -29,12 +33,12 @@ def _prepare_route(self, mapped_model: str) -> str:
2933

3034
class TogetherTextGenerationTask(BaseTextGenerationTask):
3135
def __init__(self):
32-
super().__init__(provider="together", base_url="https://api.together.xyz")
36+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
3337

3438

3539
class TogetherConversationalTask(BaseConversationalTask):
3640
def __init__(self):
37-
super().__init__(provider="together", base_url="https://api.together.xyz")
41+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
3842

3943

4044
class TogetherTextToImageTask(TogetherTask):

0 commit comments

Comments
 (0)