@@ -18,9 +18,10 @@ import { base64FromBytes } from "../utils/base64FromBytes.js";
18
18
19
19
import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks" ;
20
20
import { isUrl } from "../lib/isUrl.js" ;
21
- import type { BodyParams , HeaderParams , ModelId , RequestArgs , UrlParams } from "../types.js" ;
21
+ import type { BodyParams , HeaderParams , InferenceTask , ModelId , RequestArgs , UrlParams } from "../types.js" ;
22
22
import { delay } from "../utils/delay.js" ;
23
23
import { omit } from "../utils/omit.js" ;
24
+ import type { ImageToImageTaskHelper } from "./providerHelper.js" ;
24
25
import {
25
26
type AutomaticSpeechRecognitionTaskHelper ,
26
27
TaskProviderHelper ,
@@ -34,6 +35,7 @@ import {
34
35
InferenceClientProviderApiError ,
35
36
InferenceClientProviderOutputError ,
36
37
} from "../errors.js" ;
38
+ import type { ImageToImageArgs } from "../tasks/index.js" ;
37
39
38
40
export interface FalAiQueueOutput {
39
41
request_id : string ;
@@ -82,6 +84,75 @@ abstract class FalAITask extends TaskProviderHelper {
82
84
}
83
85
}
84
86
87
+ abstract class FalAiQueueTask extends FalAITask {
88
+ abstract task : InferenceTask ;
89
+
90
+ async getResponseFromQueueApi (
91
+ response : FalAiQueueOutput ,
92
+ url ?: string ,
93
+ headers ?: Record < string , string >
94
+ ) : Promise < unknown > {
95
+ if ( ! url || ! headers ) {
96
+ throw new InferenceClientInputError ( `URL and headers are required for ${ this . task } task` ) ;
97
+ }
98
+ const requestId = response . request_id ;
99
+ if ( ! requestId ) {
100
+ throw new InferenceClientProviderOutputError (
101
+ `Received malformed response from Fal.ai ${ this . task } API: no request ID found in the response`
102
+ ) ;
103
+ }
104
+ let status = response . status ;
105
+
106
+ const parsedUrl = new URL ( url ) ;
107
+ const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${
108
+ parsedUrl . host === "router.huggingface.co" ? "/fal-ai" : ""
109
+ } `;
110
+
111
+ // extracting the provider model id for status and result urls
112
+ // from the response as it might be different from the mapped model in `url`
113
+ const modelId = new URL ( response . response_url ) . pathname ;
114
+ const queryParams = parsedUrl . search ;
115
+
116
+ const statusUrl = `${ baseUrl } ${ modelId } /status${ queryParams } ` ;
117
+ const resultUrl = `${ baseUrl } ${ modelId } ${ queryParams } ` ;
118
+
119
+ while ( status !== "COMPLETED" ) {
120
+ await delay ( 500 ) ;
121
+ const statusResponse = await fetch ( statusUrl , { headers } ) ;
122
+
123
+ if ( ! statusResponse . ok ) {
124
+ throw new InferenceClientProviderApiError (
125
+ "Failed to fetch response status from fal-ai API" ,
126
+ { url : statusUrl , method : "GET" } ,
127
+ {
128
+ requestId : statusResponse . headers . get ( "x-request-id" ) ?? "" ,
129
+ status : statusResponse . status ,
130
+ body : await statusResponse . text ( ) ,
131
+ }
132
+ ) ;
133
+ }
134
+ try {
135
+ status = ( await statusResponse . json ( ) ) . status ;
136
+ } catch ( error ) {
137
+ throw new InferenceClientProviderOutputError (
138
+ "Failed to parse status response from fal-ai API: received malformed response"
139
+ ) ;
140
+ }
141
+ }
142
+
143
+ const resultResponse = await fetch ( resultUrl , { headers } ) ;
144
+ let result : unknown ;
145
+ try {
146
+ result = await resultResponse . json ( ) ;
147
+ } catch ( error ) {
148
+ throw new InferenceClientProviderOutputError (
149
+ "Failed to parse result response from fal-ai API: received malformed response"
150
+ ) ;
151
+ }
152
+ return result ;
153
+ }
154
+ }
155
+
85
156
function buildLoraPath ( modelId : ModelId , adapterWeightsPath : string ) : string {
86
157
return `${ HF_HUB_URL } /${ modelId } /resolve/main/${ adapterWeightsPath } ` ;
87
158
}
@@ -130,21 +201,29 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
130
201
}
131
202
}
132
203
133
- export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
204
+ export class FalAIImageToImageTask extends FalAiQueueTask implements ImageToImageTaskHelper {
205
+ task : InferenceTask ;
134
206
constructor ( ) {
135
207
super ( "https://queue.fal.run" ) ;
208
+ this . task = "image-to-image" ;
136
209
}
210
+
137
211
override makeRoute ( params : UrlParams ) : string {
138
212
if ( params . authMethod !== "provider-key" ) {
139
213
return `/${ params . model } ?_subdomain=queue` ;
140
214
}
141
215
return `/${ params . model } ` ;
142
216
}
143
- override preparePayload ( params : BodyParams ) : Record < string , unknown > {
217
+
218
+ async preparePayloadAsync ( args : ImageToImageArgs ) : Promise < RequestArgs > {
219
+ const mimeType = args . inputs instanceof Blob ? args . inputs . type : "image/png" ;
144
220
return {
145
- ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
146
- ...( params . args . parameters as Record < string , unknown > ) ,
147
- prompt : params . args . inputs ,
221
+ ...omit ( args , [ "inputs" , "parameters" ] ) ,
222
+ image_url : `data:${ mimeType } ;base64,${ base64FromBytes (
223
+ new Uint8Array ( args . inputs instanceof ArrayBuffer ? args . inputs : await ( args . inputs as Blob ) . arrayBuffer ( ) )
224
+ ) } `,
225
+ ...args . parameters ,
226
+ ...args ,
148
227
} ;
149
228
}
150
229
@@ -153,63 +232,59 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe
153
232
url ?: string ,
154
233
headers ?: Record < string , string >
155
234
) : Promise < Blob > {
156
- if ( ! url || ! headers ) {
157
- throw new InferenceClientInputError ( "URL and headers are required for text-to-video task" ) ;
158
- }
159
- const requestId = response . request_id ;
160
- if ( ! requestId ) {
235
+ const result = await this . getResponseFromQueueApi ( response , url , headers ) ;
236
+
237
+ if (
238
+ typeof result === "object" &&
239
+ ! ! result &&
240
+ "images" in result &&
241
+ Array . isArray ( result . images ) &&
242
+ result . images . length > 0 &&
243
+ typeof result . images [ 0 ] === "object" &&
244
+ ! ! result . images [ 0 ] &&
245
+ "url" in result . images [ 0 ] &&
246
+ typeof result . images [ 0 ] . url === "string" &&
247
+ isUrl ( result . images [ 0 ] . url )
248
+ ) {
249
+ const urlResponse = await fetch ( result . images [ 0 ] . url ) ;
250
+ return await urlResponse . blob ( ) ;
251
+ } else {
161
252
throw new InferenceClientProviderOutputError (
162
- "Received malformed response from Fal.ai text-to-video API: no request ID found in the response"
253
+ `Received malformed response from Fal.ai image-to-image API: expected { images: Array<{ url: string }> } result format, got instead: ${ JSON . stringify (
254
+ result
255
+ ) } `
163
256
) ;
164
257
}
165
- let status = response . status ;
166
-
167
- const parsedUrl = new URL ( url ) ;
168
- const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${
169
- parsedUrl . host === "router.huggingface.co" ? "/fal-ai" : ""
170
- } `;
171
-
172
- // extracting the provider model id for status and result urls
173
- // from the response as it might be different from the mapped model in `url`
174
- const modelId = new URL ( response . response_url ) . pathname ;
175
- const queryParams = parsedUrl . search ;
176
-
177
- const statusUrl = `${ baseUrl } ${ modelId } /status${ queryParams } ` ;
178
- const resultUrl = `${ baseUrl } ${ modelId } ${ queryParams } ` ;
179
-
180
- while ( status !== "COMPLETED" ) {
181
- await delay ( 500 ) ;
182
- const statusResponse = await fetch ( statusUrl , { headers } ) ;
258
+ }
259
+ }
183
260
184
- if ( ! statusResponse . ok ) {
185
- throw new InferenceClientProviderApiError (
186
- "Failed to fetch response status from fal-ai API" ,
187
- { url : statusUrl , method : "GET" } ,
188
- {
189
- requestId : statusResponse . headers . get ( "x-request-id" ) ?? "" ,
190
- status : statusResponse . status ,
191
- body : await statusResponse . text ( ) ,
192
- }
193
- ) ;
194
- }
195
- try {
196
- status = ( await statusResponse . json ( ) ) . status ;
197
- } catch ( error ) {
198
- throw new InferenceClientProviderOutputError (
199
- "Failed to parse status response from fal-ai API: received malformed response"
200
- ) ;
201
- }
261
+ export class FalAITextToVideoTask extends FalAiQueueTask implements TextToVideoTaskHelper {
262
+ task : InferenceTask ;
263
+ constructor ( ) {
264
+ super ( "https://queue.fal.run" ) ;
265
+ this . task = "text-to-video" ;
266
+ }
267
+ override makeRoute ( params : UrlParams ) : string {
268
+ if ( params . authMethod !== "provider-key" ) {
269
+ return `/${ params . model } ?_subdomain=queue` ;
202
270
}
271
+ return `/${ params . model } ` ;
272
+ }
273
+ override preparePayload ( params : BodyParams ) : Record < string , unknown > {
274
+ return {
275
+ ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
276
+ ...( params . args . parameters as Record < string , unknown > ) ,
277
+ prompt : params . args . inputs ,
278
+ } ;
279
+ }
280
+
281
+ override async getResponse (
282
+ response : FalAiQueueOutput ,
283
+ url ?: string ,
284
+ headers ?: Record < string , string >
285
+ ) : Promise < Blob > {
286
+ const result = await this . getResponseFromQueueApi ( response , url , headers ) ;
203
287
204
- const resultResponse = await fetch ( resultUrl , { headers } ) ;
205
- let result : unknown ;
206
- try {
207
- result = await resultResponse . json ( ) ;
208
- } catch ( error ) {
209
- throw new InferenceClientProviderOutputError (
210
- "Failed to parse result response from fal-ai API: received malformed response"
211
- ) ;
212
- }
213
288
if (
214
289
typeof result === "object" &&
215
290
! ! result &&
0 commit comments