Skip to content
16 changes: 15 additions & 1 deletion packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import { isUrl } from "../lib/isUrl";
import type { BodyParams, HeaderParams, UrlParams } from "../types";
import { omit } from "../utils/omit";
import { TaskProviderHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper";

export interface ReplicateOutput {
output?: string | string[];
}
Expand Down Expand Up @@ -63,6 +62,21 @@ abstract class ReplicateTask extends TaskProviderHelper {
}

export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper {
override preparePayload(params: BodyParams): Record<string, unknown> {
return {
input: {
...omit(params.args, ["inputs", "parameters"]),
...(params.args.parameters as Record<string, unknown>),
prompt: params.args.inputs,
lora_weights:
params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath
? `https://huggingface.co/${params.mapping.hfModelId}`
: undefined,
},
version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
};
}

override async getResponse(
res: ReplicateOutput | Blob,
url?: string,
Expand Down
11 changes: 11 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,17 @@ describe.skip("InferenceClient", () => {
expect(res).toBeInstanceOf(Blob);
});

// Runs black-forest-labs/flux-dev-lora under the hood
// with fofr/flux-80s-cyberpunk as the LoRA weights
it("textToImage - all Flux LoRAs", async () => {
const res = await client.textToImage({
model: "fofr/flux-80s-cyberpunk",
provider: "replicate",
inputs: "style of 80s cyberpunk, a portrait photo",
});
expect(res).toBeInstanceOf(Blob);
});

it("textToImage canonical - stabilityai/stable-diffusion-3.5-large-turbo", async () => {
const res = await client.textToImage({
model: "stabilityai/stable-diffusion-3.5-large-turbo",
Expand Down