diff --git a/docs/source/en/guides/inference.md b/docs/source/en/guides/inference.md index 36a93f049e..ff6e4349d9 100644 --- a/docs/source/en/guides/inference.md +++ b/docs/source/en/guides/inference.md @@ -218,7 +218,7 @@ For more details, refer to the [Inference Providers pricing documentation](https | [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | | [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 3439dafd89..01360d6f59 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -1338,6 +1338,7 @@ def image_to_image( api_key=self.token, ) response = self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) return _bytes_to_image(response) def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index 574f726b67..08732e1c59 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -18,6 +18,7 @@ import io import json import logging +import mimetypes from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path @@ -197,6 +198,17 @@ def _b64_encode(content: ContentT) -> str: return base64.b64encode(data_as_bytes).decode() +def _as_url(content: ContentT, default_mime_type: str) -> str: + if isinstance(content, str) and (content.startswith("https://") or content.startswith("http://")): + return content + + mime_type = ( + mimetypes.guess_type(content, strict=False)[0] if isinstance(content, (str, Path)) else None + ) or default_mime_type + encoded_data = _b64_encode(content) + return f"data:{mime_type};base64,{encoded_data}" + + def _b64_to_image(encoded_image: str) -> "Image": """Parse a base64-encoded string into a PIL Image.""" Image = _import_pil_image() diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 162d89369f..2ca5632069 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -1384,6 +1384,7 @@ async def image_to_image( api_key=self.token, ) response = await self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) return _bytes_to_image(response) async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index 8d73b837fc..3f549f4a62 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -34,7 +34,7 @@ from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask from .nscale import NscaleConversationalTask, NscaleTextToImageTask from .openai import OpenAIConversationalTask -from .replicate import ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask +from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask @@ -141,6 +141,7 @@ "conversational": OpenAIConversationalTask(), }, "replicate": { + "image-to-image": ReplicateImageToImageTask(), "text-to-image": ReplicateTextToImageTask(), "text-to-speech": ReplicateTextToSpeechTask(), "text-to-video": ReplicateTask("text-to-video"), diff --git a/src/huggingface_hub/inference/_providers/replicate.py b/src/huggingface_hub/inference/_providers/replicate.py index 2ba3127647..8a1037b6f2 100644 --- a/src/huggingface_hub/inference/_providers/replicate.py +++ b/src/huggingface_hub/inference/_providers/replicate.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Optional, Union from huggingface_hub.hf_api import InferenceProviderMapping -from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none from huggingface_hub.utils import get_session @@ -70,3 +70,21 @@ def _prepare_payload_as_dict( payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS return payload + + +class ReplicateImageToImageTask(ReplicateTask): + def __init__(self): + super().__init__("image-to-image") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + image_url = _as_url(inputs, default_mime_type="image/jpeg") + + payload: Dict[str, Any] = {"input": {"input_image": image_url, **filter_none(parameters)}} + + mapped_model = provider_mapping_info.provider_id + if ":" in mapped_model: + version = mapped_model.split(":", 1)[1] + payload["version"] = version + return payload diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 53628bb02a..99c78a61a4 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -47,7 +47,11 @@ ) from huggingface_hub.errors import HfHubHTTPError, ValidationError from huggingface_hub.inference._client import _open_as_binary -from huggingface_hub.inference._common import _stream_chat_completion_response, _stream_text_generation_response +from huggingface_hub.inference._common import ( + _as_url, + _stream_chat_completion_response, + _stream_text_generation_response, +) from huggingface_hub.inference._providers import get_provider_helper from huggingface_hub.inference._providers.hf_inference import _build_chat_completion_url @@ -1163,3 +1167,28 @@ def test_chat_completion_url_resolution( assert request_params.url == expected_request_url assert request_params.json is not None assert request_params.json.get("model") == expected_payload_model + + +@pytest.mark.parametrize( + "content_input, default_mime_type, expected, is_exact_match", + [ + ("https://my-url.com/cat.gif", "image/jpeg", "https://my-url.com/cat.gif", True), + ("assets/image.png", "image/jpeg", "data:image/png;base64,", False), + (Path("assets/image.png"), "image/jpeg", "data:image/png;base64,", False), + ("assets/image.foo", "image/jpeg", "data:image/jpeg;base64,", False), + (b"some image bytes", "image/jpeg", "", True), + (io.BytesIO(b"some image bytes"), "image/jpeg", "", True), + ], +) +def test_as_url(content_input, default_mime_type, expected, is_exact_match, tmp_path: Path): + if isinstance(content_input, (str, Path)) and not str(content_input).startswith("http"): + file_path = tmp_path / content_input + file_path.parent.mkdir(exist_ok=True, parents=True) + file_path.touch() + content_input = file_path + + result = _as_url(content_input, default_mime_type) + if is_exact_match: + assert result == expected + else: + assert result.startswith(expected) diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 3b3a8f671c..19910a0e26 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -42,7 +42,11 @@ from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask from huggingface_hub.inference._providers.nscale import NscaleConversationalTask, NscaleTextToImageTask from huggingface_hub.inference._providers.openai import OpenAIConversationalTask -from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask +from huggingface_hub.inference._providers.replicate import ( + ReplicateImageToImageTask, + ReplicateTask, + ReplicateTextToSpeechTask, +) from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask from huggingface_hub.inference._providers.together import TogetherTextToImageTask @@ -1057,6 +1061,44 @@ def test_get_response_single_output(self, mocker): mock.return_value.get.assert_called_once_with("https://example.com/image.jpg") assert response == mock.return_value.get.return_value.content + def test_image_to_image_payload(self): + helper = ReplicateImageToImageTask() + dummy_image = b"dummy image data" + encoded_image = base64.b64encode(dummy_image).decode("utf-8") + image_uri = f"data:image/jpeg;base64,{encoded_image}" + + # No model version + payload = helper._prepare_payload_as_dict( + dummy_image, + {"num_inference_steps": 20}, + InferenceProviderMapping( + provider="replicate", + hf_model_id="google/gemini-pro-vision", + providerId="google/gemini-pro-vision", + task="image-to-image", + status="live", + ), + ) + assert payload == { + "input": {"input_image": image_uri, "num_inference_steps": 20}, + } + + payload = helper._prepare_payload_as_dict( + dummy_image, + {"num_inference_steps": 20}, + InferenceProviderMapping( + provider="replicate", + hf_model_id="google/gemini-pro-vision", + providerId="google/gemini-pro-vision:123456", + task="image-to-image", + status="live", + ), + ) + assert payload == { + "input": {"input_image": image_uri, "num_inference_steps": 20}, + "version": "123456", + } + class TestSambanovaProvider: def test_prepare_url_conversational(self):