From d60b0e93849e6c7cb6b7035131874f11dfae5d0f Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 30 Jun 2025 16:31:11 +0200 Subject: [PATCH 1/4] add image-to-image support for fal-ai --- docs/source/en/guides/inference.md | 2 +- src/huggingface_hub/inference/_client.py | 1 + .../inference/_generated/_async_client.py | 1 + .../inference/_providers/__init__.py | 2 + .../inference/_providers/fal_ai.py | 134 ++++++++++++------ 5 files changed, 98 insertions(+), 42 deletions(-) diff --git a/docs/source/en/guides/inference.md b/docs/source/en/guides/inference.md index 36a93f049e..e9975a71b6 100644 --- a/docs/source/en/guides/inference.md +++ b/docs/source/en/guides/inference.md @@ -218,7 +218,7 @@ For more details, refer to the [Inference Providers pricing documentation](https | [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | | [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 3439dafd89..01360d6f59 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -1338,6 +1338,7 @@ def image_to_image( api_key=self.token, ) response = self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) return _bytes_to_image(response) def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 162d89369f..2ca5632069 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -1384,6 +1384,7 @@ async def image_to_image( api_key=self.token, ) response = await self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) return _bytes_to_image(response) async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index 8d73b837fc..f453d857c2 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -12,6 +12,7 @@ from .cohere import CohereConversationalTask from .fal_ai import ( FalAIAutomaticSpeechRecognitionTask, + FalAIImageToImageTask, FalAITextToImageTask, FalAITextToSpeechTask, FalAITextToVideoTask, @@ -78,6 +79,7 @@ "text-to-image": FalAITextToImageTask(), "text-to-speech": FalAITextToSpeechTask(), "text-to-video": FalAITextToVideoTask(), + "image-to-image": FalAIImageToImageTask(), }, "featherless-ai": { "conversational": FeatherlessConversationalTask(), diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index 8dd463b6b1..bc97dfe465 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -32,6 +32,60 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/{mapped_model}" +class FalAIQueueTask(TaskProviderHelper, ABC): + def __init__(self, task: str): + super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task) + + def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: + headers = super()._prepare_headers(headers, api_key) + if not api_key.startswith("hf_"): + headers["authorization"] = f"Key {api_key}" + return headers + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + if api_key.startswith("hf_"): + # Use the queue subdomain for HF routing + return f"/{mapped_model}?_subdomain=queue" + return f"/{mapped_model}" + + def get_response( + self, + response: Union[bytes, Dict], + request_params: Optional[RequestParameters] = None, + ) -> Any: + response_dict = _as_dict(response) + + request_id = response_dict.get("request_id") + if not request_id: + raise ValueError("No request ID found in the response") + if request_params is None: + raise ValueError( + f"A `RequestParameters` object should be provided to get {self.task} responses with Fal AI." + ) + + # extract the base url and query params + parsed_url = urlparse(request_params.url) + # a bit hacky way to concatenate the provider name without parsing `parsed_url.path` + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}" + query_param = f"?{parsed_url.query}" if parsed_url.query else "" + + # extracting the provider model id for status and result urls + # from the response as it might be different from the mapped model in `request_params.url` + model_id = urlparse(response_dict.get("response_url")).path + status_url = f"{base_url}{str(model_id)}/status{query_param}" + result_url = f"{base_url}{str(model_id)}{query_param}" + + status = response_dict.get("status") + logger.info("Generating the output.. this can take several minutes.") + while status != "COMPLETED": + time.sleep(_POLLING_INTERVAL) + status_response = get_session().get(status_url, headers=request_params.headers) + hf_raise_for_status(status_response) + status = status_response.json().get("status") + + return get_session().get(result_url, headers=request_params.headers).json() + + class FalAIAutomaticSpeechRecognitionTask(FalAITask): def __init__(self): super().__init__("automatic-speech-recognition") @@ -110,23 +164,10 @@ def get_response(self, response: Union[bytes, Dict], request_params: Optional[Re return get_session().get(url).content -class FalAITextToVideoTask(FalAITask): +class FalAITextToVideoTask(FalAIQueueTask): def __init__(self): super().__init__("text-to-video") - def _prepare_base_url(self, api_key: str) -> str: - if api_key.startswith("hf_"): - return super()._prepare_base_url(api_key) - else: - logger.info(f"Calling '{self.provider}' provider directly.") - return "https://queue.fal.run" - - def _prepare_route(self, mapped_model: str, api_key: str) -> str: - if api_key.startswith("hf_"): - # Use the queue subdomain for HF routing - return f"/{mapped_model}?_subdomain=queue" - return f"/{mapped_model}" - def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: @@ -137,36 +178,47 @@ def get_response( response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None, ) -> Any: - response_dict = _as_dict(response) + output = super().get_response(response, request_params) + url = _as_dict(output)["video"]["url"] + return get_session().get(url).content - request_id = response_dict.get("request_id") - if not request_id: - raise ValueError("No request ID found in the response") - if request_params is None: - raise ValueError( - "A `RequestParameters` object should be provided to get text-to-video responses with Fal AI." - ) - # extract the base url and query params - parsed_url = urlparse(request_params.url) - # a bit hacky way to concatenate the provider name without parsing `parsed_url.path` - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}" - query_param = f"?{parsed_url.query}" if parsed_url.query else "" +class FalAIImageToImageTask(FalAIQueueTask): + def __init__(self): + super().__init__("image-to-image") - # extracting the provider model id for status and result urls - # from the response as it might be different from the mapped model in `request_params.url` - model_id = urlparse(response_dict.get("response_url")).path - status_url = f"{base_url}{str(model_id)}/status{query_param}" - result_url = f"{base_url}{str(model_id)}{query_param}" + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): + image_url = inputs + else: + if isinstance(inputs, str): + with open(inputs, "rb") as f: + inputs = f.read() - status = response_dict.get("status") - logger.info("Generating the video.. this can take several minutes.") - while status != "COMPLETED": - time.sleep(_POLLING_INTERVAL) - status_response = get_session().get(status_url, headers=request_params.headers) - hf_raise_for_status(status_response) - status = status_response.json().get("status") + image_b64 = base64.b64encode(inputs).decode() + content_type = "image/png" + image_url = f"data:{content_type};base64,{image_b64}" + payload: Dict[str, Any] = { + "image_url": image_url, + **filter_none(parameters), + } + if provider_mapping_info.adapter_weights_path is not None: + lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( + repo_id=provider_mapping_info.hf_model_id, + revision="main", + filename=provider_mapping_info.adapter_weights_path, + ) + payload["loras"] = [{"path": lora_path, "scale": 1}] - response = get_session().get(result_url, headers=request_params.headers).json() - url = _as_dict(response)["video"]["url"] + return payload + + def get_response( + self, + response: Union[bytes, Dict], + request_params: Optional[RequestParameters] = None, + ) -> Any: + output = super().get_response(response, request_params) + url = _as_dict(output)["images"][0]["url"] return get_session().get(url).content From b2202c7df4c355b31dc10689d20a2b1a0993cfb0 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 30 Jun 2025 16:32:09 +0200 Subject: [PATCH 2/4] add tests --- tests/test_inference_providers.py | 68 +++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 3b3a8f671c..b9582d5107 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -21,6 +21,7 @@ from huggingface_hub.inference._providers.fal_ai import ( _POLLING_INTERVAL, FalAIAutomaticSpeechRecognitionTask, + FalAIImageToImageTask, FalAITextToImageTask, FalAITextToSpeechTask, FalAITextToVideoTask, @@ -404,6 +405,73 @@ def test_text_to_video_response(self, mocker): mock_sleep.assert_called_once_with(_POLLING_INTERVAL) assert response == b"video_content" + def test_image_to_image_payload(self): + helper = FalAIImageToImageTask() + mapping_info = InferenceProviderMapping( + provider="fal-ai", + hf_model_id="stabilityai/sdxl-refiner-1.0", + providerId="fal-ai/sdxl-refiner", + task="image-to-image", + status="live", + ) + payload = helper._prepare_payload_as_dict("https://example.com/image.png", {"prompt": "a cat"}, mapping_info) + assert payload == {"image_url": "https://example.com/image.png", "prompt": "a cat"} + + payload = helper._prepare_payload_as_dict( + b"dummy_image_data", {"prompt": "replace the cat with a dog"}, mapping_info + ) + assert payload == { + "image_url": f"data:image/png;base64,{base64.b64encode(b'dummy_image_data').decode()}", + "prompt": "replace the cat with a dog", + } + + def test_image_to_image_response(self, mocker): + helper = FalAIImageToImageTask() + mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session") + mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep") + mock_session.return_value.get.side_effect = [ + # First call: status + mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}), + # Second call: get result + mocker.Mock(json=lambda: {"images": [{"url": "image_url"}]}, headers={"Content-Type": "application/json"}), + # Third call: get image content + mocker.Mock(content=b"image_content"), + ] + api_key = helper._prepare_api_key("hf_token") + headers = helper._prepare_headers({}, api_key) + url = helper._prepare_url(api_key, "username/repo_name") + + request_params = RequestParameters( + url=url, + headers=headers, + task="image-to-image", + model="username/repo_name", + data=None, + json=None, + ) + response = helper.get_response( + b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}', + request_params, + ) + + # Verify the correct URLs were called + assert mock_session.return_value.get.call_count == 3 + mock_session.return_value.get.assert_has_calls( + [ + mocker.call( + "https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue", + headers=request_params.headers, + ), + mocker.call( + "https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue", + headers=request_params.headers, + ), + mocker.call("image_url"), + ] + ) + mock_sleep.assert_called_once_with(_POLLING_INTERVAL) + assert response == b"image_content" + class TestFeatherlessAIProvider: def test_prepare_route_chat_completionurl(self): From 3a06297fa4516043dff84e5058f38308d4e041c2 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 30 Jun 2025 17:12:25 +0200 Subject: [PATCH 3/4] fix --- .../inference/_providers/fal_ai.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index bc97dfe465..5b2eaaf2d5 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -1,7 +1,8 @@ import base64 import time from abc import ABC -from typing import Any, Dict, Optional, Union +from pathlib import Path +from typing import Any, BinaryIO, Dict, Optional, Union from urllib.parse import urlparse from huggingface_hub import constants @@ -193,11 +194,18 @@ def _prepare_payload_as_dict( if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): image_url = inputs else: - if isinstance(inputs, str): + image_bytes: bytes + if isinstance(inputs, (str, Path)): with open(inputs, "rb") as f: - inputs = f.read() - - image_b64 = base64.b64encode(inputs).decode() + image_bytes = f.read() + elif isinstance(inputs, bytes): + image_bytes = inputs + elif isinstance(inputs, BinaryIO): + image_bytes = inputs.read() + else: + raise TypeError(f"Unsupported input type for image: {type(inputs)}") + + image_b64 = base64.b64encode(image_bytes).decode() content_type = "image/png" image_url = f"data:{content_type};base64,{image_b64}" payload: Dict[str, Any] = { From 8379532982b34ff958e5676f685208fe28509dfb Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 2 Jul 2025 23:43:47 +0200 Subject: [PATCH 4/4] use helper --- .../inference/_providers/fal_ai.py | 23 +++---------------- tests/test_inference_providers.py | 2 +- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index 5b2eaaf2d5..a37728be78 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -1,13 +1,12 @@ import base64 import time from abc import ABC -from pathlib import Path -from typing import Any, BinaryIO, Dict, Optional, Union +from typing import Any, Dict, Optional, Union from urllib.parse import urlparse from huggingface_hub import constants from huggingface_hub.hf_api import InferenceProviderMapping -from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import get_session, hf_raise_for_status from huggingface_hub.utils.logging import get_logger @@ -191,23 +190,7 @@ def __init__(self): def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: - if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): - image_url = inputs - else: - image_bytes: bytes - if isinstance(inputs, (str, Path)): - with open(inputs, "rb") as f: - image_bytes = f.read() - elif isinstance(inputs, bytes): - image_bytes = inputs - elif isinstance(inputs, BinaryIO): - image_bytes = inputs.read() - else: - raise TypeError(f"Unsupported input type for image: {type(inputs)}") - - image_b64 = base64.b64encode(image_bytes).decode() - content_type = "image/png" - image_url = f"data:{content_type};base64,{image_b64}" + image_url = _as_url(inputs, default_mime_type="image/jpeg") payload: Dict[str, Any] = { "image_url": image_url, **filter_none(parameters), diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index cb9ce2d00a..eea23ed5f4 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -425,7 +425,7 @@ def test_image_to_image_payload(self): b"dummy_image_data", {"prompt": "replace the cat with a dog"}, mapping_info ) assert payload == { - "image_url": f"data:image/png;base64,{base64.b64encode(b'dummy_image_data').decode()}", + "image_url": f"data:image/jpeg;base64,{base64.b64encode(b'dummy_image_data').decode()}", "prompt": "replace the cat with a dog", }