From 40453feb7d164920a20affb933738581f1e5738a Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 30 Jun 2025 16:52:26 +0200 Subject: [PATCH 1/6] add image-to-image support for replicate --- docs/source/en/guides/inference.md | 2 +- src/huggingface_hub/inference/_client.py | 1 + .../inference/_generated/_async_client.py | 1 + .../inference/_providers/__init__.py | 3 +- .../inference/_providers/replicate.py | 24 ++++++++ tests/test_inference_providers.py | 58 ++++++++++++++++++- 6 files changed, 86 insertions(+), 3 deletions(-) diff --git a/docs/source/en/guides/inference.md b/docs/source/en/guides/inference.md index 36a93f049e..ff6e4349d9 100644 --- a/docs/source/en/guides/inference.md +++ b/docs/source/en/guides/inference.md @@ -218,7 +218,7 @@ For more details, refer to the [Inference Providers pricing documentation](https | [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | | [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 3439dafd89..01360d6f59 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -1338,6 +1338,7 @@ def image_to_image( api_key=self.token, ) response = self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) return _bytes_to_image(response) def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 162d89369f..2ca5632069 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -1384,6 +1384,7 @@ async def image_to_image( api_key=self.token, ) response = await self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) return _bytes_to_image(response) async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index 8d73b837fc..3f549f4a62 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -34,7 +34,7 @@ from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask from .nscale import NscaleConversationalTask, NscaleTextToImageTask from .openai import OpenAIConversationalTask -from .replicate import ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask +from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask @@ -141,6 +141,7 @@ "conversational": OpenAIConversationalTask(), }, "replicate": { + "image-to-image": ReplicateImageToImageTask(), "text-to-image": ReplicateTextToImageTask(), "text-to-speech": ReplicateTextToSpeechTask(), "text-to-video": ReplicateTask("text-to-video"), diff --git a/src/huggingface_hub/inference/_providers/replicate.py b/src/huggingface_hub/inference/_providers/replicate.py index 2ba3127647..73af9466b6 100644 --- a/src/huggingface_hub/inference/_providers/replicate.py +++ b/src/huggingface_hub/inference/_providers/replicate.py @@ -70,3 +70,27 @@ def _prepare_payload_as_dict( payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS return payload + + +class ReplicateImageToImageTask(ReplicateTask): + def __init__(self): + super().__init__("image-to-image") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + import base64 + + if not isinstance(inputs, bytes): + raise TypeError(f"Expected `bytes` for an image-to-image task, but got `{type(inputs)}`.") + + encoded_image = base64.b64encode(inputs).decode("utf-8") + image_uri = f"data:image/jpeg;base64,{encoded_image}" + + payload: Dict[str, Any] = {"input": {"input_image": image_uri, **filter_none(parameters)}} + + mapped_model = provider_mapping_info.provider_id + if ":" in mapped_model: + version = mapped_model.split(":", 1)[1] + payload["version"] = version + return payload diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 3b3a8f671c..c26ffb52a3 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -42,7 +42,11 @@ from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask from huggingface_hub.inference._providers.nscale import NscaleConversationalTask, NscaleTextToImageTask from huggingface_hub.inference._providers.openai import OpenAIConversationalTask -from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask +from huggingface_hub.inference._providers.replicate import ( + ReplicateImageToImageTask, + ReplicateTask, + ReplicateTextToSpeechTask, +) from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask from huggingface_hub.inference._providers.together import TogetherTextToImageTask @@ -1057,6 +1061,58 @@ def test_get_response_single_output(self, mocker): mock.return_value.get.assert_called_once_with("https://example.com/image.jpg") assert response == mock.return_value.get.return_value.content + def test_image_to_image_payload(self): + helper = ReplicateImageToImageTask() + dummy_image = b"dummy image data" + encoded_image = base64.b64encode(dummy_image).decode("utf-8") + image_uri = f"data:image/jpeg;base64,{encoded_image}" + + # No model version + payload = helper._prepare_payload_as_dict( + dummy_image, + {"num_inference_steps": 20}, + InferenceProviderMapping( + provider="replicate", + hf_model_id="google/gemini-pro-vision", + providerId="google/gemini-pro-vision", + task="image-to-image", + status="live", + ), + ) + assert payload == { + "input": {"input_image": image_uri, "num_inference_steps": 20}, + } + + payload = helper._prepare_payload_as_dict( + dummy_image, + {"num_inference_steps": 20}, + InferenceProviderMapping( + provider="replicate", + hf_model_id="google/gemini-pro-vision", + providerId="google/gemini-pro-vision:123456", + task="image-to-image", + status="live", + ), + ) + assert payload == { + "input": {"input_image": image_uri, "num_inference_steps": 20}, + "version": "123456", + } + + # Test with wrong input type + with pytest.raises(TypeError, match="Expected `bytes` for an image-to-image task"): + helper._prepare_payload_as_dict( + "this is not bytes", + {}, + InferenceProviderMapping( + provider="replicate", + hf_model_id="google/gemini-pro-vision", + providerId="google/gemini-pro-vision:123456", + task="image-to-image", + status="live", + ), + ) + class TestSambanovaProvider: def test_prepare_url_conversational(self): From 8605f419d9f689083197aae1cfb329ea71970106 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 30 Jun 2025 17:09:01 +0200 Subject: [PATCH 2/6] fixes --- .../inference/_providers/replicate.py | 31 +++++++++++++------ tests/test_inference_providers.py | 14 --------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/src/huggingface_hub/inference/_providers/replicate.py b/src/huggingface_hub/inference/_providers/replicate.py index 73af9466b6..b7212d4ce5 100644 --- a/src/huggingface_hub/inference/_providers/replicate.py +++ b/src/huggingface_hub/inference/_providers/replicate.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, Optional, Union +import base64 +from pathlib import Path +from typing import Any, BinaryIO, Dict, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters, _as_dict @@ -79,15 +81,24 @@ def __init__(self): def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: - import base64 - - if not isinstance(inputs, bytes): - raise TypeError(f"Expected `bytes` for an image-to-image task, but got `{type(inputs)}`.") - - encoded_image = base64.b64encode(inputs).decode("utf-8") - image_uri = f"data:image/jpeg;base64,{encoded_image}" - - payload: Dict[str, Any] = {"input": {"input_image": image_uri, **filter_none(parameters)}} + if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): + image_url = inputs + else: + image_bytes: bytes + if isinstance(inputs, (str, Path)): + with open(inputs, "rb") as f: + image_bytes = f.read() + elif isinstance(inputs, bytes): + image_bytes = inputs + elif isinstance(inputs, BinaryIO): + image_bytes = inputs.read() + else: + raise TypeError(f"Unsupported input type for image: {type(inputs)}") + + encoded_image = base64.b64encode(image_bytes).decode("utf-8") + image_url = f"data:image/jpeg;base64,{encoded_image}" + + payload: Dict[str, Any] = {"input": {"input_image": image_url, **filter_none(parameters)}} mapped_model = provider_mapping_info.provider_id if ":" in mapped_model: diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index c26ffb52a3..19910a0e26 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -1099,20 +1099,6 @@ def test_image_to_image_payload(self): "version": "123456", } - # Test with wrong input type - with pytest.raises(TypeError, match="Expected `bytes` for an image-to-image task"): - helper._prepare_payload_as_dict( - "this is not bytes", - {}, - InferenceProviderMapping( - provider="replicate", - hf_model_id="google/gemini-pro-vision", - providerId="google/gemini-pro-vision:123456", - task="image-to-image", - status="live", - ), - ) - class TestSambanovaProvider: def test_prepare_url_conversational(self): From 11db080453684b626daf0ef395dd14e1a74ca6fc Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 2 Jul 2025 12:19:46 +0200 Subject: [PATCH 3/6] infer mime type --- src/huggingface_hub/inference/_common.py | 13 ++++++++ .../inference/_providers/replicate.py | 23 ++------------ tests/test_inference_client.py | 31 ++++++++++++++++++- 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index 574f726b67..29b6506cdc 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -18,6 +18,7 @@ import io import json import logging +import mimetypes from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path @@ -197,6 +198,18 @@ def _b64_encode(content: ContentT) -> str: return base64.b64encode(data_as_bytes).decode() +def _as_url(content: ContentT, default_mime_type: str) -> str: + if isinstance(content, str) and (content.startswith("https://") or content.startswith("http://")): + return content + + mime_type: Optional[str] = None + if isinstance(content, (str, Path)): + mime_type, _ = mimetypes.guess_type(str(content)) + final_mime_type = mime_type or default_mime_type + encoded_data = _b64_encode(content) + return f"data:{final_mime_type};base64,{encoded_data}" + + def _b64_to_image(encoded_image: str) -> "Image": """Parse a base64-encoded string into a PIL Image.""" Image = _import_pil_image() diff --git a/src/huggingface_hub/inference/_providers/replicate.py b/src/huggingface_hub/inference/_providers/replicate.py index b7212d4ce5..8a1037b6f2 100644 --- a/src/huggingface_hub/inference/_providers/replicate.py +++ b/src/huggingface_hub/inference/_providers/replicate.py @@ -1,9 +1,7 @@ -import base64 -from pathlib import Path -from typing import Any, BinaryIO, Dict, Optional, Union +from typing import Any, Dict, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping -from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import get_session @@ -81,22 +79,7 @@ def __init__(self): def _prepare_payload_as_dict( self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping ) -> Optional[Dict]: - if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): - image_url = inputs - else: - image_bytes: bytes - if isinstance(inputs, (str, Path)): - with open(inputs, "rb") as f: - image_bytes = f.read() - elif isinstance(inputs, bytes): - image_bytes = inputs - elif isinstance(inputs, BinaryIO): - image_bytes = inputs.read() - else: - raise TypeError(f"Unsupported input type for image: {type(inputs)}") - - encoded_image = base64.b64encode(image_bytes).decode("utf-8") - image_url = f"data:image/jpeg;base64,{encoded_image}" + image_url = _as_url(inputs, default_mime_type="image/jpeg") payload: Dict[str, Any] = {"input": {"input_image": image_url, **filter_none(parameters)}} diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 53628bb02a..99c78a61a4 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -47,7 +47,11 @@ ) from huggingface_hub.errors import HfHubHTTPError, ValidationError from huggingface_hub.inference._client import _open_as_binary -from huggingface_hub.inference._common import _stream_chat_completion_response, _stream_text_generation_response +from huggingface_hub.inference._common import ( + _as_url, + _stream_chat_completion_response, + _stream_text_generation_response, +) from huggingface_hub.inference._providers import get_provider_helper from huggingface_hub.inference._providers.hf_inference import _build_chat_completion_url @@ -1163,3 +1167,28 @@ def test_chat_completion_url_resolution( assert request_params.url == expected_request_url assert request_params.json is not None assert request_params.json.get("model") == expected_payload_model + + +@pytest.mark.parametrize( + "content_input, default_mime_type, expected, is_exact_match", + [ + ("https://my-url.com/cat.gif", "image/jpeg", "https://my-url.com/cat.gif", True), + ("assets/image.png", "image/jpeg", "data:image/png;base64,", False), + (Path("assets/image.png"), "image/jpeg", "data:image/png;base64,", False), + ("assets/image.foo", "image/jpeg", "data:image/jpeg;base64,", False), + (b"some image bytes", "image/jpeg", "", True), + (io.BytesIO(b"some image bytes"), "image/jpeg", "", True), + ], +) +def test_as_url(content_input, default_mime_type, expected, is_exact_match, tmp_path: Path): + if isinstance(content_input, (str, Path)) and not str(content_input).startswith("http"): + file_path = tmp_path / content_input + file_path.parent.mkdir(exist_ok=True, parents=True) + file_path.touch() + content_input = file_path + + result = _as_url(content_input, default_mime_type) + if is_exact_match: + assert result == expected + else: + assert result.startswith(expected) From d1b295cfc5953a9655d1b9cfea82441d747144d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C3=A9lina?= Date: Wed, 2 Jul 2025 14:15:05 +0200 Subject: [PATCH 4/6] Update src/huggingface_hub/inference/_common.py Co-authored-by: Lucain --- src/huggingface_hub/inference/_common.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index 29b6506cdc..f2f8da4e69 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -202,10 +202,11 @@ def _as_url(content: ContentT, default_mime_type: str) -> str: if isinstance(content, str) and (content.startswith("https://") or content.startswith("http://")): return content - mime_type: Optional[str] = None - if isinstance(content, (str, Path)): - mime_type, _ = mimetypes.guess_type(str(content)) - final_mime_type = mime_type or default_mime_type + mime_type = ( + mimetypes.guess_type(content, strict=True)[0] + if isinstance(content, (str, Path)) + else None + ) or default_mime_type encoded_data = _b64_encode(content) return f"data:{final_mime_type};base64,{encoded_data}" From 229fe720819274b7113f3121ca81af4be9a06670 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C3=A9lina?= Date: Wed, 2 Jul 2025 14:16:12 +0200 Subject: [PATCH 5/6] Update src/huggingface_hub/inference/_common.py --- src/huggingface_hub/inference/_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index f2f8da4e69..783a772c77 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -203,7 +203,7 @@ def _as_url(content: ContentT, default_mime_type: str) -> str: return content mime_type = ( - mimetypes.guess_type(content, strict=True)[0] + mimetypes.guess_type(content, strict=False)[0] if isinstance(content, (str, Path)) else None ) or default_mime_type From 3529921823c992535dd1911ed59534e6c151d595 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 2 Jul 2025 14:19:30 +0200 Subject: [PATCH 6/6] fix --- src/huggingface_hub/inference/_common.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index 783a772c77..08732e1c59 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -203,12 +203,10 @@ def _as_url(content: ContentT, default_mime_type: str) -> str: return content mime_type = ( - mimetypes.guess_type(content, strict=False)[0] - if isinstance(content, (str, Path)) - else None + mimetypes.guess_type(content, strict=False)[0] if isinstance(content, (str, Path)) else None ) or default_mime_type encoded_data = _b64_encode(content) - return f"data:{final_mime_type};base64,{encoded_data}" + return f"data:{mime_type};base64,{encoded_data}" def _b64_to_image(encoded_image: str) -> "Image":