1+ import { FAL_AI_API_BASE_URL , FAL_AI_MODEL_IDS } from "../providers/fal-ai" ;
12import { REPLICATE_API_BASE_URL , REPLICATE_MODEL_IDS } from "../providers/replicate" ;
23import { SAMBANOVA_API_BASE_URL , SAMBANOVA_MODEL_IDS } from "../providers/sambanova" ;
34import { TOGETHER_API_BASE_URL , TOGETHER_MODEL_IDS } from "../providers/together" ;
@@ -9,7 +10,8 @@ import { isUrl } from "./isUrl";
910const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co" ;
1011
1112/**
12- * Loaded from huggingface.co/api/tasks if needed
13+ * Lazy-loaded from huggingface.co/api/tasks when needed
14+ * Used to determine the default model to use when it's not user defined
1315 */
1416let tasks : Record < string , { models : { id : string } [ ] } > | null = null ;
1517
@@ -36,7 +38,7 @@ export async function makeRequestOptions(
3638
3739 const headers : Record < string , string > = { } ;
3840 if ( accessToken ) {
39- headers [ "Authorization" ] = `Bearer ${ accessToken } ` ;
41+ headers [ "Authorization" ] = provider === "fal-ai" ? `Key ${ accessToken } ` : `Bearer ${ accessToken } ` ;
4042 }
4143
4244 if ( ! model && ! tasks && taskHint ) {
@@ -74,6 +76,9 @@ export async function makeRequestOptions(
7476 case "together" :
7577 model = TOGETHER_MODEL_IDS [ model ] ?. id ?? model ;
7678 break ;
79+ case "fal-ai" :
80+ model = FAL_AI_MODEL_IDS [ model ] ;
81+ break ;
7782 default :
7883 break ;
7984 }
@@ -120,8 +125,9 @@ export async function makeRequestOptions(
120125 /// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
121126 throw new Error ( "Inference proxying is not implemented yet" ) ;
122127 } else {
123- /// This is an external key
124128 switch ( provider ) {
129+ case 'fal-ai' :
130+ return `${ FAL_AI_API_BASE_URL } /${ model } ` ;
125131 case "replicate" :
126132 return `${ REPLICATE_API_BASE_URL } /v1/models/${ model } /predictions` ;
127133 case "sambanova" :
@@ -160,10 +166,10 @@ export async function makeRequestOptions(
160166 body : binary
161167 ? args . data
162168 : JSON . stringify ( {
163- ...( ( otherArgs . model && isUrl ( otherArgs . model ) ) || provider === "replicate"
164- ? omit ( otherArgs , "model" )
165- : { ...otherArgs , model } ) ,
166- } ) ,
169+ ...( ( otherArgs . model && isUrl ( otherArgs . model ) ) || provider === "replicate" || provider === "fal-ai "
170+ ? omit ( otherArgs , "model" )
171+ : { ...otherArgs , model } ) ,
172+ } ) ,
167173 ...( credentials ? { credentials } : undefined ) ,
168174 signal : options ?. signal ,
169175 } ;
0 commit comments