Skip to content

Commit 17e09ab

Browse files
[Inference Providers] Async calls for fal.ai (#2927)
* add async calls for fal-ai * fix * fix test * nit * pass request params to get response * fixes * nit * fix quality * nit * fixes post-review * fix Co-authored-by: Lucain <[email protected]> * remove unnecessary type: ignore Co-authored-by: Lucain <[email protected]> * fix Co-authored-by: Lucain <[email protected]> --------- Co-authored-by: Lucain <[email protected]>
1 parent 4ad0d9a commit 17e09ab

File tree

13 files changed

+147
-53
lines changed

13 files changed

+147
-53
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2617,7 +2617,7 @@ def text_to_video(
26172617
api_key=self.token,
26182618
)
26192619
response = self._inner_post(request_parameters)
2620-
response = provider_helper.get_response(response)
2620+
response = provider_helper.get_response(response, request_parameters)
26212621
return response
26222622

26232623
def text_to_speech(

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2674,7 +2674,7 @@ async def text_to_video(
26742674
api_key=self.token,
26752675
)
26762676
response = await self._inner_post(request_parameters)
2677-
response = provider_helper.get_response(response)
2677+
response = provider_helper.get_response(response, request_parameters)
26782678
return response
26792679

26802680
async def text_to_speech(

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def prepare_request(
8484
raise ValueError("Either payload or data must be set in the request.")
8585
return RequestParameters(url=url, task=self.task, model=mapped_model, json=payload, data=data, headers=headers)
8686

87-
def get_response(self, response: Union[bytes, Dict]) -> Any:
87+
def get_response(
88+
self,
89+
response: Union[bytes, Dict],
90+
request_params: Optional[RequestParameters] = None,
91+
) -> Any:
8892
"""
8993
Return the response in the expected format.
9094
@@ -142,7 +146,7 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str:
142146
143147
Usually not overwritten in subclasses."""
144148
base_url = self._prepare_base_url(api_key)
145-
route = self._prepare_route(mapped_model)
149+
route = self._prepare_route(mapped_model, api_key)
146150
return f"{base_url.rstrip('/')}/{route.lstrip('/')}"
147151

148152
def _prepare_base_url(self, api_key: str) -> str:
@@ -157,7 +161,7 @@ def _prepare_base_url(self, api_key: str) -> str:
157161
logger.info(f"Calling '{self.provider}' provider directly.")
158162
return self.base_url
159163

160-
def _prepare_route(self, mapped_model: str) -> str:
164+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
161165
"""Return the route to use for the request.
162166
163167
Override this method in subclasses for customized routes.
@@ -192,7 +196,7 @@ class BaseConversationalTask(TaskProviderHelper):
192196
def __init__(self, provider: str, base_url: str):
193197
super().__init__(provider=provider, base_url=base_url, task="conversational")
194198

195-
def _prepare_route(self, mapped_model: str) -> str:
199+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
196200
return "/v1/chat/completions"
197201

198202
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
@@ -208,7 +212,7 @@ class BaseTextGenerationTask(TaskProviderHelper):
208212
def __init__(self, provider: str, base_url: str):
209213
super().__init__(provider=provider, base_url=base_url, task="text-generation")
210214

211-
def _prepare_route(self, mapped_model: str) -> str:
215+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
212216
return "/v1/completions"
213217

214218
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:

src/huggingface_hub/inference/_providers/black_forest_labs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import time
22
from typing import Any, Dict, Optional, Union
33

4-
from huggingface_hub.inference._common import _as_dict
4+
from huggingface_hub.inference._common import RequestParameters, _as_dict
55
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
66
from huggingface_hub.utils import logging
77
from huggingface_hub.utils._http import get_session
@@ -24,7 +24,7 @@ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
2424
headers["X-Key"] = api_key
2525
return headers
2626

27-
def _prepare_route(self, mapped_model: str) -> str:
27+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
2828
return f"/v1/{mapped_model}"
2929

3030
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
@@ -36,7 +36,7 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:
3636

3737
return {"prompt": inputs, **parameters}
3838

39-
def get_response(self, response: Union[bytes, Dict]) -> Any:
39+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
4040
"""
4141
Polling mechanism for Black Forest Labs since the API is asynchronous.
4242
"""

src/huggingface_hub/inference/_providers/cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ class CohereConversationalTask(BaseConversationalTask):
1111
def __init__(self):
1212
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
1313

14-
def _prepare_route(self, mapped_model: str) -> str:
14+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
1515
return "/compatibility/v1/chat/completions"

src/huggingface_hub/inference/_providers/fal_ai.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import base64
2+
import time
23
from abc import ABC
34
from typing import Any, Dict, Optional, Union
45

5-
from huggingface_hub.inference._common import _as_dict
6+
from huggingface_hub.inference._common import RequestParameters, _as_dict
67
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
7-
from huggingface_hub.utils import get_session
8+
from huggingface_hub.utils import get_session, hf_raise_for_status
9+
from huggingface_hub.utils.logging import get_logger
10+
11+
12+
logger = get_logger(__name__)
13+
14+
# Arbitrary polling interval
15+
_POLLING_INTERVAL = 2.0
816

917

1018
class FalAITask(TaskProviderHelper, ABC):
@@ -17,7 +25,7 @@ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
1725
headers["authorization"] = f"Key {api_key}"
1826
return headers
1927

20-
def _prepare_route(self, mapped_model: str) -> str:
28+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
2129
return f"/{mapped_model}"
2230

2331

@@ -41,7 +49,7 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:
4149

4250
return {"audio_url": audio_url, **filter_none(parameters)}
4351

44-
def get_response(self, response: Union[bytes, Dict]) -> Any:
52+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
4553
text = _as_dict(response)["text"]
4654
if not isinstance(text, str):
4755
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.")
@@ -61,7 +69,7 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:
6169
}
6270
return {"prompt": inputs, **parameters}
6371

64-
def get_response(self, response: Union[bytes, Dict]) -> Any:
72+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
6573
url = _as_dict(response)["images"][0]["url"]
6674
return get_session().get(url).content
6775

@@ -73,7 +81,7 @@ def __init__(self):
7381
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
7482
return {"lyrics": inputs, **filter_none(parameters)}
7583

76-
def get_response(self, response: Union[bytes, Dict]) -> Any:
84+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
7785
url = _as_dict(response)["audio"]["url"]
7886
return get_session().get(url).content
7987

@@ -82,9 +90,52 @@ class FalAITextToVideoTask(FalAITask):
8290
def __init__(self):
8391
super().__init__("text-to-video")
8492

93+
def _prepare_base_url(self, api_key: str) -> str:
94+
if api_key.startswith("hf_"):
95+
return super()._prepare_base_url(api_key)
96+
else:
97+
logger.info(f"Calling '{self.provider}' provider directly.")
98+
return "https://queue.fal.run"
99+
100+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
101+
if api_key.startswith("hf_"):
102+
# Use the queue subdomain for HF routing
103+
return f"/{mapped_model}?_subdomain=queue"
104+
return f"/{mapped_model}"
105+
85106
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
86107
return {"prompt": inputs, **filter_none(parameters)}
87108

88-
def get_response(self, response: Union[bytes, Dict]) -> Any:
109+
def get_response(
110+
self,
111+
response: Union[bytes, Dict],
112+
request_params: Optional[RequestParameters] = None,
113+
) -> Any:
114+
response_dict = _as_dict(response)
115+
116+
request_id = response_dict.get("request_id")
117+
if not request_id:
118+
raise ValueError("No request ID found in the response")
119+
if request_params is None:
120+
raise ValueError(
121+
"A `RequestParameters` object should be provided to get text-to-video responses with Fal AI."
122+
)
123+
124+
# extract the base url and query params
125+
base_url = request_params.url.split("?")[0] # or parsed.scheme + "://" + parsed.netloc + parsed.path ?
126+
query = "?_subdomain=queue" if request_params.url.endswith("_subdomain=queue") else ""
127+
128+
status_url = f"{base_url}/requests/{request_id}/status{query}"
129+
result_url = f"{base_url}/requests/{request_id}{query}"
130+
131+
status = response_dict.get("status")
132+
logger.info("Generating the video.. this can take several minutes.")
133+
while status != "COMPLETED":
134+
time.sleep(_POLLING_INTERVAL)
135+
status_response = get_session().get(status_url, headers=request_params.headers)
136+
hf_raise_for_status(status_response)
137+
status = status_response.json().get("status")
138+
139+
response = get_session().get(result_url, headers=request_params.headers).json()
89140
url = _as_dict(response)["video"]["url"]
90141
return get_session().get(url).content

src/huggingface_hub/inference/_providers/fireworks_ai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ class FireworksAIConversationalTask(BaseConversationalTask):
55
def __init__(self):
66
super().__init__(provider="fireworks-ai", base_url="https://api.fireworks.ai")
77

8-
def _prepare_route(self, mapped_model: str) -> str:
8+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
99
return "/inference/v1/chat/completions"

src/huggingface_hub/inference/_providers/hyperbolic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import base64
22
from typing import Any, Dict, Optional, Union
33

4-
from huggingface_hub.inference._common import _as_dict
4+
from huggingface_hub.inference._common import RequestParameters, _as_dict
55
from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none
66

77

88
class HyperbolicTextToImageTask(TaskProviderHelper):
99
def __init__(self):
1010
super().__init__(provider="hyperbolic", base_url="https://api.hyperbolic.xyz", task="text-to-image")
1111

12-
def _prepare_route(self, mapped_model: str) -> str:
12+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
1313
return "/v1/images/generations"
1414

1515
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
@@ -25,7 +25,7 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:
2525
parameters["height"] = 512
2626
return {"prompt": inputs, "model_name": mapped_model, **parameters}
2727

28-
def get_response(self, response: Union[bytes, Dict]) -> Any:
28+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
2929
response_dict = _as_dict(response)
3030
return base64.b64decode(response_dict["images"][0]["image"])
3131

src/huggingface_hub/inference/_providers/nebius.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import base64
22
from typing import Any, Dict, Optional, Union
33

4-
from huggingface_hub.inference._common import _as_dict
4+
from huggingface_hub.inference._common import RequestParameters, _as_dict
55
from huggingface_hub.inference._providers._common import (
66
BaseConversationalTask,
77
BaseTextGenerationTask,
@@ -24,7 +24,7 @@ class NebiusTextToImageTask(TaskProviderHelper):
2424
def __init__(self):
2525
super().__init__(task="text-to-image", provider="nebius", base_url="https://api.studio.nebius.ai")
2626

27-
def _prepare_route(self, mapped_model: str) -> str:
27+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
2828
return "/v1/images/generations"
2929

3030
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
@@ -36,6 +36,6 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:
3636

3737
return {"prompt": inputs, **parameters, "model": mapped_model}
3838

39-
def get_response(self, response: Union[bytes, Dict]) -> Any:
39+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
4040
response_dict = _as_dict(response)
4141
return base64.b64decode(response_dict["data"][0]["b64_json"])

src/huggingface_hub/inference/_providers/novita.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Dict, Optional, Union
22

3-
from huggingface_hub.inference._common import _as_dict
3+
from huggingface_hub.inference._common import RequestParameters, _as_dict
44
from huggingface_hub.inference._providers._common import (
55
BaseConversationalTask,
66
BaseTextGenerationTask,
@@ -18,7 +18,7 @@ class NovitaTextGenerationTask(BaseTextGenerationTask):
1818
def __init__(self):
1919
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
2020

21-
def _prepare_route(self, mapped_model: str) -> str:
21+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
2222
# there is no v1/ route for novita
2323
return "/v3/openai/completions"
2424

@@ -27,7 +27,7 @@ class NovitaConversationalTask(BaseConversationalTask):
2727
def __init__(self):
2828
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
2929

30-
def _prepare_route(self, mapped_model: str) -> str:
30+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
3131
# there is no v1/ route for novita
3232
return "/v3/openai/chat/completions"
3333

@@ -36,13 +36,13 @@ class NovitaTextToVideoTask(TaskProviderHelper):
3636
def __init__(self):
3737
super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task="text-to-video")
3838

39-
def _prepare_route(self, mapped_model: str) -> str:
39+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
4040
return f"/v3/hf/{mapped_model}"
4141

4242
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
4343
return {"prompt": inputs, **filter_none(parameters)}
4444

45-
def get_response(self, response: Union[bytes, Dict]) -> Any:
45+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
4646
response_dict = _as_dict(response)
4747
if not (
4848
isinstance(response_dict, dict)

0 commit comments

Comments
 (0)