Skip to content

Commit 5841172

Browse files
hanouticelinaWauplin
authored andcommitted
[Inference Providers] add image-to-image support for Replicate provider (huggingface#3188)
* add image-to-image support for replicate * fixes * infer mime type * Update src/huggingface_hub/inference/_common.py Co-authored-by: Lucain <[email protected]> * Update src/huggingface_hub/inference/_common.py * fix --------- Co-authored-by: Lucain <[email protected]>
1 parent 11ec5c7 commit 5841172

File tree

8 files changed

+109
-5
lines changed

8 files changed

+109
-5
lines changed

docs/source/en/guides/inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
218218
| [`~InferenceClient.fill_mask`] |||||||||||||||
219219
| [`~InferenceClient.image_classification`] |||||||||||||||
220220
| [`~InferenceClient.image_segmentation`] |||||||||||||||
221-
| [`~InferenceClient.image_to_image`] |||||||||||| |||
221+
| [`~InferenceClient.image_to_image`] |||||||||||| |||
222222
| [`~InferenceClient.image_to_text`] |||||||||||||||
223223
| [`~InferenceClient.object_detection`] ||||||||||||||| ||
224224
| [`~InferenceClient.question_answering`] |||||||||||||||

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,7 @@ def image_to_image(
13381338
api_key=self.token,
13391339
)
13401340
response = self._inner_post(request_parameters)
1341+
response = provider_helper.get_response(response, request_parameters)
13411342
return _bytes_to_image(response)
13421343

13431344
def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:

src/huggingface_hub/inference/_common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import io
1919
import json
2020
import logging
21+
import mimetypes
2122
from contextlib import contextmanager
2223
from dataclasses import dataclass
2324
from pathlib import Path
@@ -197,6 +198,17 @@ def _b64_encode(content: ContentT) -> str:
197198
return base64.b64encode(data_as_bytes).decode()
198199

199200

201+
def _as_url(content: ContentT, default_mime_type: str) -> str:
202+
if isinstance(content, str) and (content.startswith("https://") or content.startswith("http://")):
203+
return content
204+
205+
mime_type = (
206+
mimetypes.guess_type(content, strict=False)[0] if isinstance(content, (str, Path)) else None
207+
) or default_mime_type
208+
encoded_data = _b64_encode(content)
209+
return f"data:{mime_type};base64,{encoded_data}"
210+
211+
200212
def _b64_to_image(encoded_image: str) -> "Image":
201213
"""Parse a base64-encoded string into a PIL Image."""
202214
Image = _import_pil_image()

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,7 @@ async def image_to_image(
13841384
api_key=self.token,
13851385
)
13861386
response = await self._inner_post(request_parameters)
1387+
response = provider_helper.get_response(response, request_parameters)
13871388
return _bytes_to_image(response)
13881389

13891390
async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
3535
from .nscale import NscaleConversationalTask, NscaleTextToImageTask
3636
from .openai import OpenAIConversationalTask
37-
from .replicate import ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
37+
from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
3838
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
3939
from .swarmind import SwarmindConversationalTask, SwarmindTextGenerationTask
4040
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
@@ -142,6 +142,7 @@
142142
"conversational": OpenAIConversationalTask(),
143143
},
144144
"replicate": {
145+
"image-to-image": ReplicateImageToImageTask(),
145146
"text-to-image": ReplicateTextToImageTask(),
146147
"text-to-speech": ReplicateTextToSpeechTask(),
147148
"text-to-video": ReplicateTask("text-to-video"),

src/huggingface_hub/inference/_providers/replicate.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, Optional, Union
22

33
from huggingface_hub.hf_api import InferenceProviderMapping
4-
from huggingface_hub.inference._common import RequestParameters, _as_dict
4+
from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url
55
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
66
from huggingface_hub.utils import get_session
77

@@ -70,3 +70,21 @@ def _prepare_payload_as_dict(
7070
payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment]
7171
payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS
7272
return payload
73+
74+
75+
class ReplicateImageToImageTask(ReplicateTask):
76+
def __init__(self):
77+
super().__init__("image-to-image")
78+
79+
def _prepare_payload_as_dict(
80+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
81+
) -> Optional[Dict]:
82+
image_url = _as_url(inputs, default_mime_type="image/jpeg")
83+
84+
payload: Dict[str, Any] = {"input": {"input_image": image_url, **filter_none(parameters)}}
85+
86+
mapped_model = provider_mapping_info.provider_id
87+
if ":" in mapped_model:
88+
version = mapped_model.split(":", 1)[1]
89+
payload["version"] = version
90+
return payload

tests/test_inference_client.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@
4747
)
4848
from huggingface_hub.errors import HfHubHTTPError, ValidationError
4949
from huggingface_hub.inference._client import _open_as_binary
50-
from huggingface_hub.inference._common import _stream_chat_completion_response, _stream_text_generation_response
50+
from huggingface_hub.inference._common import (
51+
_as_url,
52+
_stream_chat_completion_response,
53+
_stream_text_generation_response,
54+
)
5155
from huggingface_hub.inference._providers import get_provider_helper
5256
from huggingface_hub.inference._providers.hf_inference import _build_chat_completion_url
5357

