Skip to content

Commit 8fb34f4

Browse files
authored
[Inference] Snippets for image-to-image task (#1565)
# TL;DR Update the inference snippet generator to properly handle image-to-image task
1 parent e14dee8 commit 8fb34f4

File tree

19 files changed

+335
-6
lines changed

19 files changed

+335
-6
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
const image = fs.readFileSync("{{inputs.asObj.inputs}}");
2+
3+
async function query(data) {
4+
const response = await fetch(
5+
"{{ fullUrl }}",
6+
{
7+
headers: {
8+
Authorization: "{{ authorizationHeader }}",
9+
"Content-Type": "image/jpeg",
10+
{% if billTo %}
11+
"X-HF-Bill-To": "{{ billTo }}",
12+
{% endif %} },
13+
method: "POST",
14+
body: {
15+
"inputs": `data:image/png;base64,${data.inputs.encode("base64")}`,
16+
"parameters": data.parameters,
17+
}
18+
}
19+
);
20+
const result = await response.json();
21+
return result;
22+
}
23+
24+
query({
25+
inputs: image,
26+
parameters: {
27+
prompt: "{{ inputs.asObj.parameters.prompt }}",
28+
}
29+
}).then((response) => {
30+
console.log(JSON.stringify(response));
31+
});
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import { InferenceClient } from "@huggingface/inference";
2+
3+
const client = new InferenceClient("{{ accessToken }}");
4+
5+
const data = fs.readFileSync("{{inputs.asObj.inputs}}");
6+
7+
const image = await client.imageToImage({
8+
{% if endpointUrl %}
9+
endpointUrl: "{{ endpointUrl }}",
10+
{% endif %}
11+
provider: "{{provider}}",
12+
model: "{{model.id}}",
13+
inputs: data,
14+
parameters: { prompt: "{{inputs.asObj.parameters.prompt}}", },
15+
}{% if billTo %}, {
16+
billTo: "{{ billTo }}",
17+
}{% endif %});
18+
/// Use the generated image (it's a Blob)
19+
// For example, you can save it to a file or display it in an image element
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{%if provider == "fal-ai" %}
2+
import fal_client
3+
import base64
4+
5+
def on_queue_update(update):
6+
if isinstance(update, fal_client.InProgress):
7+
for log in update.logs:
8+
print(log["message"])
9+
10+
with open("{{inputs.asObj.inputs}}", "rb") as image_file:
11+
image_base_64 = base64.b64encode(image_file.read()).decode('utf-8')
12+
13+
result = fal_client.subscribe(
14+
"fal-ai/flux-kontext/dev",
15+
arguments={
16+
"prompt": f"data:image/png;base64,{image_base_64}",
17+
"image_url": "{{ providerInputs.asObj.inputs }}",
18+
},
19+
with_logs=True,
20+
on_queue_update=on_queue_update,
21+
)
22+
print(result)
23+
{%endif%}
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
with open("{{ inputs.asObj.inputs }}", "rb") as image_file:
2+
input_image = image_file.read()
3+
14
# output is a PIL.Image object
25
image = client.image_to_image(
3-
"{{ inputs.asObj.inputs }}",
6+
input_image,
47
prompt="{{ inputs.asObj.parameters.prompt }}",
58
model="{{ model.id }}",
69
)

packages/inference/src/snippets/templates/python/requests/imageToImage.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
with open("{{inputs.asObj.inputs}}", "rb") as image_file:
2+
image_base_64 = base64.b64encode(image_file.read()).decode('utf-8')
3+
14
def query(payload):
25
with open(payload["inputs"], "rb") as f:
36
img = f.read()

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,12 @@ const TEST_CASES: {
133133
testName: "image-to-image",
134134
task: "image-to-image",
135135
model: {
136-
id: "stabilityai/stable-diffusion-xl-refiner-1.0",
136+
id: "black-forest-labs/FLUX.1-Kontext-dev",
137137
pipeline_tag: "image-to-image",
138138
tags: [],
139139
inference: "",
140140
},
141-
providers: ["hf-inference"],
141+
providers: ["fal-ai", "replicate", "hf-inference"],
142142
},
143143
{
144144
testName: "tabular",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
const image = fs.readFileSync("cat.png");
2+
3+
async function query(data) {
4+
const response = await fetch(
5+
"https://router.huggingface.co/fal-ai/<fal-ai alias for black-forest-labs/FLUX.1-Kontext-dev>?_subdomain=queue",
6+
{
7+
headers: {
8+
Authorization: `Bearer ${process.env.HF_TOKEN}`,
9+
"Content-Type": "image/jpeg",
10+
},
11+
method: "POST",
12+
body: {
13+
"inputs": `data:image/png;base64,${data.inputs.encode("base64")}`,
14+
"parameters": data.parameters,
15+
}
16+
}
17+
);
18+
const result = await response.json();
19+
return result;
20+
}
21+
22+
query({
23+
inputs: image,
24+
parameters: {
25+
prompt: "Turn the cat into a tiger.",
26+
}
27+
}).then((response) => {
28+
console.log(JSON.stringify(response));
29+
});
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
const image = fs.readFileSync("cat.png");
2+
3+
async function query(data) {
4+
const response = await fetch(
5+
"https://router.huggingface.co/hf-inference/models/black-forest-labs/FLUX.1-Kontext-dev",
6+
{
7+
headers: {
8+
Authorization: `Bearer ${process.env.HF_TOKEN}`,
9+
"Content-Type": "image/jpeg",
10+
},
11+
method: "POST",
12+
body: {
13+
"inputs": `data:image/png;base64,${data.inputs.encode("base64")}`,
14+
"parameters": data.parameters,
15+
}
16+
}
17+
);
18+
const result = await response.json();
19+
return result;
20+
}
21+
22+
query({
23+
inputs: image,
24+
parameters: {
25+
prompt: "Turn the cat into a tiger.",
26+
}
27+
}).then((response) => {
28+
console.log(JSON.stringify(response));
29+
});
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
const image = fs.readFileSync("cat.png");
2+
3+
async function query(data) {
4+
const response = await fetch(
5+
"https://router.huggingface.co/replicate/v1/models/<replicate alias for black-forest-labs/FLUX.1-Kontext-dev>/predictions",
6+
{
7+
headers: {
8+
Authorization: `Bearer ${process.env.HF_TOKEN}`,
9+
"Content-Type": "image/jpeg",
10+
},
11+
method: "POST",
12+
body: {
13+
"inputs": `data:image/png;base64,${data.inputs.encode("base64")}`,
14+
"parameters": data.parameters,
15+
}
16+
}
17+
);
18+
const result = await response.json();
19+
return result;
20+
}
21+
22+
query({
23+
inputs: image,
24+
parameters: {
25+
prompt: "Turn the cat into a tiger.",
26+
}
27+
}).then((response) => {
28+
console.log(JSON.stringify(response));
29+
});
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import { InferenceClient } from "@huggingface/inference";
2+
3+
const client = new InferenceClient(process.env.HF_TOKEN);
4+
5+
const data = fs.readFileSync("cat.png");
6+
7+
const image = await client.imageToImage({
8+
provider: "fal-ai",
9+
model: "black-forest-labs/FLUX.1-Kontext-dev",
10+
inputs: data,
11+
parameters: { prompt: "Turn the cat into a tiger.", },
12+
});
13+
/// Use the generated image (it's a Blob)
14+
// For example, you can save it to a file or display it in an image element

0 commit comments

Comments
 (0)