Skip to content

Commit 904b6da

Browse files
committed
factor
1 parent e7a27f3 commit 904b6da

File tree

1 file changed

+78
-114
lines changed

1 file changed

+78
-114
lines changed

packages/inference/src/providers/fal-ai.ts

Lines changed: 78 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import { base64FromBytes } from "../utils/base64FromBytes.js";
1818

1919
import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
2020
import { isUrl } from "../lib/isUrl.js";
21-
import type { BodyParams, HeaderParams, ModelId, RequestArgs, UrlParams } from "../types.js";
21+
import type { BodyParams, HeaderParams, InferenceTask, ModelId, RequestArgs, UrlParams } from "../types.js";
2222
import { delay } from "../utils/delay.js";
2323
import { omit } from "../utils/omit.js";
2424
import type { ImageToImageTaskHelper } from "./providerHelper.js";
@@ -84,6 +84,75 @@ abstract class FalAITask extends TaskProviderHelper {
8484
}
8585
}
8686

87+
abstract class FalAiQueueTask extends FalAITask {
88+
abstract task: InferenceTask;
89+
90+
async getResponseFromQueueApi(
91+
response: FalAiQueueOutput,
92+
url?: string,
93+
headers?: Record<string, string>
94+
): Promise<unknown> {
95+
if (!url || !headers) {
96+
throw new InferenceClientInputError(`URL and headers are required for ${this.task} task`);
97+
}
98+
const requestId = response.request_id;
99+
if (!requestId) {
100+
throw new InferenceClientProviderOutputError(
101+
`Received malformed response from Fal.ai ${this.task} API: no request ID found in the response`
102+
);
103+
}
104+
let status = response.status;
105+
106+
const parsedUrl = new URL(url);
107+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
108+
parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
109+
}`;
110+
111+
// extracting the provider model id for status and result urls
112+
// from the response as it might be different from the mapped model in `url`
113+
const modelId = new URL(response.response_url).pathname;
114+
const queryParams = parsedUrl.search;
115+
116+
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
117+
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
118+
119+
while (status !== "COMPLETED") {
120+
await delay(500);
121+
const statusResponse = await fetch(statusUrl, { headers });
122+
123+
if (!statusResponse.ok) {
124+
throw new InferenceClientProviderApiError(
125+
"Failed to fetch response status from fal-ai API",
126+
{ url: statusUrl, method: "GET" },
127+
{
128+
requestId: statusResponse.headers.get("x-request-id") ?? "",
129+
status: statusResponse.status,
130+
body: await statusResponse.text(),
131+
}
132+
);
133+
}
134+
try {
135+
status = (await statusResponse.json()).status;
136+
} catch (error) {
137+
throw new InferenceClientProviderOutputError(
138+
"Failed to parse status response from fal-ai API: received malformed response"
139+
);
140+
}
141+
}
142+
143+
const resultResponse = await fetch(resultUrl, { headers });
144+
let result: unknown;
145+
try {
146+
result = await resultResponse.json();
147+
} catch (error) {
148+
throw new InferenceClientProviderOutputError(
149+
"Failed to parse result response from fal-ai API: received malformed response"
150+
);
151+
}
152+
return result;
153+
}
154+
}
155+
87156
function buildLoraPath(modelId: ModelId, adapterWeightsPath: string): string {
88157
return `${HF_HUB_URL}/${modelId}/resolve/main/${adapterWeightsPath}`;
89158
}
@@ -132,9 +201,11 @@ export class FalAITextToImageTask extends FalAITask implements TextToImageTaskHe
132201
}
133202
}
134203

135-
export class FalAIImageToImageTask extends FalAITask implements ImageToImageTaskHelper {
204+
export class FalAIImageToImageTask extends FalAiQueueTask implements ImageToImageTaskHelper {
205+
task: InferenceTask;
136206
constructor() {
137207
super("https://queue.fal.run");
208+
this.task = "image-to-image";
138209
}
139210

140211
override makeRoute(params: UrlParams): string {
@@ -161,63 +232,8 @@ export class FalAIImageToImageTask extends FalAITask implements ImageToImageTask
161232
url?: string,
162233
headers?: Record<string, string>
163234
): Promise<Blob> {
164-
if (!url || !headers) {
165-
throw new InferenceClientInputError("URL and headers are required for image-to-image task");
166-
}
167-
const requestId = response.request_id;
168-
if (!requestId) {
169-
throw new InferenceClientProviderOutputError(
170-
"Received malformed response from Fal.ai text-to-video API: no request ID found in the response"
171-
);
172-
}
173-
let status = response.status;
174-
175-
const parsedUrl = new URL(url);
176-
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
177-
}`;
178-
179-
// extracting the provider model id for status and result urls
180-
// from the response as it might be different from the mapped model in `url`
181-
const modelId = new URL(response.response_url).pathname;
182-
const queryParams = parsedUrl.search;
183-
184-
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
185-
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
186-
187-
while (status !== "COMPLETED") {
188-
await delay(500);
189-
const statusResponse = await fetch(statusUrl, { headers });
235+
const result = await this.getResponseFromQueueApi(response, url, headers);
190236

191-
if (!statusResponse.ok) {
192-
throw new InferenceClientProviderApiError(
193-
"Failed to fetch response status from fal-ai API",
194-
{ url: statusUrl, method: "GET" },
195-
{
196-
requestId: statusResponse.headers.get("x-request-id") ?? "",
197-
status: statusResponse.status,
198-
body: await statusResponse.text(),
199-
}
200-
);
201-
}
202-
try {
203-
status = (await statusResponse.json()).status;
204-
} catch (error) {
205-
throw new InferenceClientProviderOutputError(
206-
"Failed to parse status response from fal-ai API: received malformed response"
207-
);
208-
}
209-
}
210-
211-
const resultResponse = await fetch(resultUrl, { headers });
212-
let result: unknown;
213-
try {
214-
result = await resultResponse.json();
215-
} catch (error) {
216-
throw new InferenceClientProviderOutputError(
217-
"Failed to parse result response from fal-ai API: received malformed response"
218-
);
219-
}
220-
console.log("result", result);
221237
if (
222238
typeof result === "object" &&
223239
!!result &&
@@ -242,9 +258,11 @@ export class FalAIImageToImageTask extends FalAITask implements ImageToImageTask
242258
}
243259
}
244260

