Skip to content
14 changes: 13 additions & 1 deletion packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import { isUrl } from "../lib/isUrl";
import type { BodyParams, HeaderParams, UrlParams } from "../types";
import { omit } from "../utils/omit";
import { TaskProviderHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper";

export interface ReplicateOutput {
output?: string | string[];
}
Expand Down Expand Up @@ -62,7 +61,20 @@ abstract class ReplicateTask extends TaskProviderHelper {
}
}



export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper {
override preparePayload(params: BodyParams): Record<string, unknown> {
const payload = super.preparePayload(params);

// For Flux LoRAs, use black-forest-labs/flux-dev-lora
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
payload.input.lora_weights = `https://huggingface.co/${params.mapping.hfModelId}`;
}

return payload;
}

override async getResponse(
res: ReplicateOutput | Blob,
url?: string,
Expand Down
11 changes: 11 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,17 @@ describe.skip("InferenceClient", () => {
expect(res).toBeInstanceOf(Blob);
});

// Runs black-forest-labs/flux-dev-lora under the hood
// with fofr/flux-80s-cyberpunk as the LoRA weights
it("textToImage - all Flux LoRAs", async () => {
const res = await client.textToImage({
model: "fofr/flux-80s-cyberpunk",
provider: "replicate",
inputs: "style of 80s cyberpunk, a portrait photo",
});
expect(res).toBeInstanceOf(Blob);
});

it("textToImage canonical - stabilityai/stable-diffusion-3.5-large-turbo", async () => {
const res = await client.textToImage({
model: "stabilityai/stable-diffusion-3.5-large-turbo",
Expand Down
Loading