Skip to content

Commit f629886

Browse files
Add Hyperbolic provider (#2863)
* add hyperbolic provider * refactor * update supported providers table * nit * add comment * add tests * update new provider doc * handle text-genration payload * use two classes for text-generation and conversational * update provider doc
1 parent cd85541 commit f629886

15 files changed

+742
-66
lines changed

docs/source/en/guides/inference.md

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -248,36 +248,36 @@ You might wonder why using [`InferenceClient`] instead of OpenAI's client? There
248248

249249
[`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks:
250250

251-
| Domain | Task | HF Inference | Replicate | fal-ai | Fireworks AI | Sambanova | Together |
252-
| ------------------- | --------------------------------------------------- | ------------ | --------- | ------ | ------------ | --------- | -------- |
253-
| **Audio** | [`~InferenceClient.audio_classification`] || | ||||
254-
| | [`~InferenceClient.audio_to_audio`] || | ||||
255-
| | [`~InferenceClient.automatic_speech_recognition`] || | ||||
256-
| | [`~InferenceClient.text_to_speech`] || | ||||
257-
| **Computer Vision** | [`~InferenceClient.image_classification`] || | ||||
258-
| | [`~InferenceClient.image_segmentation`] || | ||||
259-
| | [`~InferenceClient.image_to_image`] || | ||||
260-
| | [`~InferenceClient.image_to_text`] || | ||||
261-
| | [`~InferenceClient.object_detection`] || | ||||
262-
| | [`~InferenceClient.text_to_image`] || || |||
263-
| | [`~InferenceClient.text_to_video`] || | ||||
264-
| | [`~InferenceClient.zero_shot_image_classification`] || | ||||
265-
| **Multimodal** | [`~InferenceClient.document_question_answering`] || | ||||
266-
| | [`~InferenceClient.visual_question_answering`] || | ||||
267-
| **NLP** | [`~InferenceClient.chat_completion`] || | ||||
268-
| | [`~InferenceClient.feature_extraction`] || | ||||
269-
| | [`~InferenceClient.fill_mask`] || | ||||
270-
| | [`~InferenceClient.question_answering`] || | ||||
271-
| | [`~InferenceClient.sentence_similarity`] || | ||||
272-
| | [`~InferenceClient.summarization`] || | ||||
273-
| | [`~InferenceClient.table_question_answering`] || | ||||
274-
| | [`~InferenceClient.text_classification`] || | ||||
275-
| | [`~InferenceClient.text_generation`] || | | |||
276-
| | [`~InferenceClient.token_classification`] || | ||||
277-
| | [`~InferenceClient.translation`] || | ||||
278-
| | [`~InferenceClient.zero_shot_classification`] || | ||||
279-
| **Tabular** | [`~InferenceClient.tabular_classification`] || | ||||
280-
| | [`~InferenceClient.tabular_regression`] || | ||||
251+
| Domain | Task | HF Inference | fal-ai | Fireworks AI | Hyperbolic | Replicate | Sambanova | Together |
252+
| ------------------- | --------------------------------------------------- | ------------ | ------ | ------------ | ---------- | --------- | --------- | -------- |
253+
| **Audio** | [`~InferenceClient.audio_classification`] ||| | | |||
254+
| | [`~InferenceClient.audio_to_audio`] ||| | | |||
255+
| | [`~InferenceClient.automatic_speech_recognition`] || | | | |||
256+
| | [`~InferenceClient.text_to_speech`] || | | | |||
257+
| **Computer Vision** | [`~InferenceClient.image_classification`] ||| | | |||
258+
| | [`~InferenceClient.image_segmentation`] ||| | | |||
259+
| | [`~InferenceClient.image_to_image`] ||| | | |||
260+
| | [`~InferenceClient.image_to_text`] ||| | | |||
261+
| | [`~InferenceClient.object_detection`] ||| | | |||
262+
| | [`~InferenceClient.text_to_image`] |||| | |||
263+
| | [`~InferenceClient.text_to_video`] ||| | | |||
264+
| | [`~InferenceClient.zero_shot_image_classification`] ||| | | |||
265+
| **Multimodal** | [`~InferenceClient.document_question_answering`] ||| | | |||
266+
| | [`~InferenceClient.visual_question_answering`] ||| | | |||
267+
| **NLP** | [`~InferenceClient.chat_completion`] ||| | | |||
268+
| | [`~InferenceClient.feature_extraction`] ||| | | |||
269+
| | [`~InferenceClient.fill_mask`] ||| | | |||
270+
| | [`~InferenceClient.question_answering`] ||| | | |||
271+
| | [`~InferenceClient.sentence_similarity`] ||| | | |||
272+
| | [`~InferenceClient.summarization`] ||| | | |||
273+
| | [`~InferenceClient.table_question_answering`] ||| | | |||
274+
| | [`~InferenceClient.text_classification`] ||| | | |||
275+
| | [`~InferenceClient.text_generation`] ||| | | |||
276+
| | [`~InferenceClient.token_classification`] ||| | | |||
277+
| | [`~InferenceClient.translation`] ||| | | |||
278+
| | [`~InferenceClient.zero_shot_classification`] ||| | | |||
279+
| **Tabular** | [`~InferenceClient.tabular_classification`] ||| | | |||
280+
| | [`~InferenceClient.tabular_regression`] ||| | | |||
281281

282282
<Tip>
283283

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class InferenceClient:
132132
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
133133
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
134134
provider (`str`, *optional*):
135-
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"replicate"`, "sambanova"`, `"together"`, or `"hf-inference"`.
135+
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"replicate"`, "sambanova"` or `"together"`.
136136
defaults to hf-inference (Hugging Face Serverless Inference API).
137137
If model is a URL or `base_url` is passed, then `provider` is not used.
138138
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class AsyncInferenceClient:
120120
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
121121
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
122122
provider (`str`, *optional*):
123-
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"replicate"`, "sambanova"`, `"together"`, or `"hf-inference"`.
123+
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"replicate"`, "sambanova"` or `"together"`.
124124
defaults to hf-inference (Hugging Face Serverless Inference API).
125125
If model is a URL or `base_url` is passed, then `provider` is not used.
126126
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
)
1010
from .fireworks_ai import FireworksAIConversationalTask
1111
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
12+
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
1213
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
1314
from .sambanova import SambanovaConversationalTask
14-
from .together import TogetherTextGenerationTask, TogetherTextToImageTask
15+
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
1516

1617

1718
PROVIDER_T = Literal[
1819
"fal-ai",
1920
"fireworks-ai",
2021
"hf-inference",
22+
"hyperbolic",
2123
"replicate",
2224
"sambanova",
2325
"together",
@@ -61,6 +63,11 @@
6163
"summarization": HFInferenceTask("summarization"),
6264
"visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"),
6365
},
66+
"hyperbolic": {
67+
"text-to-image": HyperbolicTextToImageTask(),
68+
"conversational": HyperbolicTextGenerationTask("conversational"),
69+
"text-generation": HyperbolicTextGenerationTask("text-generation"),
70+
},
6471
"replicate": {
6572
"text-to-image": ReplicateTask("text-to-image"),
6673
"text-to-speech": ReplicateTextToSpeechTask(),
@@ -71,8 +78,8 @@
7178
},
7279
"together": {
7380
"text-to-image": TogetherTextToImageTask(),
74-
"conversational": TogetherTextGenerationTask("conversational"),
75-
"text-generation": TogetherTextGenerationTask("text-generation"),
81+
"conversational": TogetherConversationalTask(),
82+
"text-generation": TogetherTextGenerationTask(),
7683
},
7784
}
7885

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"fal-ai": {},
2121
"fireworks-ai": {},
2222
"hf-inference": {},
23+
"hyperbolic": {},
2324
"replicate": {},
2425
"sambanova": {},
2526
"together": {},
@@ -179,6 +180,38 @@ def _prepare_payload_as_bytes(
179180
return None
180181

181182

183+
class BaseConversationalTask(TaskProviderHelper):
184+
"""
185+
Base class for conversational (chat completion) tasks.
186+
The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/chat
187+
"""
188+
189+
def __init__(self, provider: str, base_url: str):
190+
super().__init__(provider=provider, base_url=base_url, task="conversational")
191+
192+
def _prepare_route(self, mapped_model: str) -> str:
193+
return "/v1/chat/completions"
194+
195+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
196+
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
197+
198+
199+
class BaseTextGenerationTask(TaskProviderHelper):
200+
"""
201+
Base class for text-generation (completion) tasks.
202+
The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/completions
203+
"""
204+
205+
def __init__(self, provider: str, base_url: str):
206+
super().__init__(provider=provider, base_url=base_url, task="text-generation")
207+
208+
def _prepare_route(self, mapped_model: str) -> str:
209+
return "/v1/completions"
210+
211+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
212+
return {"prompt": inputs, **filter_none(parameters), "model": mapped_model}
213+
214+
182215
@lru_cache(maxsize=None)
183216
def _fetch_inference_provider_mapping(model: str) -> Dict:
184217
"""
Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
1-
from typing import Any, Dict, Optional
1+
from ._common import BaseConversationalTask
22

3-
from ._common import TaskProviderHelper, filter_none
43

5-
6-
class FireworksAIConversationalTask(TaskProviderHelper):
4+
class FireworksAIConversationalTask(BaseConversationalTask):
75
def __init__(self):
8-
super().__init__(provider="fireworks-ai", base_url="https://api.fireworks.ai/inference", task="conversational")
9-
10-
def _prepare_route(self, mapped_model: str) -> str:
11-
return "/v1/chat/completions"
12-
13-
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
14-
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
6+
super().__init__(provider="fireworks-ai", base_url="https://api.fireworks.ai/inference")
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import base64
2+
from typing import Any, Dict, Optional, Union
3+
4+
from huggingface_hub.inference._common import _as_dict
5+
from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none
6+
7+
8+
class HyperbolicTextToImageTask(TaskProviderHelper):
9+
def __init__(self):
10+
super().__init__(provider="hyperbolic", base_url="https://api.hyperbolic.xyz", task="text-to-image")
11+
12+
def _prepare_route(self, mapped_model: str) -> str:
13+
return "/v1/images/generations"
14+
15+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
16+
parameters = filter_none(parameters)
17+
if "num_inference_steps" in parameters:
18+
parameters["steps"] = parameters.pop("num_inference_steps")
19+
if "guidance_scale" in parameters:
20+
parameters["cfg_scale"] = parameters.pop("guidance_scale")
21+
# For Hyperbolic, the width and height are required parameters
22+
if "width" not in parameters:
23+
parameters["width"] = 512
24+
if "height" not in parameters:
25+
parameters["height"] = 512
26+
return {"prompt": inputs, "model_name": mapped_model, **parameters}
27+
28+
def get_response(self, response: Union[bytes, Dict]) -> Any:
29+
response_dict = _as_dict(response)
30+
return base64.b64decode(response_dict["images"][0]["image"])
31+
32+
33+
class HyperbolicTextGenerationTask(BaseConversationalTask):
34+
"""
35+
Special case for Hyperbolic, where text-generation task is handled as a conversational task.
36+
"""
37+
38+
def __init__(self, task: str):
39+
super().__init__(
40+
provider="hyperbolic",
41+
base_url="https://api.hyperbolic.xyz",
42+
)
43+
self.task = task

src/huggingface_hub/inference/_providers/new_provider.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ Implement the methods that require custom handling. Check out the base implement
1010

1111
If the provider supports multiple tasks that require different implementations, create dedicated subclasses for each task, following the pattern shown in `fal_ai.py`.
1212

13+
For `text-generation` and `conversational` tasks, one can just inherit from `BaseTextGenerationTask` and `BaseConversationalTask` respectively (defined in `_common.py`) and override the methods if needed. Examples can be found in `fireworks_ai.py` and `together.py`.
14+
1315
```py
1416
from typing import Any, Dict, Optional, Union
1517

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
1-
from typing import Any, Dict, Optional
1+
from huggingface_hub.inference._providers._common import BaseConversationalTask
22

3-
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
43

5-
6-
class SambanovaConversationalTask(TaskProviderHelper):
4+
class SambanovaConversationalTask(BaseConversationalTask):
75
def __init__(self):
8-
super().__init__(provider="sambanova", base_url="https://api.sambanova.ai", task="conversational")
9-
10-
def _prepare_route(self, mapped_model: str) -> str:
11-
return "/v1/chat/completions"
12-
13-
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
14-
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
6+
super().__init__(provider="sambanova", base_url="https://api.sambanova.ai")

0 commit comments

Comments
 (0)