Skip to content

Commit 6234469

Browse files
committed
better snippet for image-to-image
1 parent 060cd21 commit 6234469

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

packages/tasks/src/snippets/inputs.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ const inputsImageClassification = () => `"cats.jpg"`;
8686

8787
const inputsImageToText = () => `"cats.jpg"`;
8888

89+
const inputsImageToImage = () => `{
90+
"image": "cat.png",
91+
"prompt": "Turn the cat into a tiger."
92+
}`;
93+
8994
const inputsImageSegmentation = () => `"cats.jpg"`;
9095

9196
const inputsObjectDetection = () => `"cats.jpg"`;
@@ -118,6 +123,7 @@ const modelInputSnippets: {
118123
"fill-mask": inputsFillMask,
119124
"image-classification": inputsImageClassification,
120125
"image-to-text": inputsImageToText,
126+
"image-to-image": inputsImageToImage,
121127
"image-segmentation": inputsImageSegmentation,
122128
"object-detection": inputsObjectDetection,
123129
"question-answering": inputsQuestionAnswering,

packages/tasks/src/snippets/python.spec.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,48 @@ output = query({
175175
})`);
176176
});
177177

178+
it("image-to-image", async () => {
179+
const model: ModelDataMinimal = {
180+
id: "stabilityai/stable-diffusion-xl-refiner-1.0",
181+
pipeline_tag: "image-to-image",
182+
tags: [],
183+
inference: "",
184+
};
185+
const snippets = getPythonInferenceSnippet(model, "api_token") as InferenceSnippet[];
186+
187+
expect(snippets.length).toEqual(2);
188+
189+
expect(snippets[0].client).toEqual("huggingface_hub");
190+
expect(snippets[0].content).toEqual(`from huggingface_hub import InferenceClient
191+
client = InferenceClient("stabilityai/stable-diffusion-xl-refiner-1.0", token="api_token")
192+
193+
# output is a PIL.Image object
194+
image = client.image_to_image("cat.png", prompt="Turn the cat into a tiger.")`);
195+
196+
expect(snippets[1].client).toEqual("requests");
197+
expect(snippets[1].content).toEqual(`import requests
198+
199+
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-refiner-1.0"
200+
headers = {"Authorization": "Bearer api_token"}
201+
202+
def query(payload):
203+
with open(payload["inputs"], "rb") as f:
204+
img = f.read()
205+
payload["inputs"] = base64.b64encode(img).decode("utf-8")
206+
response = requests.post(API_URL, headers=headers, json=payload)
207+
return response.content
208+
209+
image_bytes = query({
210+
"inputs": "cat.png",
211+
"parameters": {"prompt": "Turn the cat into a tiger."},
212+
})
213+
214+
# You can access the image with PIL.Image for example
215+
import io
216+
from PIL import Image
217+
image = Image.open(io.BytesIO(image_bytes))`);
218+
});
219+
178220
it("text-to-image", async () => {
179221
const model: ModelDataMinimal = {
180222
id: "black-forest-labs/FLUX.1-schnell",

packages/tasks/src/snippets/python.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,39 @@ output = query({
182182
];
183183
};
184184

185+
const snippetImageToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => {
186+
const inputsAsStr = getModelInputSnippet(model) as string;
187+
const inputsAsObj = JSON.parse(inputsAsStr);
188+
189+
return [
190+
{
191+
client: "huggingface_hub",
192+
content: `${snippetImportInferenceClient(model, accessToken)}
193+
# output is a PIL.Image object
194+
image = client.image_to_image("${inputsAsObj.image}", prompt="${inputsAsObj.prompt}")`,
195+
},
196+
{
197+
client: "requests",
198+
content: `def query(payload):
199+
with open(payload["inputs"], "rb") as f:
200+
img = f.read()
201+
payload["inputs"] = base64.b64encode(img).decode("utf-8")
202+
response = requests.post(API_URL, headers=headers, json=payload)
203+
return response.content
204+
205+
image_bytes = query({
206+
"inputs": "${inputsAsObj.image}",
207+
"parameters": {"prompt": "${inputsAsObj.prompt}"},
208+
})
209+
210+
# You can access the image with PIL.Image for example
211+
import io
212+
from PIL import Image
213+
image = Image.open(io.BytesIO(image_bytes))`,
214+
},
215+
];
216+
};
217+
185218
const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({
186219
content: `def query(payload):
187220
response = requests.post(API_URL, headers=headers, json=payload)
@@ -312,6 +345,7 @@ const pythonSnippets: Partial<
312345
"image-segmentation": snippetFile,
313346
"document-question-answering": snippetDocumentQuestionAnswering,
314347
"image-to-text": snippetFile,
348+
"image-to-image": snippetImageToImage,
315349
"zero-shot-image-classification": snippetZeroShotImageClassification,
316350
};
317351

0 commit comments

Comments
 (0)