@@ -18,7 +18,7 @@ import { base64FromBytes } from "../utils/base64FromBytes.js";
1818
1919import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks" ;
2020import { isUrl } from "../lib/isUrl.js" ;
21- import type { BodyParams , HeaderParams , ModelId , RequestArgs , UrlParams } from "../types.js" ;
21+ import type { BodyParams , HeaderParams , InferenceTask , ModelId , RequestArgs , UrlParams } from "../types.js" ;
2222import { delay } from "../utils/delay.js" ;
2323import { omit } from "../utils/omit.js" ;
2424import type { ImageToImageTaskHelper } from "./providerHelper.js" ;
@@ -84,6 +84,75 @@ abstract class FalAITask extends TaskProviderHelper {
8484 }
8585}
8686
87+ abstract class FalAiQueueTask extends FalAITask {
88+ abstract task : InferenceTask ;
89+
90+ async getResponseFromQueueApi (
91+ response : FalAiQueueOutput ,
92+ url ?: string ,
93+ headers ?: Record < string , string >
94+ ) : Promise < unknown > {
95+ if ( ! url || ! headers ) {
96+ throw new InferenceClientInputError ( `URL and headers are required for ${ this . task } task` ) ;
97+ }
98+ const requestId = response . request_id ;
99+ if ( ! requestId ) {
100+ throw new InferenceClientProviderOutputError (
101+ `Received malformed response from Fal.ai ${ this . task } API: no request ID found in the response`
102+ ) ;
103+ }
104+ let status = response . status ;
105+
106+ const parsedUrl = new URL ( url ) ;
107+ const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${
108+ parsedUrl . host === "router.huggingface.co" ? "/fal-ai" : ""
109+ } `;
110+
111+ // extracting the provider model id for status and result urls
112+ // from the response as it might be different from the mapped model in `url`
113+ const modelId = new URL ( response . response_url ) . pathname ;
114+ const queryParams = parsedUrl . search ;
115+
116+ const statusUrl = `${ baseUrl } ${ modelId } /status${ queryParams } ` ;
117+ const resultUrl = `${ baseUrl } ${ modelId } ${ queryParams } ` ;
118+
119+ while ( status !== "COMPLETED" ) {
120+ await delay ( 500 ) ;
121+ const statusResponse = await fetch ( statusUrl , { headers } ) ;
122+
123+ if ( ! statusResponse . ok ) {
124+ throw new InferenceClientProviderApiError (
125+ "Failed to fetch response status from fal-ai API" ,
126+ { url : statusUrl , method : "GET" } ,
127+ {
128+ requestId : statusResponse . headers . get ( "x-request-id" ) ?? "" ,
129+ status : statusResponse . status ,
130+ body : await statusResponse . text ( ) ,
131+ }
132+ ) ;
133+ }
134+ try {
135+ status = ( await statusResponse . json ( ) ) . status ;
136+ } catch ( error ) {
137+ throw new InferenceClientProviderOutputError (
138+ "Failed to parse status response from fal-ai API: received malformed response"
139+ ) ;
140+ }
141+ }
142+
143+ const resultResponse = await fetch ( resultUrl , { headers } ) ;
144+ let result : unknown ;
145+ try {
146+ result = await resultResponse . json ( ) ;
147+ } catch ( error ) {
148+ throw new InferenceClientProviderOutputError (
149+ "Failed to parse result response from fal-ai API: received malformed response"
150+ ) ;
151+ }
152+ return result ;
153+ }
154+ }
155+
87156function buildLoraPath ( modelId : ModelId , adapterWeightsPath : string ) : string {
88157 return `${ HF_HUB_URL } /${ modelId } /resolve/main/${ adapterWeightsPath } ` ;
89158}
@@ -132,9 +201,11 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
132201 }
133202}
134203
135- export class FalAIImageToImageTask extends FalAITask implements ImageToImageTaskHelper {
204+ export class FalAIImageToImageTask extends FalAiQueueTask implements ImageToImageTaskHelper {
205+ task : InferenceTask ;
136206 constructor ( ) {
137207 super ( "https://queue.fal.run" ) ;
208+ this . task = "image-to-image" ;
138209 }
139210
140211 override makeRoute ( params : UrlParams ) : string {
@@ -161,63 +232,8 @@ export class FalAIImageToImageTask extends FalAITask implements ImageToImageTask
161232 url ?: string ,
162233 headers ?: Record < string , string >
163234 ) : Promise < Blob > {
164- if ( ! url || ! headers ) {
165- throw new InferenceClientInputError ( "URL and headers are required for image-to-image task" ) ;
166- }
167- const requestId = response . request_id ;
168- if ( ! requestId ) {
169- throw new InferenceClientProviderOutputError (
170- "Received malformed response from Fal.ai text-to-video API: no request ID found in the response"
171- ) ;
172- }
173- let status = response . status ;
174-
175- const parsedUrl = new URL ( url ) ;
176- const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${ parsedUrl . host === "router.huggingface.co" ? "/fal-ai" : ""
177- } `;
178-
179- // extracting the provider model id for status and result urls
180- // from the response as it might be different from the mapped model in `url`
181- const modelId = new URL ( response . response_url ) . pathname ;
182- const queryParams = parsedUrl . search ;
183-
184- const statusUrl = `${ baseUrl } ${ modelId } /status${ queryParams } ` ;
185- const resultUrl = `${ baseUrl } ${ modelId } ${ queryParams } ` ;
186-
187- while ( status !== "COMPLETED" ) {
188- await delay ( 500 ) ;
189- const statusResponse = await fetch ( statusUrl , { headers } ) ;
235+ const result = await this . getResponseFromQueueApi ( response , url , headers ) ;
190236
191- if ( ! statusResponse . ok ) {
192- throw new InferenceClientProviderApiError (
193- "Failed to fetch response status from fal-ai API" ,
194- { url : statusUrl , method : "GET" } ,
195- {
196- requestId : statusResponse . headers . get ( "x-request-id" ) ?? "" ,
197- status : statusResponse . status ,
198- body : await statusResponse . text ( ) ,
199- }
200- ) ;
201- }
202- try {
203- status = ( await statusResponse . json ( ) ) . status ;
204- } catch ( error ) {
205- throw new InferenceClientProviderOutputError (
206- "Failed to parse status response from fal-ai API: received malformed response"
207- ) ;
208- }
209- }
210-
211- const resultResponse = await fetch ( resultUrl , { headers } ) ;
212- let result : unknown ;
213- try {
214- result = await resultResponse . json ( ) ;
215- } catch ( error ) {
216- throw new InferenceClientProviderOutputError (
217- "Failed to parse result response from fal-ai API: received malformed response"
218- ) ;
219- }
220- console . log ( "result" , result ) ;
221237 if (
222238 typeof result === "object" &&
223239 ! ! result &&
@@ -242,9 +258,11 @@ export class FalAIImageToImageTask extends FalAITask implements ImageToImageTask
242258 }
243259}
244260
245- export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
261+ export class FalAITextToVideoTask extends FalAiQueueTask implements TextToVideoTaskHelper {
262+ task : InferenceTask ;
246263 constructor ( ) {
247264 super ( "https://queue.fal.run" ) ;
265+ this . task = "text-to-video" ;
248266 }
249267 override makeRoute ( params : UrlParams ) : string {
250268 if ( params . authMethod !== "provider-key" ) {
@@ -265,62 +283,8 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe
265283 url ?: string ,
266284 headers ?: Record < string , string >
267285 ) : Promise < Blob > {
268- if ( ! url || ! headers ) {
269- throw new InferenceClientInputError ( "URL and headers are required for text-to-video task" ) ;
270- }
271- const requestId = response . request_id ;
272- if ( ! requestId ) {
273- throw new InferenceClientProviderOutputError (
274- "Received malformed response from Fal.ai text-to-video API: no request ID found in the response"
275- ) ;
276- }
277- let status = response . status ;
278-
279- const parsedUrl = new URL ( url ) ;
280- const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${ parsedUrl . host === "router.huggingface.co" ? "/fal-ai" : ""
281- } `;
282-
283- // extracting the provider model id for status and result urls
284- // from the response as it might be different from the mapped model in `url`
285- const modelId = new URL ( response . response_url ) . pathname ;
286- const queryParams = parsedUrl . search ;
287-
288- const statusUrl = `${ baseUrl } ${ modelId } /status${ queryParams } ` ;
289- const resultUrl = `${ baseUrl } ${ modelId } ${ queryParams } ` ;
290-
291- while ( status !== "COMPLETED" ) {
292- await delay ( 500 ) ;
293- const statusResponse = await fetch ( statusUrl , { headers } ) ;
294-
295- if ( ! statusResponse . ok ) {
296- throw new InferenceClientProviderApiError (
297- "Failed to fetch response status from fal-ai API" ,
298- { url : statusUrl , method : "GET" } ,
299- {
300- requestId : statusResponse . headers . get ( "x-request-id" ) ?? "" ,
301- status : statusResponse . status ,
302- body : await statusResponse . text ( ) ,
303- }
304- ) ;
305- }
306- try {
307- status = ( await statusResponse . json ( ) ) . status ;
308- } catch ( error ) {
309- throw new InferenceClientProviderOutputError (
310- "Failed to parse status response from fal-ai API: received malformed response"
311- ) ;
312- }
313- }
286+ const result = await this . getResponseFromQueueApi ( response , url , headers ) ;
314287
315- const resultResponse = await fetch ( resultUrl , { headers } ) ;
316- let result : unknown ;
317- try {
318- result = await resultResponse . json ( ) ;
319- } catch ( error ) {
320- throw new InferenceClientProviderOutputError (
321- "Failed to parse result response from fal-ai API: received malformed response"
322- ) ;
323- }
324288 if (
325289 typeof result === "object" &&
326290 ! ! result &&
0 commit comments