11import { HF_HUB_URL , HF_ROUTER_URL } from "../config" ;
2- import { FAL_AI_API_BASE_URL } from "../providers/fal-ai " ;
3- import { NEBIUS_API_BASE_URL } from "../providers/nebius " ;
4- import { REPLICATE_API_BASE_URL } from "../providers/replicate " ;
5- import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova " ;
6- import { TOGETHER_API_BASE_URL } from "../providers/together " ;
7- import { NOVITA_API_BASE_URL } from "../providers/novita " ;
8- import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai " ;
9- import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic " ;
10- import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs " ;
11- import type { InferenceProvider } from "../types " ;
12- import type { InferenceTask , Options , RequestArgs } from "../types" ;
2+ import { BLACK_FOREST_LABS_CONFIG } from "../providers/black-forest-labs " ;
3+ import { FAL_AI_CONFIG } from "../providers/fal-ai " ;
4+ import { FIREWORKS_AI_CONFIG } from "../providers/fireworks-ai " ;
5+ import { HF_INFERENCE_CONFIG } from "../providers/hf-inference " ;
6+ import { HYPERBOLIC_CONFIG } from "../providers/hyperbolic " ;
7+ import { NEBIUS_CONFIG } from "../providers/nebius " ;
8+ import { NOVITA_CONFIG } from "../providers/novita " ;
9+ import { REPLICATE_CONFIG } from "../providers/replicate " ;
10+ import { SAMBANOVA_CONFIG } from "../providers/sambanova " ;
11+ import { TOGETHER_CONFIG } from "../providers/together " ;
12+ import type { InferenceProvider , InferenceTask , Options , ProviderConfig , RequestArgs } from "../types" ;
1313import { isUrl } from "./isUrl" ;
1414import { version as packageVersion , name as packageName } from "../../package.json" ;
1515import { getProviderModelId } from "./getProviderModelId" ;
@@ -22,6 +22,22 @@ const HF_HUB_INFERENCE_PROXY_TEMPLATE = `${HF_ROUTER_URL}/{{PROVIDER}}`;
2222 */
2323let tasks : Record < string , { models : { id : string } [ ] } > | null = null ;
2424
25+ /**
26+ * Config to define how to serialize requests for each provider
27+ */
28+ const providerConfigs : Record < InferenceProvider , ProviderConfig > = {
29+ "black-forest-labs" : BLACK_FOREST_LABS_CONFIG ,
30+ "fal-ai" : FAL_AI_CONFIG ,
31+ "fireworks-ai" : FIREWORKS_AI_CONFIG ,
32+ "hf-inference" : HF_INFERENCE_CONFIG ,
33+ hyperbolic : HYPERBOLIC_CONFIG ,
34+ nebius : NEBIUS_CONFIG ,
35+ novita : NOVITA_CONFIG ,
36+ replicate : REPLICATE_CONFIG ,
37+ sambanova : SAMBANOVA_CONFIG ,
38+ together : TOGETHER_CONFIG ,
39+ } ;
40+
2541/**
2642 * Helper that prepares request arguments
2743 */
@@ -37,10 +53,10 @@ export async function makeRequestOptions(
3753 }
3854) : Promise < { url : string ; info : RequestInit } > {
3955 const { accessToken, endpointUrl, provider : maybeProvider , model : maybeModel , ...remainingArgs } = args ;
40- let otherArgs = remainingArgs ;
4156 const provider = maybeProvider ?? "hf-inference" ;
57+ const providerConfig = providerConfigs [ provider ] ;
4258
43- const { includeCredentials, task, chatCompletion } = options ?? { } ;
59+ const { includeCredentials, task, chatCompletion, signal } = options ?? { } ;
4460
4561 if ( endpointUrl && provider !== "hf-inference" ) {
4662 throw new Error ( `Cannot use endpointUrl with a third-party provider.` ) ;
@@ -51,6 +67,9 @@ export async function makeRequestOptions(
5167 if ( ! maybeModel && ! task ) {
5268 throw new Error ( "No model provided, and no task has been specified." ) ;
5369 }
70+ if ( ! providerConfig ) {
71+ throw new Error ( `No provider config found for provider ${ provider } ` ) ;
72+ }
5473 // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
5574 const hfModel = maybeModel ?? ( await loadDefaultModel ( task ! ) ) ;
5675 const model = await getProviderModelId ( { model : hfModel , provider } , args , {
@@ -68,44 +87,52 @@ export async function makeRequestOptions(
6887 ? "credentials-include"
6988 : "none" ;
7089
90+ // Make URL
7191 const url = endpointUrl
7292 ? chatCompletion
7393 ? endpointUrl + `/v1/chat/completions`
7494 : endpointUrl
75- : makeUrl ( {
76- authMethod,
77- chatCompletion : chatCompletion ?? false ,
95+ : providerConfig . makeUrl ( {
96+ baseUrl :
97+ authMethod !== "provider-key"
98+ ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , provider )
99+ : providerConfig . baseUrl ,
78100 model,
79- provider : provider ?? "hf-inference" ,
101+ chatCompletion ,
80102 task,
81103 } ) ;
82104
83- const headers : Record < string , string > = { } ;
84- if ( accessToken ) {
85- if ( provider === "fal-ai" && authMethod === "provider-key" ) {
86- headers [ "Authorization" ] = `Key ${ accessToken } ` ;
87- } else if ( provider === "black-forest-labs" && authMethod === "provider-key" ) {
88- headers [ "X-Key" ] = accessToken ;
89- } else {
90- headers [ "Authorization" ] = `Bearer ${ accessToken } ` ;
91- }
92- }
93-
94- // e.g. @huggingface /inference/3.1.3
95- const ownUserAgent = `${ packageName } /${ packageVersion } ` ;
96- headers [ "User-Agent" ] = [ ownUserAgent , typeof navigator !== "undefined" ? navigator . userAgent : undefined ]
97- . filter ( ( x ) => x !== undefined )
98- . join ( " " ) ;
99-
105+ // Make headers
100106 const binary = "data" in args && ! ! args . data ;
107+ const headers = providerConfig . makeHeaders ( {
108+ accessToken,
109+ authMethod,
110+ } ) ;
101111
112+ // Add content-type to headers
102113 if ( ! binary ) {
103114 headers [ "Content-Type" ] = "application/json" ;
104115 }
105116
106- if ( provider === "replicate" ) {
107- headers [ "Prefer" ] = "wait" ;
108- }
117+ // Add user-agent to headers
118+ // e.g. @huggingface /inference/3.1.3
119+ const ownUserAgent = `${ packageName } /${ packageVersion } ` ;
120+ const userAgent = [ ownUserAgent , typeof navigator !== "undefined" ? navigator . userAgent : undefined ]
121+ . filter ( ( x ) => x !== undefined )
122+ . join ( " " ) ;
123+ headers [ "User-Agent" ] = userAgent ;
124+
125+ // Make body
126+ const body = binary
127+ ? args . data
128+ : JSON . stringify (
129+ providerConfig . makeBody ( {
130+ args : remainingArgs as Record < string , unknown > ,
131+ model,
132+ task,
133+ chatCompletion,
134+ } )
135+ ) ;
109136
110137 /**
111138 * For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
@@ -117,158 +144,17 @@ export async function makeRequestOptions(
117144 credentials = "include" ;
118145 }
119146
120- /**
121- * Replicate models wrap all inputs inside { input: ... }
122- * Versioned Replicate models in the format `owner/model:version` expect the version in the body
123- */
124- if ( provider === "replicate" ) {
125- const version = model . includes ( ":" ) ? model . split ( ":" ) [ 1 ] : undefined ;
126- ( otherArgs as unknown ) = { input : otherArgs , version } ;
127- }
128-
129147 const info : RequestInit = {
130148 headers,
131149 method : "POST" ,
132- body : binary
133- ? args . data
134- : JSON . stringify ( {
135- ...otherArgs ,
136- ...( task === "text-to-image" && provider === "hyperbolic"
137- ? { model_name : model }
138- : chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
139- ? { model }
140- : undefined ) ,
141- } ) ,
150+ body,
142151 ...( credentials ? { credentials } : undefined ) ,
143- signal : options ?. signal ,
152+ signal,
144153 } ;
145154
146155 return { url, info } ;
147156}
148157
149- function makeUrl ( params : {
150- authMethod : "none" | "hf-token" | "credentials-include" | "provider-key" ;
151- chatCompletion : boolean ;
152- model : string ;
153- provider : InferenceProvider ;
154- task : InferenceTask | undefined ;
155- } ) : string {
156- if ( params . authMethod === "none" && params . provider !== "hf-inference" ) {
157- throw new Error ( "Authentication is required when requesting a third-party provider. Please provide accessToken" ) ;
158- }
159-
160- const shouldProxy = params . provider !== "hf-inference" && params . authMethod !== "provider-key" ;
161- switch ( params . provider ) {
162- case "black-forest-labs" : {
163- const baseUrl = shouldProxy
164- ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , params . provider )
165- : BLACKFORESTLABS_AI_API_BASE_URL ;
166- return `${ baseUrl } /${ params . model } ` ;
167- }
168- case "fal-ai" : {
169- const baseUrl = shouldProxy
170- ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , params . provider )
171- : FAL_AI_API_BASE_URL ;
172- return `${ baseUrl } /${ params . model } ` ;
173- }
174- case "nebius" : {
175- const baseUrl = shouldProxy
176- ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , params . provider )
177- : NEBIUS_API_BASE_URL ;
178-
179- if ( params . task === "text-to-image" ) {
180- return `${ baseUrl } /v1/images/generations` ;
181- }
182- if ( params . task === "text-generation" ) {
183- if ( params . chatCompletion ) {
184- return `${ baseUrl } /v1/chat/completions` ;
185- }
186- return `${ baseUrl } /v1/completions` ;
187- }
188- return baseUrl ;
189- }
190- case "replicate" : {
191- const baseUrl = shouldProxy
192- ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , params . provider )
193- : REPLICATE_API_BASE_URL ;
194- if ( params . model . includes ( ":" ) ) {
195- /// Versioned model
196- return `${ baseUrl } /v1/predictions` ;
197- }
198- /// Evergreen / Canonical model
199- return `${ baseUrl } /v1/models/${ params . model } /predictions` ;
200- }
201- case "sambanova" : {
202- const baseUrl = shouldProxy
203- ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , params . provider )
204- : SAMBANOVA_API_BASE_URL ;
205- /// Sambanova API matches OpenAI-like APIs: model is defined in the request body
206- if ( params . task === "text-generation" && params . chatCompletion ) {
207- return `${ baseUrl } /v1/chat/completions` ;
208- }
209- return baseUrl ;
210- }
211- case "together" : {
212- const baseUrl = shouldProxy
213- ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , params . provider )
214- : TOGETHER_API_BASE_URL ;
215- /// Together API matches OpenAI-like APIs: model is defined in the request body
216- if ( params . task === "text-to-image" ) {
217- return `${ baseUrl } /v1/images/generations` ;
218- }
219- if ( params . task === "text-generation" ) {
220- if ( params . chatCompletion ) {
221- return `${ baseUrl } /v1/chat/completions` ;
222- }
223- return `${ baseUrl } /v1/completions` ;
224- }
225- return baseUrl ;
226- }
227-
228- case "fireworks-ai" : {
229- const baseUrl = shouldProxy
230- ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , params . provider )
231- : FIREWORKS_AI_API_BASE_URL ;
232- if ( params . task === "text-generation" && params . chatCompletion ) {
233- return `${ baseUrl } /v1/chat/completions` ;
234- }
235- return baseUrl ;
236- }
237- case "hyperbolic" : {
238- const baseUrl = shouldProxy
239- ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , params . provider )
240- : HYPERBOLIC_API_BASE_URL ;
241-
242- if ( params . task === "text-to-image" ) {
243- return `${ baseUrl } /v1/images/generations` ;
244- }
245- return `${ baseUrl } /v1/chat/completions` ;
246- }
247- case "novita" : {
248- const baseUrl = shouldProxy
249- ? HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , params . provider )
250- : NOVITA_API_BASE_URL ;
251- if ( params . task === "text-generation" ) {
252- if ( params . chatCompletion ) {
253- return `${ baseUrl } /chat/completions` ;
254- }
255- return `${ baseUrl } /completions` ;
256- }
257- return baseUrl ;
258- }
259- default : {
260- const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE . replaceAll ( "{{PROVIDER}}" , "hf-inference" ) ;
261- if ( params . task && [ "feature-extraction" , "sentence-similarity" ] . includes ( params . task ) ) {
262- /// when deployed on hf-inference, those two tasks are automatically compatible with one another.
263- return `${ baseUrl } /pipeline/${ params . task } /${ params . model } ` ;
264- }
265- if ( params . task === "text-generation" && params . chatCompletion ) {
266- return `${ baseUrl } /models/${ params . model } /v1/chat/completions` ;
267- }
268- return `${ baseUrl } /models/${ params . model } ` ;
269- }
270- }
271- }
272158async function loadDefaultModel ( task : InferenceTask ) : Promise < string > {
273159 if ( ! tasks ) {
274160 tasks = await loadTaskInfo ( ) ;
0 commit comments