@@ -1163,3 +1167,28 @@ def test_chat_completion_url_resolution(
11631167
assert request_params.url == expected_request_url
11641168
assert request_params.json is not None
11651169
assert request_params.json.get("model") == expected_payload_model
1170+
1171+
1172+
@pytest.mark.parametrize(
1173+
"content_input, default_mime_type, expected, is_exact_match",
1174+
[
1175+
("https://my-url.com/cat.gif", "image/jpeg", "https://my-url.com/cat.gif", True),
1176+
("assets/image.png", "image/jpeg", "data:image/png;base64,", False),
1177+
(Path("assets/image.png"), "image/jpeg", "data:image/png;base64,", False),
1178+
("assets/image.foo", "image/jpeg", "data:image/jpeg;base64,", False),
1179+
(b"some image bytes", "image/jpeg", "", True),
1180+
(io.BytesIO(b"some image bytes"), "image/jpeg", "", True),
1181+
],
1182+
)
1183+
def test_as_url(content_input, default_mime_type, expected, is_exact_match, tmp_path: Path):
1184+
if isinstance(content_input, (str, Path)) and not str(content_input).startswith("http"):
1185+
file_path = tmp_path / content_input
1186+
file_path.parent.mkdir(exist_ok=True, parents=True)
1187+
file_path.touch()
1188+
content_input = file_path
1189+
1190+
result = _as_url(content_input, default_mime_type)
1191+
if is_exact_match:
1192+
assert result == expected
1193+
else:
1194+
assert result.startswith(expected)

tests/test_inference_providers.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242
from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask
4343
from huggingface_hub.inference._providers.nscale import NscaleConversationalTask, NscaleTextToImageTask
4444
from huggingface_hub.inference._providers.openai import OpenAIConversationalTask
45-
from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask
45+
from huggingface_hub.inference._providers.replicate import (
46+
ReplicateImageToImageTask,
47+
ReplicateTask,
48+
ReplicateTextToSpeechTask,
49+
)
4650
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
4751
from huggingface_hub.inference._providers.together import TogetherTextToImageTask
4852

@@ -1057,6 +1061,44 @@ def test_get_response_single_output(self, mocker):
10571061
mock.return_value.get.assert_called_once_with("https://example.com/image.jpg")
10581062
assert response == mock.return_value.get.return_value.content
10591063

1064+
def test_image_to_image_payload(self):
1065+
helper = ReplicateImageToImageTask()
1066+
dummy_image = b"dummy image data"
1067+
encoded_image = base64.b64encode(dummy_image).decode("utf-8")
1068+
image_uri = f"data:image/jpeg;base64,{encoded_image}"
1069+
1070+
# No model version
1071+
payload = helper._prepare_payload_as_dict(
1072+
dummy_image,
1073+
{"num_inference_steps": 20},
1074+
InferenceProviderMapping(
1075+
provider="replicate",
1076+
hf_model_id="google/gemini-pro-vision",
1077+
providerId="google/gemini-pro-vision",
1078+
task="image-to-image",
1079+
status="live",
1080+
),
1081+
)
1082+
assert payload == {
1083+
"input": {"input_image": image_uri, "num_inference_steps": 20},
1084+
}
1085+
1086+
payload = helper._prepare_payload_as_dict(
1087+
dummy_image,
1088+
{"num_inference_steps": 20},
1089+
InferenceProviderMapping(
1090+
provider="replicate",
1091+
hf_model_id="google/gemini-pro-vision",
1092+
providerId="google/gemini-pro-vision:123456",
1093+
task="image-to-image",
1094+
status="live",
1095+
),
1096+
)
1097+
assert payload == {
1098+
"input": {"input_image": image_uri, "num_inference_steps": 20},
1099+
"version": "123456",
1100+
}
1101+
10601102

10611103
class TestSambanovaProvider:
10621104
def test_prepare_url_conversational(self):

0 commit comments

Comments
 (0)