Skip to content

Commit 73b8e36

Browse files
[Inference Providers] add image-to-image support for fal.ai provider (#3187)
* add image-to-image support for fal-ai * add tests * fix * use helper
1 parent 338a46b commit 73b8e36

File tree

4 files changed

+156
-43
lines changed

4 files changed

+156
-43
lines changed

docs/source/en/guides/inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
218218
| [`~InferenceClient.fill_mask`] |||||||||||||||
219219
| [`~InferenceClient.image_classification`] |||||||||||||||
220220
| [`~InferenceClient.image_segmentation`] |||||||||||||||
221-
| [`~InferenceClient.image_to_image`] |||| |||||||||||
221+
| [`~InferenceClient.image_to_image`] |||| |||||||||||
222222
| [`~InferenceClient.image_to_text`] |||||||||||||||
223223
| [`~InferenceClient.object_detection`] ||||||||||||||| ||
224224
| [`~InferenceClient.question_answering`] |||||||||||||||

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .cohere import CohereConversationalTask
1313
from .fal_ai import (
1414
FalAIAutomaticSpeechRecognitionTask,
15+
FalAIImageToImageTask,
1516
FalAITextToImageTask,
1617
FalAITextToSpeechTask,
1718
FalAITextToVideoTask,
@@ -78,6 +79,7 @@
7879
"text-to-image": FalAITextToImageTask(),
7980
"text-to-speech": FalAITextToSpeechTask(),
8081
"text-to-video": FalAITextToVideoTask(),
82+
"image-to-image": FalAIImageToImageTask(),
8183
},
8284
"featherless-ai": {
8385
"conversational": FeatherlessConversationalTask(),

src/huggingface_hub/inference/_providers/fal_ai.py

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from huggingface_hub import constants
88
from huggingface_hub.hf_api import InferenceProviderMapping
9-
from huggingface_hub.inference._common import RequestParameters, _as_dict
9+
from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url
1010
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
1111
from huggingface_hub.utils import get_session, hf_raise_for_status
1212
from huggingface_hub.utils.logging import get_logger
@@ -32,6 +32,60 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
3232
return f"/{mapped_model}"
3333

3434

35+
class FalAIQueueTask(TaskProviderHelper, ABC):
36+
def __init__(self, task: str):
37+
super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task)
38+
39+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
40+
headers = super()._prepare_headers(headers, api_key)
41+
if not api_key.startswith("hf_"):
42+
headers["authorization"] = f"Key {api_key}"
43+
return headers
44+
45+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
46+
if api_key.startswith("hf_"):
47+
# Use the queue subdomain for HF routing
48+
return f"/{mapped_model}?_subdomain=queue"
49+
return f"/{mapped_model}"
50+
51+
def get_response(
52+
self,
53+
response: Union[bytes, Dict],
54+
request_params: Optional[RequestParameters] = None,
55+
) -> Any:
56+
response_dict = _as_dict(response)
57+
58+
request_id = response_dict.get("request_id")
59+
if not request_id:
60+
raise ValueError("No request ID found in the response")
61+
if request_params is None:
62+
raise ValueError(
63+
f"A `RequestParameters` object should be provided to get {self.task} responses with Fal AI."
64+
)
65+
66+
# extract the base url and query params
67+
parsed_url = urlparse(request_params.url)
68+
# a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
69+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}"
70+
query_param = f"?{parsed_url.query}" if parsed_url.query else ""
71+
72+
# extracting the provider model id for status and result urls
73+
# from the response as it might be different from the mapped model in `request_params.url`
74+
model_id = urlparse(response_dict.get("response_url")).path
75+
status_url = f"{base_url}{str(model_id)}/status{query_param}"
76+
result_url = f"{base_url}{str(model_id)}{query_param}"
77+
78+
status = response_dict.get("status")
79+
logger.info("Generating the output.. this can take several minutes.")
80+
while status != "COMPLETED":
81+
time.sleep(_POLLING_INTERVAL)
82+
status_response = get_session().get(status_url, headers=request_params.headers)
83+
hf_raise_for_status(status_response)
84+
status = status_response.json().get("status")
85+
86+
return get_session().get(result_url, headers=request_params.headers).json()
87+
88+
3589
class FalAIAutomaticSpeechRecognitionTask(FalAITask):
3690
def __init__(self):
3791
super().__init__("automatic-speech-recognition")
@@ -110,23 +164,10 @@ def get_response(self, response: Union[bytes, Dict], request_params: Optional[Re
110164
return get_session().get(url).content
111165

112166

113-
class FalAITextToVideoTask(FalAITask):
167+
class FalAITextToVideoTask(FalAIQueueTask):
114168
def __init__(self):
115169
super().__init__("text-to-video")
116170

117-
def _prepare_base_url(self, api_key: str) -> str:
118-
if api_key.startswith("hf_"):
119-
return super()._prepare_base_url(api_key)
120-
else:
121-
logger.info(f"Calling '{self.provider}' provider directly.")
122-
return "https://queue.fal.run"
123-
124-
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
125-
if api_key.startswith("hf_"):
126-
# Use the queue subdomain for HF routing
127-
return f"/{mapped_model}?_subdomain=queue"
128-
return f"/{mapped_model}"
129-
130171
def _prepare_payload_as_dict(
131172
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
132173
) -> Optional[Dict]:
@@ -137,36 +178,38 @@ def get_response(
137178
response: Union[bytes, Dict],
138179
request_params: Optional[RequestParameters] = None,
139180
) -> Any:
140-
response_dict = _as_dict(response)
181+
output = super().get_response(response, request_params)
182+
url = _as_dict(output)["video"]["url"]
183+
return get_session().get(url).content
141184

142-
request_id = response_dict.get("request_id")
143-
if not request_id:
144-
raise ValueError("No request ID found in the response")
145-
if request_params is None:
146-
raise ValueError(
147-
"A `RequestParameters` object should be provided to get text-to-video responses with Fal AI."
148-
)
149185

150-
# extract the base url and query params
151-
parsed_url = urlparse(request_params.url)
152-
# a bit hacky way to concatenate the provider name without parsing `parsed_url.path`
153-
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}"
154-
query_param = f"?{parsed_url.query}" if parsed_url.query else ""
186+
class FalAIImageToImageTask(FalAIQueueTask):
187+
def __init__(self):
188+
super().__init__("image-to-image")
155189

156-
# extracting the provider model id for status and result urls
157-
# from the response as it might be different from the mapped model in `request_params.url`
158-
model_id = urlparse(response_dict.get("response_url")).path
159-
status_url = f"{base_url}{str(model_id)}/status{query_param}"
160-
result_url = f"{base_url}{str(model_id)}{query_param}"
190+
def _prepare_payload_as_dict(
191+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
192+
) -> Optional[Dict]:
193+
image_url = _as_url(inputs, default_mime_type="image/jpeg")
194+
payload: Dict[str, Any] = {
195+
"image_url": image_url,
196+
**filter_none(parameters),
197+
}
198+
if provider_mapping_info.adapter_weights_path is not None:
199+
lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format(
200+
repo_id=provider_mapping_info.hf_model_id,
201+
revision="main",
202+
filename=provider_mapping_info.adapter_weights_path,
203+
)
204+
payload["loras"] = [{"path": lora_path, "scale": 1}]
161205

162-
status = response_dict.get("status")
163-
logger.info("Generating the video.. this can take several minutes.")
164-
while status != "COMPLETED":
165-
time.sleep(_POLLING_INTERVAL)
166-
status_response = get_session().get(status_url, headers=request_params.headers)
167-
hf_raise_for_status(status_response)
168-
status = status_response.json().get("status")
206+
return payload
169207

170-
response = get_session().get(result_url, headers=request_params.headers).json()
171-
url = _as_dict(response)["video"]["url"]
208+
def get_response(
209+
self,
210+
response: Union[bytes, Dict],
211+
request_params: Optional[RequestParameters] = None,
212+
) -> Any:
213+
output = super().get_response(response, request_params)
214+
url = _as_dict(output)["images"][0]["url"]
172215
return get_session().get(url).content

tests/test_inference_providers.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from huggingface_hub.inference._providers.fal_ai import (
2222
_POLLING_INTERVAL,
2323
FalAIAutomaticSpeechRecognitionTask,
24+
FalAIImageToImageTask,
2425
FalAITextToImageTask,
2526
FalAITextToSpeechTask,
2627
FalAITextToVideoTask,
@@ -408,6 +409,73 @@ def test_text_to_video_response(self, mocker):
408409
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
409410
assert response == b"video_content"
410411

412+
def test_image_to_image_payload(self):
413+
helper = FalAIImageToImageTask()
414+
mapping_info = InferenceProviderMapping(
415+
provider="fal-ai",
416+
hf_model_id="stabilityai/sdxl-refiner-1.0",
417+
providerId="fal-ai/sdxl-refiner",
418+
task="image-to-image",
419+
status="live",
420+
)
421+
payload = helper._prepare_payload_as_dict("https://example.com/image.png", {"prompt": "a cat"}, mapping_info)
422+
assert payload == {"image_url": "https://example.com/image.png", "prompt": "a cat"}
423+
424+
payload = helper._prepare_payload_as_dict(
425+
b"dummy_image_data", {"prompt": "replace the cat with a dog"}, mapping_info
426+
)
427+
assert payload == {
428+
"image_url": f"data:image/jpeg;base64,{base64.b64encode(b'dummy_image_data').decode()}",
429+
"prompt": "replace the cat with a dog",
430+
}
431+
432+
def test_image_to_image_response(self, mocker):
433+
helper = FalAIImageToImageTask()
434+
mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session")
435+
mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep")
436+
mock_session.return_value.get.side_effect = [
437+
# First call: status
438+
mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}),
439+
# Second call: get result
440+
mocker.Mock(json=lambda: {"images": [{"url": "image_url"}]}, headers={"Content-Type": "application/json"}),
441+
# Third call: get image content
442+
mocker.Mock(content=b"image_content"),
443+
]
444+
api_key = helper._prepare_api_key("hf_token")
445+
headers = helper._prepare_headers({}, api_key)
446+
url = helper._prepare_url(api_key, "username/repo_name")
447+
448+
request_params = RequestParameters(
449+
url=url,
450+
headers=headers,
451+
task="image-to-image",
452+
model="username/repo_name",
453+
data=None,
454+
json=None,
455+
)
456+
response = helper.get_response(
457+
b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}',
458+
request_params,
459+
)
460+
461+
# Verify the correct URLs were called
462+
assert mock_session.return_value.get.call_count == 3
463+
mock_session.return_value.get.assert_has_calls(
464+
[
465+
mocker.call(
466+
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue",
467+
headers=request_params.headers,
468+
),
469+
mocker.call(
470+
"https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue",
471+
headers=request_params.headers,
472+
),
473+
mocker.call("image_url"),
474+
]
475+
)
476+
mock_sleep.assert_called_once_with(_POLLING_INTERVAL)
477+
assert response == b"image_content"
478+
411479

412480
class TestFeatherlessAIProvider:
413481
def test_prepare_route_chat_completionurl(self):

0 commit comments

Comments
 (0)