Skip to content

Commit 238567a

Browse files
committed
add fal-ai as a provider
1 parent aa7d1ca commit 238567a

File tree

5 files changed

+35
-12
lines changed

5 files changed

+35
-12
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
12
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
23
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
34
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
@@ -9,7 +10,8 @@ import { isUrl } from "./isUrl";
910
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
1011

1112
/**
12-
* Loaded from huggingface.co/api/tasks if needed
13+
* Lazy-loaded from huggingface.co/api/tasks when needed
14+
* Used to determine the default model to use when it's not user defined
1315
*/
1416
let tasks: Record<string, { models: { id: string }[] }> | null = null;
1517

@@ -36,7 +38,7 @@ export async function makeRequestOptions(
3638

3739
const headers: Record<string, string> = {};
3840
if (accessToken) {
39-
headers["Authorization"] = `Bearer ${accessToken}`;
41+
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
4042
}
4143

4244
if (!model && !tasks && taskHint) {
@@ -74,6 +76,9 @@ export async function makeRequestOptions(
7476
case "together":
7577
model = TOGETHER_MODEL_IDS[model]?.id ?? model;
7678
break;
79+
case "fal-ai":
80+
model = FAL_AI_MODEL_IDS[model];
81+
break;
7782
default:
7883
break;
7984
}
@@ -120,8 +125,9 @@ export async function makeRequestOptions(
120125
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
121126
throw new Error("Inference proxying is not implemented yet");
122127
} else {
123-
/// This is an external key
124128
switch (provider) {
129+
case 'fal-ai':
130+
return `${FAL_AI_API_BASE_URL}/${model}`;
125131
case "replicate":
126132
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
127133
case "sambanova":
@@ -160,10 +166,10 @@ export async function makeRequestOptions(
160166
body: binary
161167
? args.data
162168
: JSON.stringify({
163-
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate"
164-
? omit(otherArgs, "model")
165-
: { ...otherArgs, model }),
166-
}),
169+
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" || provider === "fal-ai"
170+
? omit(otherArgs, "model")
171+
: { ...otherArgs, model }),
172+
}),
167173
...(credentials ? { credentials } : undefined),
168174
signal: options?.signal,
169175
};
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import type { ModelId } from "../types";
2+
3+
export const FAL_AI_API_BASE_URL = "https://fal.run"
4+
5+
type FalAiId = string;
6+
7+
export const FAL_AI_MODEL_IDS: Record<ModelId, FalAiId> = {
8+
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
9+
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
10+
"black-forest-labs/FLUX.1-Redux-dev": "fal-ai/flux/dev/redux",
11+
"openai/whisper-large-v3": "fal-ai/wizper",
12+
"TencentARC/PhotoMaker": "fal-ai/photomaker",
13+
}

packages/inference/src/tasks/custom/request.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import type { InferenceTask, Options, RequestArgs } from "../../types";
22
import { makeRequestOptions } from "../../lib/makeRequestOptions";
33

44
/**
5-
* Primitive to make custom calls to Inference Endpoints
5+
* Primitive to make custom calls to the inference provider
66
*/
77
export async function request<T>(
88
args: RequestArgs,
@@ -35,8 +35,8 @@ export async function request<T>(
3535
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
3636
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
3737
}
38-
if (output.error) {
39-
throw new Error(JSON.stringify(output.error));
38+
if (output.error || output.detail) {
39+
throw new Error(JSON.stringify(output.error ?? output.detail));
4040
} else {
4141
throw new Error(output);
4242
}

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ interface OutputUrlImageGeneration {
5757
* Recommended model: stabilityai/stable-diffusion-2
5858
*/
5959
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
60-
if (args.provider === "together") {
60+
if (args.provider === "together" || args.provider === "fal-ai") {
6161
args.prompt = args.inputs;
6262
args.inputs = "";
6363
args.response_format = "base64";
@@ -70,6 +70,10 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
7070
taskHint: "text-to-image",
7171
});
7272
if (res && typeof res === "object") {
73+
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
74+
const image = await fetch(res.images[0].url);
75+
return await image.blob();
76+
}
7377
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
7478
const base64Data = res.data[0].b64_json;
7579
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);

packages/inference/src/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ export interface Options {
4545

4646
export type InferenceTask = Exclude<PipelineType, "other">;
4747

48-
export const INFERENCE_PROVIDERS = ["replicate", "sambanova", "together", "hf-inference"] as const;
48+
export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"] as const;
4949
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
5050

5151
export interface BaseArgs {

0 commit comments

Comments
 (0)