Skip to content

Commit ced4d89

Browse files
Add Black Forest Labs provider (#2864)
* add bfl * update table * add logging
1 parent a7f3151 commit ced4d89

File tree

8 files changed

+6325
-32
lines changed

8 files changed

+6325
-32
lines changed

docs/source/en/guides/inference.md

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -248,36 +248,36 @@ You might wonder why using [`InferenceClient`] instead of OpenAI's client? There
248248

249249
[`InferenceClient`]'s goal is to provide the easiest interface to run inference on Hugging Face models, on any provider. It has a simple API that supports the most common tasks. Here is a table showing which providers support which tasks:
250250

251-
| Domain | Task | HF Inference | fal-ai | Fireworks AI | Hyperbolic | Nebius AI Studio | Novita AI | Replicate | Sambanova | Together |
252-
| ------------------- | --------------------------------------------------- | ------------ | ------ | --------- | ---------- | ---------------- | ------ | --------- | --------- | ----------- |
253-
| **Audio** | [`~InferenceClient.audio_classification`] ||||||||||
254-
| | [`~InferenceClient.audio_to_audio`] ||||||||||
255-
| | [`~InferenceClient.automatic_speech_recognition`] ||||||||||
256-
| | [`~InferenceClient.text_to_speech`] ||||||||||
257-
| **Computer Vision** | [`~InferenceClient.image_classification`] ||||||||||
258-
| | [`~InferenceClient.image_segmentation`] ||||||||||
259-
| | [`~InferenceClient.image_to_image`] ||||||||||
260-
| | [`~InferenceClient.image_to_text`] ||||||||||
261-
| | [`~InferenceClient.object_detection`] ||||||||||
262-
| | [`~InferenceClient.text_to_image`] ||||||||||
263-
| | [`~InferenceClient.text_to_video`] ||||||||||
264-
| | [`~InferenceClient.zero_shot_image_classification`] ||||||||||
265-
| **Multimodal** | [`~InferenceClient.document_question_answering`] ||||||||||
266-
| | [`~InferenceClient.visual_question_answering`] ||||||||||
267-
| **NLP** | [`~InferenceClient.chat_completion`] ||||||||||
268-
| | [`~InferenceClient.feature_extraction`] ||||||||||
269-
| | [`~InferenceClient.fill_mask`] ||||||||||
270-
| | [`~InferenceClient.question_answering`] ||||||||||
271-
| | [`~InferenceClient.sentence_similarity`] ||||||||||
272-
| | [`~InferenceClient.summarization`] ||||||||||
273-
| | [`~InferenceClient.table_question_answering`] ||||||||||
274-
| | [`~InferenceClient.text_classification`] ||||||||||
275-
| | [`~InferenceClient.text_generation`] ||||||||||
276-
| | [`~InferenceClient.token_classification`] ||||||||||
277-
| | [`~InferenceClient.translation`] ||||||||||
278-
| | [`~InferenceClient.zero_shot_classification`] ||||||||||
279-
| **Tabular** | [`~InferenceClient.tabular_classification`] ||||||||||
280-
| | [`~InferenceClient.tabular_regression`] ||||||||||
251+
| Domain | Task | Black Forest Labs | HF Inference | fal-ai | Fireworks AI | Hyperbolic | Nebius AI Studio | Novita AI | Replicate | Sambanova | Together |
252+
| ------------------- | --------------------------------------------------- | ---------------- | ------------ | ------ | --------- | ---------- | ---------------- | ------ | --------- | --------- | ----------- |
253+
| **Audio** | [`~InferenceClient.audio_classification`] | | |||||||||
254+
| | [`~InferenceClient.audio_to_audio`] | | |||||||||
255+
| | [`~InferenceClient.automatic_speech_recognition`] | | |||||||||
256+
| | [`~InferenceClient.text_to_speech`] | | |||||||||
257+
| **Computer Vision** | [`~InferenceClient.image_classification`] | | |||||||||
258+
| | [`~InferenceClient.image_segmentation`] | | |||||||||
259+
| | [`~InferenceClient.image_to_image`] | | |||||||||
260+
| | [`~InferenceClient.image_to_text`] | | |||||||||
261+
| | [`~InferenceClient.object_detection`] | | |||||||||
262+
| | [`~InferenceClient.text_to_image`] | ||||||||||
263+
| | [`~InferenceClient.text_to_video`] | ||||||||||
264+
| | [`~InferenceClient.zero_shot_image_classification`] | | |||||||||
265+
| **Multimodal** | [`~InferenceClient.document_question_answering`] | | |||||||||
266+
| | [`~InferenceClient.visual_question_answering`] | | |||||||||
267+
| **NLP** | [`~InferenceClient.chat_completion`] | | |||||||||
268+
| | [`~InferenceClient.feature_extraction`] | | |||||||||
269+
| | [`~InferenceClient.fill_mask`] | | |||||||||
270+
| | [`~InferenceClient.question_answering`] | | |||||||||
271+
| | [`~InferenceClient.sentence_similarity`] | | |||||||||
272+
| | [`~InferenceClient.summarization`] | | |||||||||
273+
| | [`~InferenceClient.table_question_answering`] | | |||||||||
274+
| | [`~InferenceClient.text_classification`] | | |||||||||
275+
| | [`~InferenceClient.text_generation`] | | |||||||||
276+
| | [`~InferenceClient.token_classification`] | | |||||||||
277+
| | [`~InferenceClient.translation`] | | |||||||||
278+
| | [`~InferenceClient.zero_shot_classification`] | | |||||||||
279+
| **Tabular** | [`~InferenceClient.tabular_classification`] | | |||||||||
280+
| | [`~InferenceClient.tabular_regression`] | | |||||||||
281281

282282
<Tip>
283283

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class InferenceClient:
132132
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
133133
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
134134
provider (`str`, *optional*):
135-
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
135+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
136136
defaults to hf-inference (Hugging Face Serverless Inference API).
137137
If model is a URL or `base_url` is passed, then `provider` is not used.
138138
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class AsyncInferenceClient:
120120
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
121121
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
122122
provider (`str`, *optional*):
123-
Name of the provider to use for inference. Can be "fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
123+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"replicate"`, "sambanova"` or `"together"`.
124124
defaults to hf-inference (Hugging Face Serverless Inference API).
125125
If model is a URL or `base_url` is passed, then `provider` is not used.
126126
token (`str` or `bool`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, Literal
22

33
from ._common import TaskProviderHelper
4+
from .black_forest_labs import BlackForestLabsTextToImageTask
45
from .fal_ai import (
56
FalAIAutomaticSpeechRecognitionTask,
67
FalAITextToImageTask,
@@ -18,6 +19,7 @@
1819

1920

2021
PROVIDER_T = Literal[
22+
"black-forest-labs",
2123
"fal-ai",
2224
"fireworks-ai",
2325
"hf-inference",
@@ -30,6 +32,9 @@
3032
]
3133

3234
PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
35+
"black-forest-labs": {
36+
"text-to-image": BlackForestLabsTextToImageTask(),
37+
},
3338
"fal-ai": {
3439
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
3540
"text-to-image": FalAITextToImageTask(),
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import time
2+
from typing import Any, Dict, Optional, Union
3+
4+
from huggingface_hub.inference._common import _as_dict
5+
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
6+
from huggingface_hub.utils import logging
7+
from huggingface_hub.utils._http import get_session
8+
9+
10+
logger = logging.get_logger(__name__)
11+
12+
MAX_POLLING_ATTEMPTS = 6
13+
POLLING_INTERVAL = 1.0
14+
15+
16+
class BlackForestLabsTextToImageTask(TaskProviderHelper):
17+
def __init__(self):
18+
super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai/v1", task="text-to-image")
19+
20+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
21+
headers = super()._prepare_headers(headers, api_key)
22+
if not api_key.startswith("hf_"):
23+
_ = headers.pop("authorization")
24+
headers["X-Key"] = api_key
25+
return headers
26+
27+
def _prepare_route(self, mapped_model: str) -> str:
28+
return mapped_model
29+
30+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
31+
parameters = filter_none(parameters)
32+
if "num_inference_steps" in parameters:
33+
parameters["steps"] = parameters.pop("num_inference_steps")
34+
if "guidance_scale" in parameters:
35+
parameters["guidance"] = parameters.pop("guidance_scale")
36+
37+
return {"prompt": inputs, **parameters}
38+
39+
def get_response(self, response: Union[bytes, Dict]) -> Any:
40+
"""
41+
Polling mechanism for Black Forest Labs since the API is asynchronous.
42+
"""
43+
url = _as_dict(response).get("polling_url")
44+
session = get_session()
45+
for _ in range(MAX_POLLING_ATTEMPTS):
46+
time.sleep(POLLING_INTERVAL)
47+
48+
response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore
49+
response.raise_for_status() # type: ignore
50+
response_json: Dict = response.json() # type: ignore
51+
status = response_json.get("status")
52+
logger.info(
53+
f"Polling generation result from {url}. Current status: {status}. "
54+
f"Will retry after {POLLING_INTERVAL} seconds if not ready."
55+
)
56+
57+
if (
58+
status == "Ready"
59+
and isinstance(response_json.get("result"), dict)
60+
and (sample_url := response_json["result"].get("sample"))
61+
):
62+
image_resp = session.get(sample_url)
63+
image_resp.raise_for_status()
64+
return image_resp.content
65+
66+
raise TimeoutError(f"Failed to get the image URL after {MAX_POLLING_ATTEMPTS} attempts.")

0 commit comments

Comments
 (0)