@@ -18,9 +18,10 @@ import { base64FromBytes } from "../utils/base64FromBytes.js";
1818
1919import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks" ;
2020import { isUrl } from "../lib/isUrl.js" ;
21- import type { BodyParams , HeaderParams , ModelId , RequestArgs , UrlParams } from "../types.js" ;
21+ import type { BodyParams , HeaderParams , InferenceTask , ModelId , RequestArgs , UrlParams } from "../types.js" ;
2222import { delay } from "../utils/delay.js" ;
2323import { omit } from "../utils/omit.js" ;
24+ import type { ImageToImageTaskHelper } from "./providerHelper.js" ;
2425import {
2526 type AutomaticSpeechRecognitionTaskHelper ,
2627 TaskProviderHelper ,
@@ -34,6 +35,7 @@ import {
3435 InferenceClientProviderApiError ,
3536 InferenceClientProviderOutputError ,
3637} from "../errors.js" ;
38+ import type { ImageToImageArgs } from "../tasks/index.js" ;
3739
3840export interface FalAiQueueOutput {
3941 request_id : string ;
@@ -82,6 +84,75 @@ abstract class FalAITask extends TaskProviderHelper {
8284 }
8385}
8486
87+ abstract class FalAiQueueTask extends FalAITask {
88+ abstract task : InferenceTask ;
89+
90+ async getResponseFromQueueApi (
91+ response : FalAiQueueOutput ,
92+ url ?: string ,
93+ headers ?: Record < string , string >
94+ ) : Promise < unknown > {
95+ if ( ! url || ! headers ) {
96+ throw new InferenceClientInputError ( `URL and headers are required for ${ this . task } task` ) ;
97+ }
98+ const requestId = response . request_id ;
99+ if ( ! requestId ) {
100+ throw new InferenceClientProviderOutputError (
101+ `Received malformed response from Fal.ai ${ this . task } API: no request ID found in the response`
102+ ) ;
103+ }
104+ let status = response . status ;
105+
106+ const parsedUrl = new URL ( url ) ;
107+ const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${
108+ parsedUrl . host === "router.huggingface.co" ? "/fal-ai" : ""
109+ } `;
110+
111+ // extracting the provider model id for status and result urls
112+ // from the response as it might be different from the mapped model in `url`
113+ const modelId = new URL ( response . response_url ) . pathname ;
114+ const queryParams = parsedUrl . search ;
115+
116+ const statusUrl = `${ baseUrl } ${ modelId } /status${ queryParams } ` ;
117+ const resultUrl = `${ baseUrl } ${ modelId } ${ queryParams } ` ;
118+
119+ while ( status !== "COMPLETED" ) {
120+ await delay ( 500 ) ;
121+ const statusResponse = await fetch ( statusUrl , { headers } ) ;
122+
123+ if ( ! statusResponse . ok ) {
124+ throw new InferenceClientProviderApiError (
125+ "Failed to fetch response status from fal-ai API" ,
126+ { url : statusUrl , method : "GET" } ,
127+ {
128+ requestId : statusResponse . headers . get ( "x-request-id" ) ?? "" ,
129+ status : statusResponse . status ,
130+ body : await statusResponse . text ( ) ,
131+ }
132+ ) ;
133+ }
134+ try {
135+ status = ( await statusResponse . json ( ) ) . status ;
136+ } catch ( error ) {
137+ throw new InferenceClientProviderOutputError (
138+ "Failed to parse status response from fal-ai API: received malformed response"
139+ ) ;
140+ }
141+ }
142+
143+ const resultResponse = await fetch ( resultUrl , { headers } ) ;
144+ let result : unknown ;
145+ try {
146+ result = await resultResponse . json ( ) ;
147+ } catch ( error ) {
148+ throw new InferenceClientProviderOutputError (
149+ "Failed to parse result response from fal-ai API: received malformed response"
150+ ) ;
151+ }
152+ return result ;
153+ }
154+ }
155+
85156function buildLoraPath ( modelId : ModelId , adapterWeightsPath : string ) : string {
86157 return `${ HF_HUB_URL } /${ modelId } /resolve/main/${ adapterWeightsPath } ` ;
87158}
@@ -130,21 +201,42 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
130201 }
131202}
132203
133- export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
204+ export class FalAIImageToImageTask extends FalAiQueueTask implements ImageToImageTaskHelper {
205+ task : InferenceTask ;
134206 constructor ( ) {
135207 super ( "https://queue.fal.run" ) ;
208+ this . task = "image-to-image" ;
136209 }
210+
137211 override makeRoute ( params : UrlParams ) : string {
138212 if ( params . authMethod !== "provider-key" ) {
139213 return `/${ params . model } ?_subdomain=queue` ;
140214 }
141215 return `/${ params . model } ` ;
142216 }
217+
143218 override preparePayload ( params : BodyParams ) : Record < string , unknown > {
219+ const payload = params . args ;
220+ if ( params . mapping ?. adapter === "lora" && params . mapping . adapterWeightsPath ) {
221+ payload . loras = [
222+ {
223+ path : buildLoraPath ( params . mapping . hfModelId , params . mapping . adapterWeightsPath ) ,
224+ scale : 1 ,
225+ } ,
226+ ] ;
227+ }
228+ return payload ;
229+ }
230+
231+ async preparePayloadAsync ( args : ImageToImageArgs ) : Promise < RequestArgs > {
232+ const mimeType = args . inputs instanceof Blob ? args . inputs . type : "image/png" ;
144233 return {
145- ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
146- ...( params . args . parameters as Record < string , unknown > ) ,
147- prompt : params . args . inputs ,
234+ ...omit ( args , [ "inputs" , "parameters" ] ) ,
235+ image_url : `data:${ mimeType } ;base64,${ base64FromBytes (
236+ new Uint8Array ( args . inputs instanceof ArrayBuffer ? args . inputs : await ( args . inputs as Blob ) . arrayBuffer ( ) )
237+ ) } `,
238+ ...args . parameters ,
239+ ...args ,
148240 } ;
149241 }
150242
@@ -153,63 +245,59 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe
153245 url ?: string ,
154246 headers ?: Record < string , string >
155247 ) : Promise < Blob > {
156- if ( ! url || ! headers ) {
157- throw new InferenceClientInputError ( "URL and headers are required for text-to-video task" ) ;
158- }
159- const requestId = response . request_id ;
160- if ( ! requestId ) {
248+ const result = await this . getResponseFromQueueApi ( response , url , headers ) ;
249+
250+ if (
251+ typeof result === "object" &&
252+ ! ! result &&
253+ "images" in result &&
254+ Array . isArray ( result . images ) &&
255+ result . images . length > 0 &&
256+ typeof result . images [ 0 ] === "object" &&
257+ ! ! result . images [ 0 ] &&
258+ "url" in result . images [ 0 ] &&
259+ typeof result . images [ 0 ] . url === "string" &&
260+ isUrl ( result . images [ 0 ] . url )
261+ ) {
262+ const urlResponse = await fetch ( result . images [ 0 ] . url ) ;
263+ return await urlResponse . blob ( ) ;
264+ } else {
161265 throw new InferenceClientProviderOutputError (
162- "Received malformed response from Fal.ai text-to-video API: no request ID found in the response"
266+ `Received malformed response from Fal.ai image-to-image API: expected { images: Array<{ url: string }> } result format, got instead: ${ JSON . stringify (
267+ result
268+ ) } `
163269 ) ;
164270 }
165- let status = response . status ;
166-
167- const parsedUrl = new URL ( url ) ;
168- const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${
169- parsedUrl . host === "router.huggingface.co" ? "/fal-ai" : ""
170- } `;
171-
172- // extracting the provider model id for status and result urls
173- // from the response as it might be different from the mapped model in `url`
174- const modelId = new URL ( response . response_url ) . pathname ;
175- const queryParams = parsedUrl . search ;
176-
177- const statusUrl = `${ baseUrl } ${ modelId } /status${ queryParams } ` ;
178- const resultUrl = `${ baseUrl } ${ modelId } ${ queryParams } ` ;
179-
180- while ( status !== "COMPLETED" ) {
181- await delay ( 500 ) ;
182- const statusResponse = await fetch ( statusUrl , { headers } ) ;
271+ }
272+ }
183273
184- if ( ! statusResponse . ok ) {
185- throw new InferenceClientProviderApiError (
186- "Failed to fetch response status from fal-ai API" ,
187- { url : statusUrl , method : "GET" } ,
188- {
189- requestId : statusResponse . headers . get ( "x-request-id" ) ?? "" ,
190- status : statusResponse . status ,
191- body : await statusResponse . text ( ) ,
192- }
193- ) ;
194- }
195- try {
196- status = ( await statusResponse . json ( ) ) . status ;
197- } catch ( error ) {
198- throw new InferenceClientProviderOutputError (
199- "Failed to parse status response from fal-ai API: received malformed response"
200- ) ;
201- }
274+ export class FalAITextToVideoTask extends FalAiQueueTask implements TextToVideoTaskHelper {
275+ task : InferenceTask ;
276+ constructor ( ) {
277+ super ( "https://queue.fal.run" ) ;
278+ this . task = "text-to-video" ;
279+ }
280+ override makeRoute ( params : UrlParams ) : string {
281+ if ( params . authMethod !== "provider-key" ) {
282+ return `/${ params . model } ?_subdomain=queue` ;
202283 }
284+ return `/${ params . model } ` ;
285+ }
286+ override preparePayload ( params : BodyParams ) : Record < string , unknown > {
287+ return {
288+ ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
289+ ...( params . args . parameters as Record < string , unknown > ) ,
290+ prompt : params . args . inputs ,
291+ } ;
292+ }
293+
294+ override async getResponse (
295+ response : FalAiQueueOutput ,
296+ url ?: string ,
297+ headers ?: Record < string , string >
298+ ) : Promise < Blob > {
299+ const result = await this . getResponseFromQueueApi ( response , url , headers ) ;
203300
204- const resultResponse = await fetch ( resultUrl , { headers } ) ;
205- let result : unknown ;
206- try {
207- result = await resultResponse . json ( ) ;
208- } catch ( error ) {
209- throw new InferenceClientProviderOutputError (
210- "Failed to parse result response from fal-ai API: received malformed response"
211- ) ;
212- }
213301 if (
214302 typeof result === "object" &&
215303 ! ! result &&
0 commit comments