Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
| [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ |
| [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,7 @@ def image_to_image(
api_key=self.token,
)
response = self._inner_post(request_parameters)
response = provider_helper.get_response(response, request_parameters)
return _bytes_to_image(response)

def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,7 @@ async def image_to_image(
api_key=self.token,
)
response = await self._inner_post(request_parameters)
response = provider_helper.get_response(response, request_parameters)
return _bytes_to_image(response)

async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
Expand Down
3 changes: 2 additions & 1 deletion src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
from .nscale import NscaleConversationalTask, NscaleTextToImageTask
from .openai import OpenAIConversationalTask
from .replicate import ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask

Expand Down Expand Up @@ -141,6 +141,7 @@
"conversational": OpenAIConversationalTask(),
},
"replicate": {
"image-to-image": ReplicateImageToImageTask(),
"text-to-image": ReplicateTextToImageTask(),
"text-to-speech": ReplicateTextToSpeechTask(),
"text-to-video": ReplicateTask("text-to-video"),
Expand Down
37 changes: 36 additions & 1 deletion src/huggingface_hub/inference/_providers/replicate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict, Optional, Union
import base64
from pathlib import Path
from typing import Any, BinaryIO, Dict, Optional, Union

from huggingface_hub.hf_api import InferenceProviderMapping
from huggingface_hub.inference._common import RequestParameters, _as_dict
Expand Down Expand Up @@ -70,3 +72,36 @@ def _prepare_payload_as_dict(
payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment]
payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS
return payload


class ReplicateImageToImageTask(ReplicateTask):
def __init__(self):
super().__init__("image-to-image")

def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
image_url = inputs
else:
image_bytes: bytes
if isinstance(inputs, (str, Path)):
with open(inputs, "rb") as f:
image_bytes = f.read()
elif isinstance(inputs, bytes):
image_bytes = inputs
elif isinstance(inputs, BinaryIO):
image_bytes = inputs.read()
else:
raise TypeError(f"Unsupported input type for image: {type(inputs)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could use _open_as_binary here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or even better def _b64_encode(content: ContentT) -> str:

Copy link
Contributor

@Wauplin Wauplin Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could define a def _as_url(content: ContentT) -> str: helper that takes as input any str/url/path/binary and return either base64-url or a plain url. With a centralized logic we could add some logic to infer the mime-type (see image/jpeg below) if it really become necessary.

(a centralized helper would also take as input PIL.Image object when #3191 is addressed)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice idea! in 11db080, i added a helper _as_url + a way to infer the mime type from a file path when available using mimetypes module.


encoded_image = base64.b64encode(image_bytes).decode("utf-8")
image_url = f"data:image/jpeg;base64,{encoded_image}"

payload: Dict[str, Any] = {"input": {"input_image": image_url, **filter_none(parameters)}}

mapped_model = provider_mapping_info.provider_id
if ":" in mapped_model:
version = mapped_model.split(":", 1)[1]
payload["version"] = version
return payload
44 changes: 43 additions & 1 deletion tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask
from huggingface_hub.inference._providers.nscale import NscaleConversationalTask, NscaleTextToImageTask
from huggingface_hub.inference._providers.openai import OpenAIConversationalTask
from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask
from huggingface_hub.inference._providers.replicate import (
ReplicateImageToImageTask,
ReplicateTask,
ReplicateTextToSpeechTask,
)
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from huggingface_hub.inference._providers.together import TogetherTextToImageTask

Expand Down Expand Up @@ -1057,6 +1061,44 @@ def test_get_response_single_output(self, mocker):
mock.return_value.get.assert_called_once_with("https://example.com/image.jpg")
assert response == mock.return_value.get.return_value.content

def test_image_to_image_payload(self):
helper = ReplicateImageToImageTask()
dummy_image = b"dummy image data"
encoded_image = base64.b64encode(dummy_image).decode("utf-8")
image_uri = f"data:image/jpeg;base64,{encoded_image}"

# No model version
payload = helper._prepare_payload_as_dict(
dummy_image,
{"num_inference_steps": 20},
InferenceProviderMapping(
provider="replicate",
hf_model_id="google/gemini-pro-vision",
providerId="google/gemini-pro-vision",
task="image-to-image",
status="live",
),
)
assert payload == {
"input": {"input_image": image_uri, "num_inference_steps": 20},
}

payload = helper._prepare_payload_as_dict(
dummy_image,
{"num_inference_steps": 20},
InferenceProviderMapping(
provider="replicate",
hf_model_id="google/gemini-pro-vision",
providerId="google/gemini-pro-vision:123456",
task="image-to-image",
status="live",
),
)
assert payload == {
"input": {"input_image": image_uri, "num_inference_steps": 20},
"version": "123456",
}


class TestSambanovaProvider:
def test_prepare_url_conversational(self):
Expand Down