From 0d89ef4ee3cf3a318b296f7cdcfa78b63f59d2e7 Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Thu, 26 Jun 2025 19:35:31 +0200 Subject: [PATCH 1/3] =?UTF-8?q?image-to-i=C3=B9age=20fal-ai?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../inference/src/lib/getProviderHelper.ts | 1 + packages/inference/src/providers/fal-ai.ts | 121 +++++++++++++++++- .../inference/src/tasks/cv/imageToImage.ts | 4 +- 3 files changed, 122 insertions(+), 4 deletions(-) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 619bce2133..d7cc87fb59 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -64,6 +64,7 @@ export const PROVIDERS: Record { + const mimeType = args.inputs instanceof Blob ? args.inputs.type : "image/png"; + return { + ...omit(args, ["inputs", "parameters"]), + image_url: `data:${mimeType};base64,${base64FromBytes( + new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer()) + )}`, + ...args.parameters, + ...args, + }; + } + + override async getResponse( + response: FalAiQueueOutput, + url?: string, + headers?: Record + ): Promise { + if (!url || !headers) { + throw new InferenceClientInputError("URL and headers are required for image-to-image task"); + } + const requestId = response.request_id; + if (!requestId) { + throw new InferenceClientProviderOutputError( + "Received malformed response from Fal.ai text-to-video API: no request ID found in the response" + ); + } + let status = response.status; + + const parsedUrl = new URL(url); + const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : "" + }`; + + // extracting the provider model id for status and result urls + // from the response as it might be different from the mapped model in `url` + const modelId = new URL(response.response_url).pathname; + const queryParams = parsedUrl.search; + + const statusUrl = `${baseUrl}${modelId}/status${queryParams}`; + const resultUrl = `${baseUrl}${modelId}${queryParams}`; + + while (status !== "COMPLETED") { + await delay(500); + const statusResponse = await fetch(statusUrl, { headers }); + + if (!statusResponse.ok) { + console + throw new InferenceClientProviderApiError( + "Failed to fetch response status from fal-ai API", + { url: statusUrl, method: "GET" }, + { + requestId: statusResponse.headers.get("x-request-id") ?? "", + status: statusResponse.status, + body: await statusResponse.text(), + } + ); + } + try { + status = (await statusResponse.json()).status; + } catch (error) { + throw new InferenceClientProviderOutputError( + "Failed to parse status response from fal-ai API: received malformed response" + ); + } + } + + const resultResponse = await fetch(resultUrl, { headers }); + let result: unknown; + try { + result = await resultResponse.json(); + } catch (error) { + throw new InferenceClientProviderOutputError( + "Failed to parse result response from fal-ai API: received malformed response" + ); + } + console.log("result", result); + if ( + typeof result === "object" && + !!result && + "images" in result && + Array.isArray(result.images) && + result.images.length > 0 && + typeof result.images[0] === "object" && + !!result.images[0] && + "url" in result.images[0] && + typeof result.images[0].url === "string" && + isUrl(result.images[0].url) + ) { + const urlResponse = await fetch(result.images[0].url); + return await urlResponse.blob(); + } else { + throw new InferenceClientProviderOutputError( + `Received malformed response from Fal.ai image-to-image API: expected { images: Array<{ url: string }> } result format, got instead: ${JSON.stringify( + result + )}` + ); + } + } +} + export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper { constructor() { super("https://queue.fal.run"); @@ -165,9 +281,8 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe let status = response.status; const parsedUrl = new URL(url); - const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${ - parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : "" - }`; + const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : "" + }`; // extracting the provider model id for status and result urls // from the response as it might be different from the mapped model in `url` diff --git a/packages/inference/src/tasks/cv/imageToImage.ts b/packages/inference/src/tasks/cv/imageToImage.ts index 4405dd2cb2..1266cc5452 100644 --- a/packages/inference/src/tasks/cv/imageToImage.ts +++ b/packages/inference/src/tasks/cv/imageToImage.ts @@ -3,6 +3,7 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js"; import { getProviderHelper } from "../../lib/getProviderHelper.js"; import type { BaseArgs, Options } from "../../types.js"; import { innerRequest } from "../../utils/request.js"; +import { makeRequestOptions } from "../../lib/makeRequestOptions.js"; export type ImageToImageArgs = BaseArgs & ImageToImageInput; @@ -18,5 +19,6 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P ...options, task: "image-to-image", }); - return providerHelper.getResponse(res); + const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "image-to-image" }); + return providerHelper.getResponse(res, url, info.headers as Record); } From e7a27f37fec8192b0c378ac57fa632573e94df31 Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Thu, 26 Jun 2025 19:42:56 +0200 Subject: [PATCH 2/3] lint --- packages/inference/src/providers/fal-ai.ts | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/packages/inference/src/providers/fal-ai.ts b/packages/inference/src/providers/fal-ai.ts index 52b193a823..5f4a8e86d3 100644 --- a/packages/inference/src/providers/fal-ai.ts +++ b/packages/inference/src/providers/fal-ai.ts @@ -21,9 +21,7 @@ import { isUrl } from "../lib/isUrl.js"; import type { BodyParams, HeaderParams, ModelId, RequestArgs, UrlParams } from "../types.js"; import { delay } from "../utils/delay.js"; import { omit } from "../utils/omit.js"; -import type { - ImageToImageTaskHelper -} from "./providerHelper.js"; +import type { ImageToImageTaskHelper } from "./providerHelper.js"; import { type AutomaticSpeechRecognitionTaskHelper, TaskProviderHelper, @@ -139,7 +137,6 @@ export class FalAIImageToImageTask extends FalAITask implements ImageToImageTask super("https://queue.fal.run"); } - override makeRoute(params: UrlParams): string { if (params.authMethod !== "provider-key") { return `/${params.model}?_subdomain=queue`; @@ -192,7 +189,6 @@ export class FalAIImageToImageTask extends FalAITask implements ImageToImageTask const statusResponse = await fetch(statusUrl, { headers }); if (!statusResponse.ok) { - console throw new InferenceClientProviderApiError( "Failed to fetch response status from fal-ai API", { url: statusUrl, method: "GET" }, From 904b6da0c7e681f9bf22d865e77baaec8d409967 Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Thu, 26 Jun 2025 19:51:07 +0200 Subject: [PATCH 3/3] factor --- packages/inference/src/providers/fal-ai.ts | 192 +++++++++------------ 1 file changed, 78 insertions(+), 114 deletions(-) diff --git a/packages/inference/src/providers/fal-ai.ts b/packages/inference/src/providers/fal-ai.ts index 5f4a8e86d3..8976a4f652 100644 --- a/packages/inference/src/providers/fal-ai.ts +++ b/packages/inference/src/providers/fal-ai.ts @@ -18,7 +18,7 @@ import { base64FromBytes } from "../utils/base64FromBytes.js"; import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks"; import { isUrl } from "../lib/isUrl.js"; -import type { BodyParams, HeaderParams, ModelId, RequestArgs, UrlParams } from "../types.js"; +import type { BodyParams, HeaderParams, InferenceTask, ModelId, RequestArgs, UrlParams } from "../types.js"; import { delay } from "../utils/delay.js"; import { omit } from "../utils/omit.js"; import type { ImageToImageTaskHelper } from "./providerHelper.js"; @@ -84,6 +84,75 @@ abstract class FalAITask extends TaskProviderHelper { } } +abstract class FalAiQueueTask extends FalAITask { + abstract task: InferenceTask; + + async getResponseFromQueueApi( + response: FalAiQueueOutput, + url?: string, + headers?: Record + ): Promise { + if (!url || !headers) { + throw new InferenceClientInputError(`URL and headers are required for ${this.task} task`); + } + const requestId = response.request_id; + if (!requestId) { + throw new InferenceClientProviderOutputError( + `Received malformed response from Fal.ai ${this.task} API: no request ID found in the response` + ); + } + let status = response.status; + + const parsedUrl = new URL(url); + const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${ + parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : "" + }`; + + // extracting the provider model id for status and result urls + // from the response as it might be different from the mapped model in `url` + const modelId = new URL(response.response_url).pathname; + const queryParams = parsedUrl.search; + + const statusUrl = `${baseUrl}${modelId}/status${queryParams}`; + const resultUrl = `${baseUrl}${modelId}${queryParams}`; + + while (status !== "COMPLETED") { + await delay(500); + const statusResponse = await fetch(statusUrl, { headers }); + + if (!statusResponse.ok) { + throw new InferenceClientProviderApiError( + "Failed to fetch response status from fal-ai API", + { url: statusUrl, method: "GET" }, + { + requestId: statusResponse.headers.get("x-request-id") ?? "", + status: statusResponse.status, + body: await statusResponse.text(), + } + ); + } + try { + status = (await statusResponse.json()).status; + } catch (error) { + throw new InferenceClientProviderOutputError( + "Failed to parse status response from fal-ai API: received malformed response" + ); + } + } + + const resultResponse = await fetch(resultUrl, { headers }); + let result: unknown; + try { + result = await resultResponse.json(); + } catch (error) { + throw new InferenceClientProviderOutputError( + "Failed to parse result response from fal-ai API: received malformed response" + ); + } + return result; + } +} + function buildLoraPath(modelId: ModelId, adapterWeightsPath: string): string { return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`; } @@ -132,9 +201,11 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe } } -export class FalAIImageToImageTask extends FalAITask implements ImageToImageTaskHelper { +export class FalAIImageToImageTask extends FalAiQueueTask implements ImageToImageTaskHelper { + task: InferenceTask; constructor() { super("https://queue.fal.run"); + this.task = "image-to-image"; } override makeRoute(params: UrlParams): string { @@ -161,63 +232,8 @@ export class FalAIImageToImageTask extends FalAITask implements ImageToImageTask url?: string, headers?: Record ): Promise { - if (!url || !headers) { - throw new InferenceClientInputError("URL and headers are required for image-to-image task"); - } - const requestId = response.request_id; - if (!requestId) { - throw new InferenceClientProviderOutputError( - "Received malformed response from Fal.ai text-to-video API: no request ID found in the response" - ); - } - let status = response.status; - - const parsedUrl = new URL(url); - const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : "" - }`; - - // extracting the provider model id for status and result urls - // from the response as it might be different from the mapped model in `url` - const modelId = new URL(response.response_url).pathname; - const queryParams = parsedUrl.search; - - const statusUrl = `${baseUrl}${modelId}/status${queryParams}`; - const resultUrl = `${baseUrl}${modelId}${queryParams}`; - - while (status !== "COMPLETED") { - await delay(500); - const statusResponse = await fetch(statusUrl, { headers }); + const result = await this.getResponseFromQueueApi(response, url, headers); - if (!statusResponse.ok) { - throw new InferenceClientProviderApiError( - "Failed to fetch response status from fal-ai API", - { url: statusUrl, method: "GET" }, - { - requestId: statusResponse.headers.get("x-request-id") ?? "", - status: statusResponse.status, - body: await statusResponse.text(), - } - ); - } - try { - status = (await statusResponse.json()).status; - } catch (error) { - throw new InferenceClientProviderOutputError( - "Failed to parse status response from fal-ai API: received malformed response" - ); - } - } - - const resultResponse = await fetch(resultUrl, { headers }); - let result: unknown; - try { - result = await resultResponse.json(); - } catch (error) { - throw new InferenceClientProviderOutputError( - "Failed to parse result response from fal-ai API: received malformed response" - ); - } - console.log("result", result); if ( typeof result === "object" && !!result && @@ -242,9 +258,11 @@ export class FalAIImageToImageTask extends FalAITask implements ImageToImageTask } } -export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper { +export class FalAITextToVideoTask extends FalAiQueueTask implements TextToVideoTaskHelper { + task: InferenceTask; constructor() { super("https://queue.fal.run"); + this.task = "text-to-video"; } override makeRoute(params: UrlParams): string { if (params.authMethod !== "provider-key") { @@ -265,62 +283,8 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe url?: string, headers?: Record ): Promise { - if (!url || !headers) { - throw new InferenceClientInputError("URL and headers are required for text-to-video task"); - } - const requestId = response.request_id; - if (!requestId) { - throw new InferenceClientProviderOutputError( - "Received malformed response from Fal.ai text-to-video API: no request ID found in the response" - ); - } - let status = response.status; - - const parsedUrl = new URL(url); - const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : "" - }`; - - // extracting the provider model id for status and result urls - // from the response as it might be different from the mapped model in `url` - const modelId = new URL(response.response_url).pathname; - const queryParams = parsedUrl.search; - - const statusUrl = `${baseUrl}${modelId}/status${queryParams}`; - const resultUrl = `${baseUrl}${modelId}${queryParams}`; - - while (status !== "COMPLETED") { - await delay(500); - const statusResponse = await fetch(statusUrl, { headers }); - - if (!statusResponse.ok) { - throw new InferenceClientProviderApiError( - "Failed to fetch response status from fal-ai API", - { url: statusUrl, method: "GET" }, - { - requestId: statusResponse.headers.get("x-request-id") ?? "", - status: statusResponse.status, - body: await statusResponse.text(), - } - ); - } - try { - status = (await statusResponse.json()).status; - } catch (error) { - throw new InferenceClientProviderOutputError( - "Failed to parse status response from fal-ai API: received malformed response" - ); - } - } + const result = await this.getResponseFromQueueApi(response, url, headers); - const resultResponse = await fetch(resultUrl, { headers }); - let result: unknown; - try { - result = await resultResponse.json(); - } catch (error) { - throw new InferenceClientProviderOutputError( - "Failed to parse result response from fal-ai API: received malformed response" - ); - } if ( typeof result === "object" && !!result &&