@@ -2,8 +2,8 @@ import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
22import { REPLICATE_API_BASE_URL , REPLICATE_MODEL_IDS } from "../providers/replicate" ;
33import { SAMBANOVA_API_BASE_URL , SAMBANOVA_MODEL_IDS } from "../providers/sambanova" ;
44import { TOGETHER_API_BASE_URL , TOGETHER_MODEL_IDS } from "../providers/together" ;
5- import { INFERENCE_PROVIDERS , type InferenceTask , type Options , type RequestArgs } from "../types" ;
6- import { omit } from "../utils/omit " ;
5+ import type { InferenceProvider } from "../types" ;
6+ import type { InferenceTask , Options , RequestArgs } from "../types " ;
77import { HF_HUB_URL } from "./getDefaultTask" ;
88import { isUrl } from "./isUrl" ;
99
@@ -31,62 +31,49 @@ export async function makeRequestOptions(
3131 chatCompletion ?: boolean ;
3232 }
3333) : Promise < { url : string ; info : RequestInit } > {
34- const { accessToken, endpointUrl, provider, ...otherArgs } = args ;
35- let { model } = args ;
34+ const { accessToken, endpointUrl, provider : maybeProvider , model : maybeModel , ...otherArgs } = args ;
35+ const provider = maybeProvider ?? "hf-inference" ;
36+
3637 const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } =
3738 options ?? { } ;
3839
39- const headers : Record < string , string > = { } ;
40- if ( accessToken ) {
41- headers [ "Authorization" ] = provider === "fal-ai" ? `Key ${ accessToken } ` : `Bearer ${ accessToken } ` ;
40+ if ( endpointUrl && provider !== "hf-inference" ) {
41+ throw new Error ( `Cannot use endpointUrl with a third-party provider.` ) ;
4242 }
43-
44- if ( ! model && ! tasks && taskHint ) {
45- const res = await fetch ( `${ HF_HUB_URL } /api/tasks` ) ;
46-
47- if ( res . ok ) {
48- tasks = await res . json ( ) ;
49- }
43+ if ( forceTask && provider !== "hf-inference" ) {
44+ throw new Error ( `Cannot use forceTask with a third-party provider.` ) ;
5045 }
51-
52- if ( ! model && tasks && taskHint ) {
53- const taskInfo = tasks [ taskHint ] ;
54- if ( taskInfo ) {
55- model = taskInfo . models [ 0 ] . id ;
56- }
46+ if ( maybeModel && isUrl ( maybeModel ) ) {
47+ throw new Error ( `Model URLs are no longer supported. Use endpointUrl instead.` ) ;
5748 }
5849
59- if ( ! model ) {
60- throw new Error ( "No model provided, and no default model found for this task" ) ;
61- }
62- if ( provider ) {
63- if ( ! INFERENCE_PROVIDERS . includes ( provider ) ) {
64- throw new Error ( "Unknown Inference provider" ) ;
65- }
66- if ( ! accessToken ) {
67- throw new Error ( "Specifying an Inference provider requires an accessToken" ) ;
50+ let model : string ;
51+ if ( ! maybeModel ) {
52+ if ( taskHint ) {
53+ model = mapModel ( { model : await loadDefaultModel ( taskHint ) , provider } ) ;
54+ } else {
55+ throw new Error ( "No model provided, and no default model found for this task" ) ;
56+ /// TODO : change error message ^
6857 }
58+ } else {
59+ model = mapModel ( { model : maybeModel , provider } ) ;
60+ }
6961
70- const modelId = ( ( ) => {
71- switch ( provider ) {
72- case "replicate" :
73- return REPLICATE_MODEL_IDS [ model ] ;
74- case "sambanova" :
75- return SAMBANOVA_MODEL_IDS [ model ] ;
76- case "together" :
77- return TOGETHER_MODEL_IDS [ model ] ?. id ;
78- case "fal-ai" :
79- return FAL_AI_MODEL_IDS [ model ] ;
80- default :
81- return model ;
82- }
83- } ) ( ) ;
84-
85- if ( ! modelId ) {
86- throw new Error ( `Model ${ model } is not supported for provider ${ provider } ` ) ;
87- }
62+ const url = endpointUrl
63+ ? chatCompletion
64+ ? endpointUrl + `/v1/chat/completions`
65+ : endpointUrl
66+ : makeUrl ( {
67+ model,
68+ provider : provider ?? "hf-inference" ,
69+ taskHint,
70+ chatCompletion : chatCompletion ?? false ,
71+ forceTask,
72+ } ) ;
8873
89- model = modelId ;
74+ const headers : Record < string , string > = { } ;
75+ if ( accessToken ) {
76+ headers [ "Authorization" ] = provider === "fal-ai" ? `Key ${ accessToken } ` : `Bearer ${ accessToken } ` ;
9077 }
9178
9279 const binary = "data" in args && ! ! args . data ;
@@ -95,73 +82,20 @@ export async function makeRequestOptions(
9582 headers [ "Content-Type" ] = "application/json" ;
9683 }
9784
98- if ( wait_for_model ) {
99- headers [ "X-Wait-For-Model" ] = "true" ;
100- }
101- if ( use_cache === false ) {
102- headers [ "X-Use-Cache" ] = "false" ;
103- }
104- if ( dont_load_model ) {
105- headers [ "X-Load-Model" ] = "0" ;
106- }
107- if ( provider === "replicate" ) {
108- headers [ "Prefer" ] = "wait" ;
109- }
110-
111- let url = ( ( ) => {
112- if ( endpointUrl && isUrl ( model ) ) {
113- throw new TypeError ( "Both model and endpointUrl cannot be URLs" ) ;
114- }
115- if ( isUrl ( model ) ) {
116- console . warn ( "Using a model URL is deprecated, please use the `endpointUrl` parameter instead" ) ;
117- return model ;
118- }
119- if ( endpointUrl ) {
120- return endpointUrl ;
85+ if ( provider === "hf-inference" ) {
86+ if ( wait_for_model ) {
87+ headers [ "X-Wait-For-Model" ] = "true" ;
12188 }
122- if ( forceTask ) {
123- return ` ${ HF_INFERENCE_API_BASE_URL } /pipeline/ ${ forceTask } / ${ model } ` ;
89+ if ( use_cache === false ) {
90+ headers [ "X-Use-Cache" ] = "false" ;
12491 }
125- if ( provider ) {
126- if ( ! accessToken ) {
127- throw new Error ( "Specifying an Inference provider requires an accessToken" ) ;
128- }
129- if ( accessToken . startsWith ( "hf_" ) ) {
130- /// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
131- throw new Error ( "Inference proxying is not implemented yet" ) ;
132- } else {
133- switch ( provider ) {
134- case "fal-ai" :
135- return `${ FAL_AI_API_BASE_URL } /${ model } ` ;
136- case "replicate" :
137- if ( model . includes ( ":" ) ) {
138- // Versioned models are in the form of `owner/model:version`
139- return `${ REPLICATE_API_BASE_URL } /v1/predictions` ;
140- } else {
141- // Unversioned models are in the form of `owner/model`
142- return `${ REPLICATE_API_BASE_URL } /v1/models/${ model } /predictions` ;
143- }
144- case "sambanova" :
145- return SAMBANOVA_API_BASE_URL ;
146- case "together" :
147- if ( taskHint === "text-to-image" ) {
148- return `${ TOGETHER_API_BASE_URL } /v1/images/generations` ;
149- }
150- return TOGETHER_API_BASE_URL ;
151- default :
152- break ;
153- }
154- }
92+ if ( dont_load_model ) {
93+ headers [ "X-Load-Model" ] = "0" ;
15594 }
156-
157- return `${ HF_INFERENCE_API_BASE_URL } /models/${ model } ` ;
158- } ) ( ) ;
159-
160- if ( chatCompletion && ! url . endsWith ( "/chat/completions" ) ) {
161- url += "/v1/chat/completions" ;
16295 }
163- if ( provider === "together" && taskHint === "text-generation" && ! chatCompletion ) {
164- url += "/v1/completions" ;
96+
97+ if ( provider === "replicate" ) {
98+ headers [ "Prefer" ] = "wait" ;
16599 }
166100
167101 /**
@@ -188,13 +122,102 @@ export async function makeRequestOptions(
188122 body : binary
189123 ? args . data
190124 : JSON . stringify ( {
191- ...( ( otherArgs . model && isUrl ( otherArgs . model ) ) || provider === "replicate" || provider === "fal-ai"
192- ? omit ( otherArgs , "model" )
193- : { ...otherArgs , model } ) ,
125+ ...otherArgs ,
126+ ...( chatCompletion || provider === "together" ? { model } : undefined ) ,
194127 } ) ,
195128 ...( credentials ? { credentials } : undefined ) ,
196129 signal : options ?. signal ,
197130 } ;
198131
199132 return { url, info } ;
200133}
134+
135+ function mapModel ( params : { model : string ; provider : InferenceProvider } ) : string {
136+ const model = ( ( ) => {
137+ switch ( params . provider ) {
138+ case "fal-ai" :
139+ return FAL_AI_MODEL_IDS [ params . model ] ;
140+ case "replicate" :
141+ return REPLICATE_MODEL_IDS [ params . model ] ;
142+ case "sambanova" :
143+ return SAMBANOVA_MODEL_IDS [ params . model ] ;
144+ case "together" :
145+ return TOGETHER_MODEL_IDS [ params . model ] ?. id ;
146+ case "hf-inference" :
147+ return params . model ;
148+ }
149+ } ) ( ) ;
150+
151+ if ( ! model ) {
152+ throw new Error ( `Model ${ params . model } is not supported for provider ${ params . provider } ` ) ;
153+ }
154+ return model ;
155+ }
156+
157+ function makeUrl ( params : {
158+ model : string ;
159+ provider : InferenceProvider ;
160+ taskHint : InferenceTask | undefined ;
161+ chatCompletion : boolean ;
162+ forceTask ?: string | InferenceTask ;
163+ } ) : string {
164+ switch ( params . provider ) {
165+ case "fal-ai" :
166+ return `${ FAL_AI_API_BASE_URL } /${ params . model } ` ;
167+ case "replicate" : {
168+ if ( params . model . includes ( ":" ) ) {
169+ /// Versioned model
170+ return `${ REPLICATE_API_BASE_URL } /v1/predictions` ;
171+ }
172+ /// Evergreen / Canonical model
173+ return `${ REPLICATE_API_BASE_URL } /v1/models/${ params . model } /predictions` ;
174+ }
175+ case "sambanova" :
176+ /// Sambanova API matches OpenAI-like APIs: model is defined in the request body
177+ if ( params . taskHint === "text-generation" && params . chatCompletion ) {
178+ return `${ SAMBANOVA_API_BASE_URL } /v1/chat/completions` ;
179+ }
180+ return SAMBANOVA_API_BASE_URL ;
181+ case "together" : {
182+ /// Together API matches OpenAI-like APIs: model is defined in the request body
183+ if ( params . taskHint === "text-to-image" ) {
184+ return `${ TOGETHER_API_BASE_URL } /v1/images/generations` ;
185+ }
186+ if ( params . taskHint === "text-generation" ) {
187+ if ( params . chatCompletion ) {
188+ return `${ TOGETHER_API_BASE_URL } /v1/chat/completions` ;
189+ }
190+ return `${ TOGETHER_API_BASE_URL } /v1/completions` ;
191+ }
192+ return TOGETHER_API_BASE_URL ;
193+ }
194+ default : {
195+ const url = params . forceTask
196+ ? `${ HF_INFERENCE_API_BASE_URL } /pipeline/${ params . forceTask } /${ params . model } `
197+ : `${ HF_INFERENCE_API_BASE_URL } /models/${ params . model } ` ;
198+ if ( params . taskHint === "text-generation" && params . chatCompletion ) {
199+ return url + `/v1/chat/completions` ;
200+ }
201+ return url ;
202+ }
203+ }
204+ }
205+ async function loadDefaultModel ( task : InferenceTask ) : Promise < string > {
206+ if ( ! tasks ) {
207+ tasks = await loadTaskInfo ( ) ;
208+ }
209+ const taskInfo = tasks [ task ] ;
210+ if ( ( taskInfo ?. models . length ?? 0 ) <= 0 ) {
211+ throw new Error ( `No default model defined for task ${ task } , please define the model explicitly.` ) ;
212+ }
213+ return taskInfo . models [ 0 ] . id ;
214+ }
215+
216+ async function loadTaskInfo ( ) : Promise < Record < string , { models : { id : string } [ ] } > > {
217+ const res = await fetch ( `${ HF_HUB_URL } /api/tasks` ) ;
218+
219+ if ( ! res . ok ) {
220+ throw new Error ( "Failed to load tasks definitions from Hugging Face Hub." ) ;
221+ }
222+ return await res . json ( ) ;
223+ }
0 commit comments