diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 619bce2133..d7cc87fb59 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -64,6 +64,7 @@ export const PROVIDERS: Record + ): Promise { + if (!url || !headers) { + throw new InferenceClientInputError(`URL and headers are required for ${this.task} task`); + } + const requestId = response.request_id; + if (!requestId) { + throw new InferenceClientProviderOutputError( + `Received malformed response from Fal.ai ${this.task} API: no request ID found in the response` + ); + } + let status = response.status; + + const parsedUrl = new URL(url); + const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${ + parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : "" + }`; + + // extracting the provider model id for status and result urls + // from the response as it might be different from the mapped model in `url` + const modelId = new URL(response.response_url).pathname; + const queryParams = parsedUrl.search; + + const statusUrl = `${baseUrl}${modelId}/status${queryParams}`; + const resultUrl = `${baseUrl}${modelId}${queryParams}`; + + while (status !== "COMPLETED") { + await delay(500); + const statusResponse = await fetch(statusUrl, { headers }); + + if (!statusResponse.ok) { + throw new InferenceClientProviderApiError( + "Failed to fetch response status from fal-ai API", + { url: statusUrl, method: "GET" }, + { + requestId: statusResponse.headers.get("x-request-id") ?? "", + status: statusResponse.status, + body: await statusResponse.text(), + } + ); + } + try { + status = (await statusResponse.json()).status; + } catch (error) { + throw new InferenceClientProviderOutputError( + "Failed to parse status response from fal-ai API: received malformed response" + ); + } + } + + const resultResponse = await fetch(resultUrl, { headers }); + let result: unknown; + try { + result = await resultResponse.json(); + } catch (error) { + throw new InferenceClientProviderOutputError( + "Failed to parse result response from fal-ai API: received malformed response" + ); + } + return result; + } +} + function buildLoraPath(modelId: ModelId, adapterWeightsPath: string): string { return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`; } @@ -130,21 +201,29 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe } } -export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper { +export class FalAIImageToImageTask extends FalAiQueueTask implements ImageToImageTaskHelper { + task: InferenceTask; constructor() { super("https://queue.fal.run"); + this.task = "image-to-image"; } + override makeRoute(params: UrlParams): string { if (params.authMethod !== "provider-key") { return `/${params.model}?_subdomain=queue`; } return `/${params.model}`; } - override preparePayload(params: BodyParams): Record { + + async preparePayloadAsync(args: ImageToImageArgs): Promise { + const mimeType = args.inputs instanceof Blob ? args.inputs.type : "image/png"; return { - ...omit(params.args, ["inputs", "parameters"]), - ...(params.args.parameters as Record), - prompt: params.args.inputs, + ...omit(args, ["inputs", "parameters"]), + image_url: `data:${mimeType};base64,${base64FromBytes( + new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer()) + )}`, + ...args.parameters, + ...args, }; } @@ -153,63 +232,59 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe url?: string, headers?: Record ): Promise { - if (!url || !headers) { - throw new InferenceClientInputError("URL and headers are required for text-to-video task"); - } - const requestId = response.request_id; - if (!requestId) { + const result = await this.getResponseFromQueueApi(response, url, headers); + + if ( + typeof result === "object" && + !!result && + "images" in result && + Array.isArray(result.images) && + result.images.length > 0 && + typeof result.images[0] === "object" && + !!result.images[0] && + "url" in result.images[0] && + typeof result.images[0].url === "string" && + isUrl(result.images[0].url) + ) { + const urlResponse = await fetch(result.images[0].url); + return await urlResponse.blob(); + } else { throw new InferenceClientProviderOutputError( - "Received malformed response from Fal.ai text-to-video API: no request ID found in the response" + `Received malformed response from Fal.ai image-to-image API: expected { images: Array<{ url: string }> } result format, got instead: ${JSON.stringify( + result + )}` ); } - let status = response.status; - - const parsedUrl = new URL(url); - const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${ - parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : "" - }`; - - // extracting the provider model id for status and result urls - // from the response as it might be different from the mapped model in `url` - const modelId = new URL(response.response_url).pathname; - const queryParams = parsedUrl.search; - - const statusUrl = `${baseUrl}${modelId}/status${queryParams}`; - const resultUrl = `${baseUrl}${modelId}${queryParams}`; - - while (status !== "COMPLETED") { - await delay(500); - const statusResponse = await fetch(statusUrl, { headers }); + } +} - if (!statusResponse.ok) { - throw new InferenceClientProviderApiError( - "Failed to fetch response status from fal-ai API", - { url: statusUrl, method: "GET" }, - { - requestId: statusResponse.headers.get("x-request-id") ?? "", - status: statusResponse.status, - body: await statusResponse.text(), - } - ); - } - try { - status = (await statusResponse.json()).status; - } catch (error) { - throw new InferenceClientProviderOutputError( - "Failed to parse status response from fal-ai API: received malformed response" - ); - } +export class FalAITextToVideoTask extends FalAiQueueTask implements TextToVideoTaskHelper { + task: InferenceTask; + constructor() { + super("https://queue.fal.run"); + this.task = "text-to-video"; + } + override makeRoute(params: UrlParams): string { + if (params.authMethod !== "provider-key") { + return `/${params.model}?_subdomain=queue`; } + return `/${params.model}`; + } + override preparePayload(params: BodyParams): Record { + return { + ...omit(params.args, ["inputs", "parameters"]), + ...(params.args.parameters as Record), + prompt: params.args.inputs, + }; + } + + override async getResponse( + response: FalAiQueueOutput, + url?: string, + headers?: Record + ): Promise { + const result = await this.getResponseFromQueueApi(response, url, headers); - const resultResponse = await fetch(resultUrl, { headers }); - let result: unknown; - try { - result = await resultResponse.json(); - } catch (error) { - throw new InferenceClientProviderOutputError( - "Failed to parse result response from fal-ai API: received malformed response" - ); - } if ( typeof result === "object" && !!result && diff --git a/packages/inference/src/tasks/cv/imageToImage.ts b/packages/inference/src/tasks/cv/imageToImage.ts index 4405dd2cb2..1266cc5452 100644 --- a/packages/inference/src/tasks/cv/imageToImage.ts +++ b/packages/inference/src/tasks/cv/imageToImage.ts @@ -3,6 +3,7 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js"; import { getProviderHelper } from "../../lib/getProviderHelper.js"; import type { BaseArgs, Options } from "../../types.js"; import { innerRequest } from "../../utils/request.js"; +import { makeRequestOptions } from "../../lib/makeRequestOptions.js"; export type ImageToImageArgs = BaseArgs & ImageToImageInput; @@ -18,5 +19,6 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P ...options, task: "image-to-image", }); - return providerHelper.getResponse(res); + const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "image-to-image" }); + return providerHelper.getResponse(res, url, info.headers as Record); }