Skip to content

Commit 963dfe1

Browse files
committed
Add support for image-to-video task type for Replicate
1 parent 4e3da78 commit 963dfe1

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
140140
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
141141
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
142142
"image-to-image": new Replicate.ReplicateImageToImageTask(),
143+
"image-to-video": new Replicate.ReplicateImageToVideoTask(),
143144
},
144145
sambanova: {
145146
conversational: new Sambanova.SambanovaConversationalTask(),

packages/inference/src/providers/replicate.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,50 @@ export class ReplicateImageToImageTask extends ReplicateTask implements ImageToI
213213
throw new InferenceClientProviderOutputError("Received malformed response from Replicate image-to-image API");
214214
}
215215
}
216+
217+
export class ReplicateImageToVideoTask extends ReplicateTask {
218+
// Synchronous: expects base64 string in args.inputs
219+
override preparePayload(params: BodyParams): Record<string, unknown> {
220+
const { args } = params;
221+
const { inputs, parameters } = args;
222+
return {
223+
input: {
224+
...omit(args, ["inputs", "parameters"]),
225+
...(parameters as Record<string, unknown>),
226+
inputs,
227+
},
228+
version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
229+
};
230+
}
231+
232+
// Asynchronous: handles Blob to base64 conversion
233+
async preparePayloadAsync(args: { inputs: Blob } & Record<string, unknown>): Promise<RequestArgs> {
234+
const { inputs, ...restArgs } = args;
235+
const bytes = new Uint8Array(await inputs.arrayBuffer());
236+
const base64 = base64FromBytes(bytes);
237+
const imageInput = `data:${inputs.type || "image/png"};base64,${base64}`;
238+
return {
239+
...restArgs,
240+
inputs: imageInput,
241+
};
242+
}
243+
244+
// Handle the response from Replicate
245+
override async getResponse(response: ReplicateOutput): Promise<Blob> {
246+
if (
247+
typeof response === "object" &&
248+
!!response &&
249+
"output" in response
250+
) {
251+
if (Array.isArray(response.output) && response.output.length > 0 && typeof response.output[0] === "string") {
252+
const urlResponse = await fetch(response.output[0]);
253+
return await urlResponse.blob();
254+
}
255+
if (typeof response.output === "string" && isUrl(response.output)) {
256+
const urlResponse = await fetch(response.output);
257+
return await urlResponse.blob();
258+
}
259+
}
260+
throw new InferenceClientProviderOutputError("Received malformed response from Replicate image-to-video API");
261+
}
262+
}

packages/inference/test/InferenceClient.spec.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,18 @@ describe.skip("InferenceClient", () => {
12891289
});
12901290
expect(res).toBeInstanceOf(Blob);
12911291
});
1292+
1293+
it("imageToVideo - MeiGen-AI/MeiGen-MultiTalk", async () => {
1294+
const res = await client.imageToVideo({
1295+
model: "MeiGen-AI/MeiGen-MultiTalk",
1296+
provider: "replicate",
1297+
inputs: new Blob([readTestFile("bird_canny.png")], { type: "image/png" }),
1298+
parameters: {
1299+
prompt: "A bird flying in the sky",
1300+
},
1301+
});
1302+
expect(res).toBeInstanceOf(Blob);
1303+
});
12921304
},
12931305
TIMEOUT
12941306
);

0 commit comments

Comments
 (0)