245-
export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHelper {
261+
export class FalAITextToVideoTask extends FalAiQueueTask implements TextToVideoTaskHelper {
262+
task: InferenceTask;
246263
constructor() {
247264
super("https://queue.fal.run");
265+
this.task = "text-to-video";
248266
}
249267
override makeRoute(params: UrlParams): string {
250268
if (params.authMethod !== "provider-key") {
@@ -265,62 +283,8 @@ export class FalAITextToVideoTask extends FalAITask implements TextToVideoTaskHe
265283
url?: string,
266284
headers?: Record<string, string>
267285
): Promise<Blob> {
268-
if (!url || !headers) {
269-
throw new InferenceClientInputError("URL and headers are required for text-to-video task");
270-
}
271-
const requestId = response.request_id;
272-
if (!requestId) {
273-
throw new InferenceClientProviderOutputError(
274-
"Received malformed response from Fal.ai text-to-video API: no request ID found in the response"
275-
);
276-
}
277-
let status = response.status;
278-
279-
const parsedUrl = new URL(url);
280-
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${parsedUrl.host === "router.huggingface.co" ? "/fal-ai" : ""
281-
}`;
282-
283-
// extracting the provider model id for status and result urls
284-
// from the response as it might be different from the mapped model in `url`
285-
const modelId = new URL(response.response_url).pathname;
286-
const queryParams = parsedUrl.search;
287-
288-
const statusUrl = `${baseUrl}${modelId}/status${queryParams}`;
289-
const resultUrl = `${baseUrl}${modelId}${queryParams}`;
290-
291-
while (status !== "COMPLETED") {
292-
await delay(500);
293-
const statusResponse = await fetch(statusUrl, { headers });
294-
295-
if (!statusResponse.ok) {
296-
throw new InferenceClientProviderApiError(
297-
"Failed to fetch response status from fal-ai API",
298-
{ url: statusUrl, method: "GET" },
299-
{
300-
requestId: statusResponse.headers.get("x-request-id") ?? "",
301-
status: statusResponse.status,
302-
body: await statusResponse.text(),
303-
}
304-
);
305-
}
306-
try {
307-
status = (await statusResponse.json()).status;
308-
} catch (error) {
309-
throw new InferenceClientProviderOutputError(
310-
"Failed to parse status response from fal-ai API: received malformed response"
311-
);
312-
}
313-
}
286+
const result = await this.getResponseFromQueueApi(response, url, headers);
314287

315-
const resultResponse = await fetch(resultUrl, { headers });
316-
let result: unknown;
317-
try {
318-
result = await resultResponse.json();
319-
} catch (error) {
320-
throw new InferenceClientProviderOutputError(
321-
"Failed to parse result response from fal-ai API: received malformed response"
322-
);
323-
}
324288
if (
325289
typeof result === "object" &&
326290
!!result &&

0 commit comments

Comments
 (0)