Skip to content

Commit d6423bb

Browse files
[inference] add snippets for image-to-video (#1678)
Add inference snippets for the text-to-video task --------- Co-authored-by: Celina Hanouti <[email protected]>
1 parent 6c33d26 commit d6423bb

File tree

17 files changed

+205
-12
lines changed

17 files changed

+205
-12
lines changed

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ const snippets: Partial<
389389
"image-text-to-text": snippetGenerator("conversational"),
390390
"image-to-image": snippetGenerator("imageToImage", prepareImageToImageInput),
391391
"image-to-text": snippetGenerator("basicImage"),
392+
"image-to-video": snippetGenerator("imageToVideo", prepareImageToImageInput),
392393
"object-detection": snippetGenerator("basicImage"),
393394
"question-answering": snippetGenerator("questionAnswering", prepareQuestionAnsweringInput),
394395
"sentence-similarity": snippetGenerator("basic"),
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
const image = fs.readFileSync("{{inputs.asObj.inputs}}");
2+
3+
async function query(data) {
4+
const response = await fetch(
5+
"{{ fullUrl }}",
6+
{
7+
headers: {
8+
Authorization: "{{ authorizationHeader }}",
9+
"Content-Type": "image/jpeg",
10+
{% if billTo %}
11+
"X-HF-Bill-To": "{{ billTo }}",
12+
{% endif %} },
13+
method: "POST",
14+
body: {
15+
"image_url": `data:image/png;base64,${data.image.encode("base64")}`,
16+
"prompt": data.prompt,
17+
}
18+
}
19+
);
20+
const result = await response.json();
21+
return result;
22+
}
23+
24+
query({
25+
"image": image,
26+
"prompt": "{{inputs.asObj.parameters.prompt}}",
27+
}).then((response) => {
28+
// Use video
29+
});
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import { InferenceClient } from "@huggingface/inference";
2+
3+
const client = new InferenceClient("{{ accessToken }}");
4+
5+
const data = fs.readFileSync("{{inputs.asObj.inputs}}");
6+
7+
const video = await client.imageToVideo({
8+
{% if endpointUrl %}
9+
endpointUrl: "{{ endpointUrl }}",
10+
{% endif %}
11+
provider: "{{provider}}",
12+
model: "{{model.id}}",
13+
inputs: data,
14+
parameters: { prompt: "{{inputs.asObj.parameters.prompt}}", },
15+
}{% if billTo %}, {
16+
billTo: "{{ billTo }}",
17+
}{% endif %});
18+
19+
/// Use the generated video (it's a Blob)
20+
// For example, you can save it to a file or display it in a video element
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{%if provider == "fal-ai" %}
2+
import fal_client
3+
import base64
4+
5+
def on_queue_update(update):
6+
if isinstance(update, fal_client.InProgress):
7+
for log in update.logs:
8+
print(log["message"])
9+
10+
with open("{{inputs.asObj.inputs}}", "rb") as image_file:
11+
image_base_64 = base64.b64encode(image_file.read()).decode('utf-8')
12+
13+
result = fal_client.subscribe(
14+
"{{model.id}}",
15+
arguments={
16+
"image_url": f"data:image/png;base64,{image_base_64}",
17+
"prompt": "{{inputs.asObj.parameters.prompt}}",
18+
},
19+
with_logs=True,
20+
on_queue_update=on_queue_update,
21+
)
22+
print(result)
23+
{%endif%}

packages/inference/src/snippets/templates/python/huggingface_hub/imageToImage.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ image = client.image_to_image(
66
input_image,
77
prompt="{{ inputs.asObj.parameters.prompt }}",
88
model="{{ model.id }}",
9-
)
9+
)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
with open("{{ inputs.asObj.inputs }}", "rb") as image_file:
2+
input_image = image_file.read()
3+
4+
video = client.image_to_video(
5+
input_image,
6+
prompt="{{ inputs.asObj.parameters.prompt }}",
7+
model="{{ model.id }}",
8+
)

packages/inference/src/snippets/templates/python/requests/imageToImage.jinja

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
with open("{{inputs.asObj.inputs}}", "rb") as image_file:
2-
image_base_64 = base64.b64encode(image_file.read()).decode('utf-8')
31

42
def query(payload):
53
with open(payload["inputs"], "rb") as f:
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
def query(payload):
3+
with open(payload["inputs"], "rb") as f:
4+
img = f.read()
5+
payload["inputs"] = base64.b64encode(img).decode("utf-8")
6+
response = requests.post(API_URL, headers=headers, json=payload)
7+
return response.content
8+
9+
video_bytes = query({
10+
{{ inputs.asJsonString }}
11+
})

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,17 @@ const TEST_CASES: {
140140
},
141141
providers: ["fal-ai", "replicate", "hf-inference"],
142142
},
143+
{
144+
testName: "image-to-video",
145+
task: "image-to-video",
146+
model: {
147+
id: "Wan-AI/Wan2.2-I2V-A14B",
148+
pipeline_tag: "image-to-video",
149+
tags: [],
150+
inference: "",
151+
},
152+
providers: ["fal-ai"],
153+
},
143154
{
144155
testName: "tabular",
145156
task: "tabular-classification",

packages/tasks-gen/snippets-fixtures/image-to-image/python/requests/0.fal-ai.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
88
}
99

10-
with open("cat.png", "rb") as image_file:
11-
image_base_64 = base64.b64encode(image_file.read()).decode('utf-8')
12-
1310
def query(payload):
1411
with open(payload["inputs"], "rb") as f:
1512
img = f.read()

0 commit comments

Comments
 (0)