Skip to content

Commit b2721aa

Browse files
committed
image-to-iùage fal-ai
1 parent 86ec6ef commit b2721aa

File tree

3 files changed

+122
-4
lines changed

3 files changed

+122
-4
lines changed

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
6464
"text-to-image": new FalAI.FalAITextToImageTask(),
6565
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
6666
"text-to-video": new FalAI.FalAITextToVideoTask(),
67+
"image-to-image": new FalAI.FalAIImageToImageTask(),
6768
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
6869
},
6970
"featherless-ai": {

packages/inference/src/providers/fal-ai.ts

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ import { isUrl } from "../lib/isUrl.js";
2121
import type { BodyParams, HeaderParams, ModelId, RequestArgs, UrlParams } from "../types.js";
2222
import { delay } from "../utils/delay.js";
2323
import { omit } from "../utils/omit.js";
24+
import type {
25+
ImageToImageTaskHelper
26+
} from "./providerHelper.js";
2427
import {
2528
type AutomaticSpeechRecognitionTaskHelper,
2629
TaskProviderHelper,
@@ -34,6 +37,7 @@ import {
3437
InferenceClientProviderApiError,
3538
InferenceClientProviderOutputError,
3639
} from "../errors.js";
40+
import type { ImageToImageArgs } from "../tasks/index.js";
3741

3842
export interface FalAiQueueOutput {
3943
request_id: string;
@@ -130,6 +134,118 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
130134
}
131135
}
132136

137+
export class FalAIImageToImageTask extends FalAITask implements ImageToImageTaskHelper {
138+
constructor() {
139+
super("https://queue.fal.run");
140+
}
141+
142+
143+
override makeRoute(params: UrlParams): string {
144+
if (params.authMethod !== "provider-key") {
145+
return `/${params.model}?_subdomain=queue`;
146+
}
147+
return `/${params.model}`;
148+
}
149+
150+
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
151+
const mimeType = args.inputs instanceof Blob ? args.inputs.type : "image/png";
152+
return {
153+
...omit(args, ["inputs", "parameters"]),
154+
image_url: `data:${mimeType};base64,${base64FromBytes(
155+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
156+
)}`,
157+
...args.parameters,
158+
...args,
159+
};
160+
}
161+
162+
override async getResponse(
163+
response: FalAiQueueOutput,
164+
url?: string,
165+
headers?: Record<string, string>
166+
): Promise<Blob> {
167+
if (!url || !headers) {
168+
throw new InferenceClientInputError("URL and headers are required for image-to-image task");
169+
}
170+
const requestId = response.request_id;
171+
if (!requestId) {
172+
throw new InferenceClientProviderOutputError(
173+
"Received malformed response from Fal.ai text-to-video API: no request ID found in the response"
174+
);
175+
}
176+
let status = response.status;
177+
178+
const parsedUrl = new URL(url);
179+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
180+
}`;
181+
182+
// extracting the provider model id for status and result urls
183+
// from the response as it might be different from the mapped model in `url`
184+
const modelId = new URL(response.response_url).pathname;
185+
const queryParams = parsedUrl.search;
186+
187+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
188+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
189+
190+
while (status !== "COMPLETED") {
191+
await delay(500);
192+
const statusResponse = await fetch(statusUrl, { headers });
193+
194+
if (!statusResponse.ok) {
195+
console
196+
throw new InferenceClientProviderApiError(
197+
"Failed to fetch response status from fal-ai API",
198+
{ url: statusUrl, method: "GET" },
199+
{
200+
requestId: statusResponse.headers.get("x-request-id") ?? "",
201+
status: statusResponse.status,
202+
body: await statusResponse.text(),
203+
}
204+
);
205+
}
206+
try {
207+
status = (await statusResponse.json()).status;
208+
} catch (error) {
209+
throw new InferenceClientProviderOutputError(
210+
"Failed to parse status response from fal-ai API: received malformed response"
211+
);
212+
}
213+
}
214+
215+
const resultResponse = await fetch(resultUrl, { headers });
216+
let result: unknown;
217+
try {
218+
result = await resultResponse.json();
219+
} catch (error) {
220+
throw new InferenceClientProviderOutputError(
221+
"Failed to parse result response from fal-ai API: received malformed response"
222+
);
223+
}
224+
console.log("result", result);
225+
if (
226+
typeof result === "object" &&
227+
!!result &&
228+
"images" in result &&
229+
Array.isArray(result.images) &&
230+
result.images.length > 0 &&
231+
typeof result.images[0] === "object" &&
232+
!!result.images[0] &&
233+
"url" in result.images[0] &&
234+
typeof result.images[0].url === "string" &&
235+
isUrl(result.images[0].url)
236+
) {
237+
const urlResponse = await fetch(result.images[0].url);
238+
return await urlResponse.blob();
239+
} else {
240+
throw new InferenceClientProviderOutputError(
241+
`Received malformed response from Fal.ai image-to-image API: expected { images: Array<{ url: string }> } result format, got instead: ${JSON.stringify(
242+
result
243+
)}`
244+
);
245+
}
246+
}
247+
}
248+
133249
export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
134250
constructor() {
135251
super("https://queue.fal.run");
@@ -165,9 +281,8 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe
165281
let status = response.status;
166282

167283
const parsedUrl = new URL(url);
168-
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
169-
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
170-
}`;
284+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
285+
}`;
171286

172287
// extracting the provider model id for status and result urls
173288
// from the response as it might be different from the mapped model in `url`

packages/inference/src/tasks/cv/imageToImage.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js";
33
import { getProviderHelper } from "../../lib/getProviderHelper.js";
44
import type { BaseArgs, Options } from "../../types.js";
55
import { innerRequest } from "../../utils/request.js";
6+
import { makeRequestOptions } from "../../lib/makeRequestOptions.js";
67

78
export type ImageToImageArgs = BaseArgs & ImageToImageInput;
89

@@ -18,5 +19,6 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
1819
...options,
1920
task: "image-to-image",
2021
});
21-
return providerHelper.getResponse(res);
22+
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "image-to-image" });
23+
return providerHelper.getResponse(res, url, info.headers as Record<string, string>);
2224
}

0 commit comments

Comments
 (0)