Skip to content

Commit a365b7d

Browse files
Add Fireworks AI provider + instructions for new provider (#2848)
* first draft of dynamic mapping * fix imports and typing * add back recommended models fetching for hf-inference * fix * avoir circular imports * small clean up * add default supported model list * remove unnecessary arg * nit * rename function * another nit * fix * fix conversational * fix hf-inference * add warning when status=staging * update warning and use model_info * update import * fix ExpandModelProperty_T * refacto * fix python 3.8 * fix test * remove newlines * Base class for inference providers * revert * refacto hf-inference and fal-ai tests * replicate tests * samba and together tests * reorder * unfinished business * some docstrings * fix some tests * fix HfInference does not require token * fix inference client tests * fix hf-inference _prepare_api_key * fai ai get response tests * test get_response together + replicate * fix prepare_url * Add Fireworks AI provider + instructions for new provider * Update src/huggingface_hub/inference/_providers/new_provider.md Co-authored-by: Célina <[email protected]> * add fireworks AI to supported providers table --------- Co-authored-by: Celina Hanouti <[email protected]>
1 parent 0467c1c commit a365b7d

File tree

10 files changed

+446
-40
lines changed

10 files changed

+446
-40
lines changed

docs/source/en/guides/inference.md

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -248,36 +248,36 @@ You might wonder why using [`InferenceClient`] instead of OpenAI's client? There
248248

249249
[`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks:
250250

251-
| Domain | Task | HF Inference | Replicate | fal-ai | Sambanova | Together |
252-
| ------------------- | --------------------------------------------------- | ------------ | --------- | ------ | --------- | -------- |
253-
| **Audio** | [`~InferenceClient.audio_classification`] ||||||
254-
| | [`~InferenceClient.audio_to_audio`] ||||||
255-
| | [`~InferenceClient.automatic_speech_recognition`] ||||||
256-
| | [`~InferenceClient.text_to_speech`] ||||||
257-
| **Computer Vision** | [`~InferenceClient.image_classification`] ||||||
258-
| | [`~InferenceClient.image_segmentation`] ||||||
259-
| | [`~InferenceClient.image_to_image`] ||||||
260-
| | [`~InferenceClient.image_to_text`] ||||||
261-
| | [`~InferenceClient.object_detection`] ||||||
262-
| | [`~InferenceClient.text_to_image`] ||||||
263-
| | [`~InferenceClient.text_to_video`] ||||||
264-
| | [`~InferenceClient.zero_shot_image_classification`] ||||||
265-
| **Multimodal** | [`~InferenceClient.document_question_answering`] ||||||
266-
| | [`~InferenceClient.visual_question_answering`] ||||||
267-
| **NLP** | [`~InferenceClient.chat_completion`] ||||||
268-
| | [`~InferenceClient.feature_extraction`] ||||||
269-
| | [`~InferenceClient.fill_mask`] ||||||
270-
| | [`~InferenceClient.question_answering`] ||||||
271-
| | [`~InferenceClient.sentence_similarity`] ||||||
272-
| | [`~InferenceClient.summarization`] ||||||
273-
| | [`~InferenceClient.table_question_answering`] ||||||
274-
| | [`~InferenceClient.text_classification`] ||||||
275-
| | [`~InferenceClient.text_generation`] ||||||
276-
| | [`~InferenceClient.token_classification`] ||||||
277-
| | [`~InferenceClient.translation`] ||||||
278-
| | [`~InferenceClient.zero_shot_classification`] ||||||
279-
| **Tabular** | [`~InferenceClient.tabular_classification`] ||||||
280-
| | [`~InferenceClient.tabular_regression`] ||||||
251+
| Domain | Task | HF Inference | Replicate | fal-ai | Fireworks AI | Sambanova | Together |
252+
| ------------------- | --------------------------------------------------- | ------------ | --------- | ------ | ------------ | --------- | -------- |
253+
| **Audio** | [`~InferenceClient.audio_classification`] |||| |||
254+
| | [`~InferenceClient.audio_to_audio`] |||| |||
255+
| | [`~InferenceClient.automatic_speech_recognition`] |||| |||
256+
| | [`~InferenceClient.text_to_speech`] |||| |||
257+
| **Computer Vision** | [`~InferenceClient.image_classification`] |||| |||
258+
| | [`~InferenceClient.image_segmentation`] |||| |||
259+
| | [`~InferenceClient.image_to_image`] |||| |||
260+
| | [`~InferenceClient.image_to_text`] |||| |||
261+
| | [`~InferenceClient.object_detection`] |||| |||
262+
| | [`~InferenceClient.text_to_image`] |||| |||
263+
| | [`~InferenceClient.text_to_video`] |||| |||
264+
| | [`~InferenceClient.zero_shot_image_classification`] |||| |||
265+
| **Multimodal** | [`~InferenceClient.document_question_answering`] |||| |||
266+
| | [`~InferenceClient.visual_question_answering`] |||| |||
267+
| **NLP** | [`~InferenceClient.chat_completion`] |||| |||
268+
| | [`~InferenceClient.feature_extraction`] |||| |||
269+
| | [`~InferenceClient.fill_mask`] |||| |||
270+
| | [`~InferenceClient.question_answering`] |||| |||
271+
| | [`~InferenceClient.sentence_similarity`] |||| |||
272+
| | [`~InferenceClient.summarization`] |||| |||
273+
| | [`~InferenceClient.table_question_answering`] |||| |||
274+
| | [`~InferenceClient.text_classification`] |||| |||
275+
| | [`~InferenceClient.text_generation`] |||| |||
276+
| | [`~InferenceClient.token_classification`] |||| |||
277+
| | [`~InferenceClient.translation`] |||| |||
278+
| | [`~InferenceClient.zero_shot_classification`] |||| |||
279+
| **Tabular** | [`~InferenceClient.tabular_classification`] |||| |||
280+
| | [`~InferenceClient.tabular_regression`] |||| |||
281281

282282
<Tip>
283283

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class InferenceClient:
133133
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
134134
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
135135
provider (`str`, *optional*):
136-
Name of the provider to use for inference. Can be `"replicate"`, `"together"`, `"fal-ai"`, `"sambanova"` or `"hf-inference"`.
136+
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"replicate"`, "sambanova"`, `"together"`, or `"hf-inference"`.
137137
defaults to hf-inference (Hugging Face Serverless Inference API).
138138
If model is a URL or `base_url` is passed, then `provider` is not used.
139139
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class AsyncInferenceClient:
121121
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
122122
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
123123
provider (`str`, *optional*):
124-
Name of the provider to use for inference. Can be `"replicate"`, `"together"`, `"fal-ai"`, `"sambanova"` or `"hf-inference"`.
124+
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"replicate"`, "sambanova"`, `"together"`, or `"hf-inference"`.
125125
defaults to hf-inference (Hugging Face Serverless Inference API).
126126
If model is a URL or `base_url` is passed, then `provider` is not used.
127127
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
FalAITextToSpeechTask,
88
FalAITextToVideoTask,
99
)
10+
from .fireworks_ai import FireworksAIConversationalTask
1011
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
1112
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
1213
from .sambanova import SambanovaConversationalTask
@@ -15,6 +16,7 @@
1516

1617
PROVIDER_T = Literal[
1718
"fal-ai",
19+
"fireworks-ai",
1820
"hf-inference",
1921
"replicate",
2022
"sambanova",
@@ -28,6 +30,9 @@
2830
"text-to-speech": FalAITextToSpeechTask(),
2931
"text-to-video": FalAITextToVideoTask(),
3032
},
33+
"fireworks-ai": {
34+
"conversational": FireworksAIConversationalTask(),
35+
},
3136
"hf-inference": {
3237
"text-to-image": HFInferenceTask("text-to-image"),
3338
"conversational": HFInferenceConversational(),
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Any, Dict, Optional
2+
3+
from ._common import TaskProviderHelper, filter_none
4+
5+
6+
class FireworksAIConversationalTask(TaskProviderHelper):
7+
def __init__(self):
8+
super().__init__(provider="fireworks-ai", base_url="https://api.fireworks.ai/inference", task="conversational")
9+
10+
def _prepare_route(self, mapped_model: str) -> str:
11+
return "/v1/chat/completions"
12+
13+
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
14+
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
## How to add a new provider?
2+
3+
Before adding a new provider to the `huggingface_hub` library, make sure it has already been added to `huggingface.js` and is working on the Hub. Support in the Python library comes as a second step. In this guide, we are considering that the first part is complete.
4+
5+
### 1. Implement the provider helper
6+
7+
Create a new file under `src/huggingface_hub/inference/_providers/{provider_name}.py` and copy-paste the following snippet.
8+
9+
Implement the methods that require custom handling. Check out the base implementation to check default behavior. If you don't need to override a method, just remove it. At least one of `_prepare_payload` or `_prepare_body` must be overwritten.
10+
11+
If the provider supports multiple tasks that require different implementations, create dedicated subclasses for each task, following the pattern shown in `fal_ai.py`.
12+
13+
```py
14+
from typing import Any, Dict, Optional, Union
15+
16+
from ._common import TaskProviderHelper
17+
18+
19+
class MyNewProviderTaskProviderHelper(TaskProviderHelper):
20+
def __init__(self):
21+
"""Define high-level parameters."""
22+
super().__init__(provider=..., base_url=..., task=...)
23+
24+
def get_response(self, response: Union[bytes, Dict]) -> Any:
25+
"""
26+
Return the response in the expected format.
27+
28+
Override this method in subclasses for customized response handling."""
29+
return super().get_response(response)
30+
31+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
32+
"""Return the headers to use for the request.
33+
34+
Override this method in subclasses for customized headers.
35+
"""
36+
return super()._prepare_headers(headers, api_key)
37+
38+
def _prepare_route(self, mapped_model: str) -> str:
39+
"""Return the route to use for the request.
40+
41+
Override this method in subclasses for customized routes.
42+
"""
43+
return super()._prepare_route(mapped_model)
44+
45+
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
46+
"""Return the payload to use for the request, as a dict.
47+
48+
Override this method in subclasses for customized payloads.
49+
Only one of `_prepare_payload` and `_prepare_body` should return a value.
50+
"""
51+
return super()._prepare_payload(inputs, parameters, mapped_model)
52+
53+
def _prepare_body(
54+
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
55+
) -> Optional[bytes]:
56+
"""Return the body to use for the request, as bytes.
57+
58+
Override this method in subclasses for customized body data.
59+
Only one of `_prepare_payload` and `_prepare_body` should return a value.
60+
"""
61+
return super()._prepare_body(inputs, parameters, mapped_model, extra_payload)
62+
```
63+
64+
### 2. Register the provider helper in `__init__.py`
65+
66+
Go to `src/huggingface_hub/inference/_providers/__init__.py` and add your provider to `PROVIDER_T` and `PROVIDERS`.
67+
Please try to respect alphabetical order.
68+
69+
### 3. Update docstring in `InferenceClient.__init__` to document your provider
70+
71+
### 4. Add static tests in `tests/test_inference_providers.py`
72+
73+
You only have to add a test for overwritten methods.
74+
75+
### 5. Add VCR tests in `tests/test_inference_client.py`
76+
77+
- Add an entry in `_RECOMMENDED_MODELS_FOR_VCR` at the top of the test module. It contains a mapping task <> test model. Model id must be the HF model id.
78+
- Add an entry in `API_KEY_ENV_VARIABLES` to define which env variable should be used
79+
- Run tests locally with `pytest tests/test_inference_client.py -k <provider>` and commit the VCR cassettes.

0 commit comments

Comments
 (0)