Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
"text-to-image": new FalAI.FalAITextToImageTask(),
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
"text-to-video": new FalAI.FalAITextToVideoTask(),
"image-to-image": new FalAI.FalAIImageToImageTask(),
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
},
"featherless-ai": {
Expand Down
189 changes: 132 additions & 57 deletions packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ import { base64FromBytes } from "../utils/base64FromBytes.js";

import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
import { isUrl } from "../lib/isUrl.js";
import type { BodyParams, HeaderParams, ModelId, RequestArgs, UrlParams } from "../types.js";
import type { BodyParams, HeaderParams, InferenceTask, ModelId, RequestArgs, UrlParams } from "../types.js";
import { delay } from "../utils/delay.js";
import { omit } from "../utils/omit.js";
import type { ImageToImageTaskHelper } from "./providerHelper.js";
import {
type AutomaticSpeechRecognitionTaskHelper,
TaskProviderHelper,
Expand All @@ -34,6 +35,7 @@ import {
InferenceClientProviderApiError,
InferenceClientProviderOutputError,
} from "../errors.js";
import type { ImageToImageArgs } from "../tasks/index.js";

export interface FalAiQueueOutput {
request_id: string;
Expand Down Expand Up @@ -82,6 +84,75 @@ abstract class FalAITask extends TaskProviderHelper {
}
}

abstract class FalAiQueueTask extends FalAITask {
abstract task: InferenceTask;

async getResponseFromQueueApi(
response: FalAiQueueOutput,
url?: string,
headers?: Record<string, string>
): Promise<unknown> {
if (!url || !headers) {
throw new InferenceClientInputError(`URL and headers are required for ${this.task} task`);
}
const requestId = response.request_id;
if (!requestId) {
throw new InferenceClientProviderOutputError(
`Received malformed response from Fal.ai ${this.task} API: no request ID found in the response`
);
}
let status = response.status;

const parsedUrl = new URL(url);
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
}`;

// extracting the provider model id for status and result urls
// from the response as it might be different from the mapped model in `url`
const modelId = new URL(response.response_url).pathname;
const queryParams = parsedUrl.search;

const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
const resultUrl = `${baseUrl}${modelId}${queryParams}`;

while (status !== "COMPLETED") {
await delay(500);
const statusResponse = await fetch(statusUrl, { headers });

if (!statusResponse.ok) {
throw new InferenceClientProviderApiError(
"Failed to fetch response status from fal-ai API",
{ url: statusUrl, method: "GET" },
{
requestId: statusResponse.headers.get("x-request-id") ?? "",
status: statusResponse.status,
body: await statusResponse.text(),
}
);
}
try {
status = (await statusResponse.json()).status;
} catch (error) {
throw new InferenceClientProviderOutputError(
"Failed to parse status response from fal-ai API: received malformed response"
);
}
}

const resultResponse = await fetch(resultUrl, { headers });
let result: unknown;
try {
result = await resultResponse.json();
} catch (error) {
throw new InferenceClientProviderOutputError(
"Failed to parse result response from fal-ai API: received malformed response"
);
}
return result;
}
}

function buildLoraPath(modelId: ModelId, adapterWeightsPath: string): string {
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
}
Expand Down Expand Up @@ -130,21 +201,29 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
}
}

export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
export class FalAIImageToImageTask extends FalAiQueueTask implements ImageToImageTaskHelper {
task: InferenceTask;
constructor() {
super("https://queue.fal.run");
this.task = "image-to-image";
}

override makeRoute(params: UrlParams): string {
if (params.authMethod !== "provider-key") {
return `/${params.model}?_subdomain=queue`;
}
return `/${params.model}`;
}
override preparePayload(params: BodyParams): Record<string, unknown> {

async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
const mimeType = args.inputs instanceof Blob ? args.inputs.type : "image/png";
return {
...omit(params.args, ["inputs", "parameters"]),
...(params.args.parameters as Record<string, unknown>),
prompt: params.args.inputs,
...omit(args, ["inputs", "parameters"]),
image_url: `data:${mimeType};base64,${base64FromBytes(
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
)}`,
...args.parameters,
...args,
};
}

Expand All @@ -153,63 +232,59 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe
url?: string,
headers?: Record<string, string>
): Promise<Blob> {
if (!url || !headers) {
throw new InferenceClientInputError("URL and headers are required for text-to-video task");
}
const requestId = response.request_id;
if (!requestId) {
const result = await this.getResponseFromQueueApi(response, url, headers);

if (
typeof result === "object" &&
!!result &&
"images" in result &&
Array.isArray(result.images) &&
result.images.length > 0 &&
typeof result.images[0] === "object" &&
!!result.images[0] &&
"url" in result.images[0] &&
typeof result.images[0].url === "string" &&
isUrl(result.images[0].url)
) {
const urlResponse = await fetch(result.images[0].url);
return await urlResponse.blob();
} else {
throw new InferenceClientProviderOutputError(
"Received malformed response from Fal.ai text-to-video API: no request ID found in the response"
`Received malformed response from Fal.ai image-to-image API: expected { images: Array<{ url: string }> } result format, got instead: ${JSON.stringify(
result
)}`
);
}
let status = response.status;

const parsedUrl = new URL(url);
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
}`;

// extracting the provider model id for status and result urls
// from the response as it might be different from the mapped model in `url`
const modelId = new URL(response.response_url).pathname;
const queryParams = parsedUrl.search;

const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
const resultUrl = `${baseUrl}${modelId}${queryParams}`;

while (status !== "COMPLETED") {
await delay(500);
const statusResponse = await fetch(statusUrl, { headers });
}
}

if (!statusResponse.ok) {
throw new InferenceClientProviderApiError(
"Failed to fetch response status from fal-ai API",
{ url: statusUrl, method: "GET" },
{
requestId: statusResponse.headers.get("x-request-id") ?? "",
status: statusResponse.status,
body: await statusResponse.text(),
}
);
}
try {
status = (await statusResponse.json()).status;
} catch (error) {
throw new InferenceClientProviderOutputError(
"Failed to parse status response from fal-ai API: received malformed response"
);
}
export class FalAITextToVideoTask extends FalAiQueueTask implements TextToVideoTaskHelper {
task: InferenceTask;
constructor() {
super("https://queue.fal.run");
this.task = "text-to-video";
}
override makeRoute(params: UrlParams): string {
if (params.authMethod !== "provider-key") {
return `/${params.model}?_subdomain=queue`;
}
return `/${params.model}`;
}
override preparePayload(params: BodyParams): Record<string, unknown> {
return {
...omit(params.args, ["inputs", "parameters"]),
...(params.args.parameters as Record<string, unknown>),
prompt: params.args.inputs,
};
}

override async getResponse(
response: FalAiQueueOutput,
url?: string,
headers?: Record<string, string>
): Promise<Blob> {
const result = await this.getResponseFromQueueApi(response, url, headers);

const resultResponse = await fetch(resultUrl, { headers });
let result: unknown;
try {
result = await resultResponse.json();
} catch (error) {
throw new InferenceClientProviderOutputError(
"Failed to parse result response from fal-ai API: received malformed response"
);
}
if (
typeof result === "object" &&
!!result &&
Expand Down
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/imageToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js";
import { getProviderHelper } from "../../lib/getProviderHelper.js";
import type { BaseArgs, Options } from "../../types.js";
import { innerRequest } from "../../utils/request.js";
import { makeRequestOptions } from "../../lib/makeRequestOptions.js";

export type ImageToImageArgs = BaseArgs & ImageToImageInput;

Expand All @@ -18,5 +19,6 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
...options,
task: "image-to-image",
});
return providerHelper.getResponse(res);
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "image-to-image" });
return providerHelper.getResponse(res, url, info.headers as Record<string, string>);
}