diff --git a/packages/inference/src/providers/replicate.ts b/packages/inference/src/providers/replicate.ts index 5a732dcf5e..7785039346 100644 --- a/packages/inference/src/providers/replicate.ts +++ b/packages/inference/src/providers/replicate.ts @@ -19,7 +19,6 @@ import { isUrl } from "../lib/isUrl"; import type { BodyParams, HeaderParams, UrlParams } from "../types"; import { omit } from "../utils/omit"; import { TaskProviderHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper"; - export interface ReplicateOutput { output?: string | string[]; } @@ -63,6 +62,21 @@ abstract class ReplicateTask extends TaskProviderHelper { } export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper { + override preparePayload(params: BodyParams): Record { + return { + input: { + ...omit(params.args, ["inputs", "parameters"]), + ...(params.args.parameters as Record), + prompt: params.args.inputs, + lora_weights: + params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath + ? `https://huggingface.co/${params.mapping.hfModelId}` + : undefined, + }, + version: params.model.includes(":") ? params.model.split(":")[1] : undefined, + }; + } + override async getResponse( res: ReplicateOutput | Blob, url?: string, diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 62f41c5cb3..231806db4a 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1160,6 +1160,17 @@ describe.skip("InferenceClient", () => { expect(res).toBeInstanceOf(Blob); }); + // Runs black-forest-labs/flux-dev-lora under the hood + // with fofr/flux-80s-cyberpunk as the LoRA weights + it("textToImage - all Flux LoRAs", async () => { + const res = await client.textToImage({ + model: "fofr/flux-80s-cyberpunk", + provider: "replicate", + inputs: "style of 80s cyberpunk, a portrait photo", + }); + expect(res).toBeInstanceOf(Blob); + }); + it("textToImage canonical - stabilityai/stable-diffusion-3.5-large-turbo", async () => { const res = await client.textToImage({ model: "stabilityai/stable-diffusion-3.5-large-turbo",