Skip to content

Commit 6755e81

Browse files
committed
add image-to-image for replicate
1 parent 86ec6ef commit 6755e81

File tree

3 files changed

+79
-1
lines changed

3 files changed

+79
-1
lines changed

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
138138
"text-to-image": new Replicate.ReplicateTextToImageTask(),
139139
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
140140
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
141+
"image-to-image": new Replicate.ReplicateImageToImageTask(),
141142
},
142143
sambanova: {
143144
conversational: new Sambanova.SambanovaConversationalTask(),

packages/inference/src/providers/replicate.ts

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ import { InferenceClientProviderOutputError } from "../errors.js";
1818
import { isUrl } from "../lib/isUrl.js";
1919
import type { BodyParams, HeaderParams, UrlParams } from "../types.js";
2020
import { omit } from "../utils/omit.js";
21-
import { TaskProviderHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper.js";
21+
import { TaskProviderHelper, type ImageToImageTaskHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper.js";
22+
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
23+
import { base64FromBytes } from "../utils/base64FromBytes.js";
2224
export interface ReplicateOutput {
2325
output?: string | string[];
2426
}
@@ -152,3 +154,65 @@ export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVid
152154
throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-video API");
153155
}
154156
}
157+
158+
export class ReplicateImageToImageTask extends ReplicateTask implements ImageToImageTaskHelper {
159+
override preparePayload(params: BodyParams): Record<string, unknown> {
160+
const inputs = params.args.inputs as Blob;
161+
const parameters = params.args.parameters as Record<string, unknown> | undefined;
162+
163+
return {
164+
input: {
165+
...omit(params.args, ["inputs", "parameters"]),
166+
...(parameters || {}),
167+
prompt: parameters?.prompt || "",
168+
image: inputs, // This will be processed in preparePayloadAsync
169+
},
170+
version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
171+
};
172+
}
173+
174+
async preparePayloadAsync(args: ImageToImageArgs): Promise<import("../types.js").RequestArgs> {
175+
const { inputs, parameters, ...restArgs } = args;
176+
177+
// Convert Blob to base64 data URL
178+
const bytes = new Uint8Array(await inputs.arrayBuffer());
179+
const base64 = base64FromBytes(bytes);
180+
const imageInput = `data:${inputs.type || "image/jpeg"};base64,${base64}`;
181+
182+
return {
183+
...restArgs,
184+
inputs: imageInput,
185+
parameters: {
186+
...parameters,
187+
image: imageInput,
188+
},
189+
};
190+
}
191+
192+
override async getResponse(response: ReplicateOutput): Promise<Blob> {
193+
if (
194+
typeof response === "object" &&
195+
!!response &&
196+
"output" in response &&
197+
Array.isArray(response.output) &&
198+
response.output.length > 0 &&
199+
typeof response.output[0] === "string"
200+
) {
201+
const urlResponse = await fetch(response.output[0]);
202+
return await urlResponse.blob();
203+
}
204+
205+
if (
206+
typeof response === "object" &&
207+
!!response &&
208+
"output" in response &&
209+
typeof response.output === "string" &&
210+
isUrl(response.output)
211+
) {
212+
const urlResponse = await fetch(response.output);
213+
return await urlResponse.blob();
214+
}
215+
216+
throw new InferenceClientProviderOutputError("Received malformed response from Replicate image-to-image API");
217+
}
218+
}

packages/inference/test/InferenceClient.spec.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
chatCompletion,
88
chatCompletionStream,
99
HfInference,
10+
imageToImage,
1011
InferenceClient,
1112
textGeneration,
1213
textToImage,
@@ -1277,6 +1278,18 @@ describe.skip("InferenceClient", () => {
12771278

12781279
expect(res).toBeInstanceOf(Blob);
12791280
});
1281+
1282+
it("imageToImage - FLUX Kontext Dev", async () => {
1283+
const res = await client.imageToImage({
1284+
model: "black-forest-labs/flux-kontext-dev",
1285+
provider: "replicate",
1286+
inputs: new Blob([readTestFile("stormtrooper_depth.png")], { type: "image/png" }),
1287+
parameters: {
1288+
prompt: "Change the stormtrooper armor to golden color while keeping the same pose and helmet design",
1289+
},
1290+
});
1291+
expect(res).toBeInstanceOf(Blob);
1292+
});
12801293
},
12811294
TIMEOUT
12821295
);

0 commit comments

Comments
 (0)