Skip to content

Commit e14dee8

Browse files
authored
[Inference] Add image-to-image support for fal-ai (#1563)
1 parent 3b23dfa commit e14dee8

File tree

3 files changed

+136
-58
lines changed

3 files changed

+136
-58
lines changed

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
6464
"text-to-image": new FalAI.FalAITextToImageTask(),
6565
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
6666
"text-to-video": new FalAI.FalAITextToVideoTask(),
67+
"image-to-image": new FalAI.FalAIImageToImageTask(),
6768
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
6869
},
6970
"featherless-ai": {

packages/inference/src/providers/fal-ai.ts

Lines changed: 132 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ import { base64FromBytes } from "../utils/base64FromBytes.js";
1818

1919
import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
2020
import { isUrl } from "../lib/isUrl.js";
21-
import type { BodyParams, HeaderParams, ModelId, RequestArgs, UrlParams } from "../types.js";
21+
import type { BodyParams, HeaderParams, InferenceTask, ModelId, RequestArgs, UrlParams } from "../types.js";
2222
import { delay } from "../utils/delay.js";
2323
import { omit } from "../utils/omit.js";
24+
import type { ImageToImageTaskHelper } from "./providerHelper.js";
2425
import {
2526
type AutomaticSpeechRecognitionTaskHelper,
2627
TaskProviderHelper,
@@ -34,6 +35,7 @@ import {
3435
InferenceClientProviderApiError,
3536
InferenceClientProviderOutputError,
3637
} from "../errors.js";
38+
import type { ImageToImageArgs } from "../tasks/index.js";
3739

3840
export interface FalAiQueueOutput {
3941
request_id: string;
@@ -82,6 +84,75 @@ abstract class FalAITask extends TaskProviderHelper {
8284
}
8385
}
8486

87+
abstract class FalAiQueueTask extends FalAITask {
88+
abstract task: InferenceTask;
89+
90+
async getResponseFromQueueApi(
91+
response: FalAiQueueOutput,
92+
url?: string,
93+
headers?: Record<string, string>
94+
): Promise<unknown> {
95+
if (!url || !headers) {
96+
throw new InferenceClientInputError(`URL and headers are required for ${this.task} task`);
97+
}
98+
const requestId = response.request_id;
99+
if (!requestId) {
100+
throw new InferenceClientProviderOutputError(
101+
`Received malformed response from Fal.ai ${this.task} API: no request ID found in the response`
102+
);
103+
}
104+
let status = response.status;
105+
106+
const parsedUrl = new URL(url);
107+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
108+
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
109+
}`;
110+
111+
// extracting the provider model id for status and result urls
112+
// from the response as it might be different from the mapped model in `url`
113+
const modelId = new URL(response.response_url).pathname;
114+
const queryParams = parsedUrl.search;
115+
116+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
117+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
118+
119+
while (status !== "COMPLETED") {
120+
await delay(500);
121+
const statusResponse = await fetch(statusUrl, { headers });
122+
123+
if (!statusResponse.ok) {
124+
throw new InferenceClientProviderApiError(
125+
"Failed to fetch response status from fal-ai API",
126+
{ url: statusUrl, method: "GET" },
127+
{
128+
requestId: statusResponse.headers.get("x-request-id") ?? "",
129+
status: statusResponse.status,
130+
body: await statusResponse.text(),
131+
}
132+
);
133+
}
134+
try {
135+
status = (await statusResponse.json()).status;
136+
} catch (error) {
137+
throw new InferenceClientProviderOutputError(
138+
"Failed to parse status response from fal-ai API: received malformed response"
139+
);
140+
}
141+
}
142+
143+
const resultResponse = await fetch(resultUrl, { headers });
144+
let result: unknown;
145+
try {
146+
result = await resultResponse.json();
147+
} catch (error) {
148+
throw new InferenceClientProviderOutputError(
149+
"Failed to parse result response from fal-ai API: received malformed response"
150+
);
151+
}
152+
return result;
153+
}
154+
}
155+
85156
function buildLoraPath(modelId: ModelId, adapterWeightsPath: string): string {
86157
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
87158
}
@@ -130,21 +201,29 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
130201
}
131202
}
132203

133-
export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
204+
export class FalAIImageToImageTask extends FalAiQueueTask implements ImageToImageTaskHelper {
205+
task: InferenceTask;
134206
constructor() {
135207
super("https://queue.fal.run");
208+
this.task = "image-to-image";
136209
}
210+
137211
override makeRoute(params: UrlParams): string {
138212
if (params.authMethod !== "provider-key") {
139213
return `/${params.model}?_subdomain=queue`;
140214
}
141215
return `/${params.model}`;
142216
}
143-
override preparePayload(params: BodyParams): Record<string, unknown> {
217+
218+
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
219+
const mimeType = args.inputs instanceof Blob ? args.inputs.type : "image/png";
144220
return {
145-
...omit(params.args, ["inputs", "parameters"]),
146-
...(params.args.parameters as Record<string, unknown>),
147-
prompt: params.args.inputs,
221+
...omit(args, ["inputs", "parameters"]),
222+
image_url: `data:${mimeType};base64,${base64FromBytes(
223+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
224+
)}`,
225+
...args.parameters,
226+
...args,
148227
};
149228
}
150229

@@ -153,63 +232,59 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe
153232
url?: string,
154233
headers?: Record<string, string>
155234
): Promise<Blob> {
156-
if (!url || !headers) {
157-
throw new InferenceClientInputError("URL and headers are required for text-to-video task");
158-
}
159-
const requestId = response.request_id;
160-
if (!requestId) {
235+
const result = await this.getResponseFromQueueApi(response, url, headers);
236+
237+
if (
238+
typeof result === "object" &&
239+
!!result &&
240+
"images" in result &&
241+
Array.isArray(result.images) &&
242+
result.images.length > 0 &&
243+
typeof result.images[0] === "object" &&
244+
!!result.images[0] &&
245+
"url" in result.images[0] &&
246+
typeof result.images[0].url === "string" &&
247+
isUrl(result.images[0].url)
248+
) {
249+
const urlResponse = await fetch(result.images[0].url);
250+
return await urlResponse.blob();
251+
} else {
161252
throw new InferenceClientProviderOutputError(
162-
"Received malformed response from Fal.ai text-to-video API: no request ID found in the response"
253+
`Received malformed response from Fal.ai image-to-image API: expected { images: Array<{ url: string }> } result format, got instead: ${JSON.stringify(
254+
result
255+
)}`
163256
);
164257
}
165-
let status = response.status;
166-
167-
const parsedUrl = new URL(url);
168-
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
169-
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
170-
}`;
171-
172-
// extracting the provider model id for status and result urls
173-
// from the response as it might be different from the mapped model in `url`
174-
const modelId = new URL(response.response_url).pathname;
175-
const queryParams = parsedUrl.search;
176-
177-
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
178-
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
179-
180-
while (status !== "COMPLETED") {
181-
await delay(500);
182-
const statusResponse = await fetch(statusUrl, { headers });
258+
}
259+
}
183260

184-
if (!statusResponse.ok) {
185-
throw new InferenceClientProviderApiError(
186-
"Failed to fetch response status from fal-ai API",
187-
{ url: statusUrl, method: "GET" },
188-
{
189-
requestId: statusResponse.headers.get("x-request-id") ?? "",
190-
status: statusResponse.status,
191-
body: await statusResponse.text(),
192-
}
193-
);
194-
}
195-
try {
196-
status = (await statusResponse.json()).status;
197-
} catch (error) {
198-
throw new InferenceClientProviderOutputError(
199-
"Failed to parse status response from fal-ai API: received malformed response"
200-
);
201-
}
261+
export class FalAITextToVideoTask extends FalAiQueueTask implements TextToVideoTaskHelper {
262+
task: InferenceTask;
263+
constructor() {
264+
super("https://queue.fal.run");
265+
this.task = "text-to-video";
266+
}
267+
override makeRoute(params: UrlParams): string {
268+
if (params.authMethod !== "provider-key") {
269+
return `/${params.model}?_subdomain=queue`;
202270
}
271+
return `/${params.model}`;
272+
}
273+
override preparePayload(params: BodyParams): Record<string, unknown> {
274+
return {
275+
...omit(params.args, ["inputs", "parameters"]),
276+
...(params.args.parameters as Record<string, unknown>),
277+
prompt: params.args.inputs,
278+
};
279+
}
280+
281+
override async getResponse(
282+
response: FalAiQueueOutput,
283+
url?: string,
284+
headers?: Record<string, string>
285+
): Promise<Blob> {
286+
const result = await this.getResponseFromQueueApi(response, url, headers);
203287

204-
const resultResponse = await fetch(resultUrl, { headers });
205-
let result: unknown;
206-
try {
207-
result = await resultResponse.json();
208-
} catch (error) {
209-
throw new InferenceClientProviderOutputError(
210-
"Failed to parse result response from fal-ai API: received malformed response"
211-
);
212-
}
213288
if (
214289
typeof result === "object" &&
215290
!!result &&

packages/inference/src/tasks/cv/imageToImage.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js";
33
import { getProviderHelper } from "../../lib/getProviderHelper.js";
44
import type { BaseArgs, Options } from "../../types.js";
55
import { innerRequest } from "../../utils/request.js";
6+
import { makeRequestOptions } from "../../lib/makeRequestOptions.js";
67

78
export type ImageToImageArgs = BaseArgs & ImageToImageInput;
89

@@ -18,5 +19,6 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
1819
...options,
1920
task: "image-to-image",
2021
});
21-
return providerHelper.getResponse(res);
22+
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "image-to-image" });
23+
return providerHelper.getResponse(res, url, info.headers as Record<string, string>);
2224
}

0 commit comments

Comments
 (0)