diff --git a/docs/source/en/guides/inference.md b/docs/source/en/guides/inference.md index ff6e4349d9..708af3d1e4 100644 --- a/docs/source/en/guides/inference.md +++ b/docs/source/en/guides/inference.md @@ -218,7 +218,7 @@ For more details, refer to the [Inference Providers pricing documentation](https | [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | +| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | | [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index 3f549f4a62..405087d485 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -12,6 +12,7 @@ from .cohere import CohereConversationalTask from .fal_ai import ( FalAIAutomaticSpeechRecognitionTask, + FalAIImageToImageTask, FalAITextToImageTask, FalAITextToSpeechTask, FalAITextToVideoTask, @@ -78,6 +79,7 @@ "text-to-image": FalAITextToImageTask(), "text-to-speech": FalAITextToSpeechTask(), "text-to-video": FalAITextToVideoTask(), + "image-to-image": FalAIImageToImageTask(), }, "featherless-ai": { "conversational": FeatherlessConversationalTask(), diff --git a/src/huggingface_hub/inference/_providers/fal_ai.py b/src/huggingface_hub/inference/_providers/fal_ai.py index 8dd463b6b1..a37728be78 100644 --- a/src/huggingface_hub/inference/_providers/fal_ai.py +++ b/src/huggingface_hub/inference/_providers/fal_ai.py @@ -6,7 +6,7 @@ from huggingface_hub import constants from huggingface_hub.hf_api import InferenceProviderMapping -from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import get_session, hf_raise_for_status from huggingface_hub.utils.logging import get_logger @@ -32,6 +32,60 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return f"/{mapped_model}" +class FalAIQueueTask(TaskProviderHelper, ABC): + def __init__(self, task: str): + super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task) + + def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: + headers = super()._prepare_headers(headers, api_key) + if not api_key.startswith("hf_"): + headers["authorization"] = f"Key {api_key}" + return headers + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + if api_key.startswith("hf_"): + # Use the queue subdomain for HF routing + return f"/{mapped_model}?_subdomain=queue" + return f"/{mapped_model}" + + def get_response( + self, + response: Union[bytes, Dict], + request_params: Optional[RequestParameters] = None, + ) -> Any: + response_dict = _as_dict(response) + + request_id = response_dict.get("request_id") + if not request_id: + raise ValueError("No request ID found in the response") + if request_params is None: + raise ValueError( + f"A `RequestParameters` object should be provided to get {self.task} responses with Fal AI." + ) + + # extract the base url and query params + parsed_url = urlparse(request_params.url) + # a bit hacky way to concatenate the provider name without parsing `parsed_url.path` + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}" + query_param = f"?{parsed_url.query}" if parsed_url.query else "" + + # extracting the provider model id for status and result urls + # from the response as it might be different from the mapped model in `request_params.url` + model_id = urlparse(response_dict.get("response_url")).path + status_url = f"{base_url}{str(model_id)}/status{query_param}" + result_url = f"{base_url}{str(model_id)}{query_param}" + + status = response_dict.get("status") + logger.info("Generating the output.. this can take several minutes.") + while status != "COMPLETED": + time.sleep(_POLLING_INTERVAL) + status_response = get_session().get(status_url, headers=request_params.headers) + hf_raise_for_status(status_response) + status = status_response.json().get("status") + + return get_session().get(result_url, headers=request_params.headers).json() + + class FalAIAutomaticSpeechRecognitionTask(FalAITask): def __init__(self): super().__init__("automatic-speech-recognition") @@ -110,23 +164,10 @@ def get_response(self, response: Union[bytes, Dict], request_params: Optional[Re return get_session().get(url).content -class FalAITextToVideoTask(FalAITask): +class FalAITextToVideoTask(FalAIQueueTask): def __init__(self): super().__init__("text-to-video") - def _prepare_base_url(self, api_key: str) -> str: - if api_key.startswith("hf_"): - return super()._prepare_base_url(api_key) - else: - logger.info(f"Calling '{self.provider}' provider directly.") - return "https://queue.fal.run" - - def _prepare_route(self, mapped_model: str, api_key: str) -> str: - if api_key.startswith("hf_"): - # Use the queue subdomain for HF routing - return f"/{mapped_model}?_subdomain=queue" - return f"/{mapped_model}" - def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: @@ -137,36 +178,38 @@ def get_response( response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None, ) -> Any: - response_dict = _as_dict(response) + output = super().get_response(response, request_params) + url = _as_dict(output)["video"]["url"] + return get_session().get(url).content - request_id = response_dict.get("request_id") - if not request_id: - raise ValueError("No request ID found in the response") - if request_params is None: - raise ValueError( - "A `RequestParameters` object should be provided to get text-to-video responses with Fal AI." - ) - # extract the base url and query params - parsed_url = urlparse(request_params.url) - # a bit hacky way to concatenate the provider name without parsing `parsed_url.path` - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}" - query_param = f"?{parsed_url.query}" if parsed_url.query else "" +class FalAIImageToImageTask(FalAIQueueTask): + def __init__(self): + super().__init__("image-to-image") - # extracting the provider model id for status and result urls - # from the response as it might be different from the mapped model in `request_params.url` - model_id = urlparse(response_dict.get("response_url")).path - status_url = f"{base_url}{str(model_id)}/status{query_param}" - result_url = f"{base_url}{str(model_id)}{query_param}" + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + image_url = _as_url(inputs, default_mime_type="image/jpeg") + payload: Dict[str, Any] = { + "image_url": image_url, + **filter_none(parameters), + } + if provider_mapping_info.adapter_weights_path is not None: + lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( + repo_id=provider_mapping_info.hf_model_id, + revision="main", + filename=provider_mapping_info.adapter_weights_path, + ) + payload["loras"] = [{"path": lora_path, "scale": 1}] - status = response_dict.get("status") - logger.info("Generating the video.. this can take several minutes.") - while status != "COMPLETED": - time.sleep(_POLLING_INTERVAL) - status_response = get_session().get(status_url, headers=request_params.headers) - hf_raise_for_status(status_response) - status = status_response.json().get("status") + return payload - response = get_session().get(result_url, headers=request_params.headers).json() - url = _as_dict(response)["video"]["url"] + def get_response( + self, + response: Union[bytes, Dict], + request_params: Optional[RequestParameters] = None, + ) -> Any: + output = super().get_response(response, request_params) + url = _as_dict(output)["images"][0]["url"] return get_session().get(url).content diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 7819cea852..eea23ed5f4 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -21,6 +21,7 @@ from huggingface_hub.inference._providers.fal_ai import ( _POLLING_INTERVAL, FalAIAutomaticSpeechRecognitionTask, + FalAIImageToImageTask, FalAITextToImageTask, FalAITextToSpeechTask, FalAITextToVideoTask, @@ -408,6 +409,73 @@ def test_text_to_video_response(self, mocker): mock_sleep.assert_called_once_with(_POLLING_INTERVAL) assert response == b"video_content" + def test_image_to_image_payload(self): + helper = FalAIImageToImageTask() + mapping_info = InferenceProviderMapping( + provider="fal-ai", + hf_model_id="stabilityai/sdxl-refiner-1.0", + providerId="fal-ai/sdxl-refiner", + task="image-to-image", + status="live", + ) + payload = helper._prepare_payload_as_dict("https://example.com/image.png", {"prompt": "a cat"}, mapping_info) + assert payload == {"image_url": "https://example.com/image.png", "prompt": "a cat"} + + payload = helper._prepare_payload_as_dict( + b"dummy_image_data", {"prompt": "replace the cat with a dog"}, mapping_info + ) + assert payload == { + "image_url": f"data:image/jpeg;base64,{base64.b64encode(b'dummy_image_data').decode()}", + "prompt": "replace the cat with a dog", + } + + def test_image_to_image_response(self, mocker): + helper = FalAIImageToImageTask() + mock_session = mocker.patch("huggingface_hub.inference._providers.fal_ai.get_session") + mock_sleep = mocker.patch("huggingface_hub.inference._providers.fal_ai.time.sleep") + mock_session.return_value.get.side_effect = [ + # First call: status + mocker.Mock(json=lambda: {"status": "COMPLETED"}, headers={"Content-Type": "application/json"}), + # Second call: get result + mocker.Mock(json=lambda: {"images": [{"url": "image_url"}]}, headers={"Content-Type": "application/json"}), + # Third call: get image content + mocker.Mock(content=b"image_content"), + ] + api_key = helper._prepare_api_key("hf_token") + headers = helper._prepare_headers({}, api_key) + url = helper._prepare_url(api_key, "username/repo_name") + + request_params = RequestParameters( + url=url, + headers=headers, + task="image-to-image", + model="username/repo_name", + data=None, + json=None, + ) + response = helper.get_response( + b'{"request_id": "test_request_id", "status": "PROCESSING", "response_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id", "status_url": "https://queue.fal.run/username_provider/repo_name_provider/requests/test_request_id/status"}', + request_params, + ) + + # Verify the correct URLs were called + assert mock_session.return_value.get.call_count == 3 + mock_session.return_value.get.assert_has_calls( + [ + mocker.call( + "https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id/status?_subdomain=queue", + headers=request_params.headers, + ), + mocker.call( + "https://router.huggingface.co/fal-ai/username_provider/repo_name_provider/requests/test_request_id?_subdomain=queue", + headers=request_params.headers, + ), + mocker.call("image_url"), + ] + ) + mock_sleep.assert_called_once_with(_POLLING_INTERVAL) + assert response == b"image_content" + class TestFeatherlessAIProvider: def test_prepare_route_chat_completionurl(self):