Skip to content

Commit e3ff966

Browse files
add image-segmentation support for fal (#1602)
adds `image-segmentation` task for fal, following internal discussion [here](https://huggingface.slack.com/archives/C0664PDFGSJ/p1751531970106289?thread_ts=1748529796.374799&cid=C0664PDFGSJ) cc @Vaibhavs10 --------- Co-authored-by: Celina Hanouti <[email protected]>
1 parent eb72014 commit e3ff966

File tree

5 files changed

+105
-6
lines changed

5 files changed

+105
-6
lines changed

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
6767
"text-to-video": new FalAI.FalAITextToVideoTask(),
6868
"image-to-image": new FalAI.FalAIImageToImageTask(),
6969
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
70+
"image-segmentation": new FalAI.FalAIImageSegmentationTask(),
7071
},
7172
"featherless-ai": {
7273
conversational: new FeatherlessAI.FeatherlessAIConversationalTask(),

packages/inference/src/providers/fal-ai.ts

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
*/
1717
import { base64FromBytes } from "../utils/base64FromBytes.js";
1818

19-
import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
19+
import type { AutomaticSpeechRecognitionOutput, ImageSegmentationOutput } from "@huggingface/tasks";
2020
import { isUrl } from "../lib/isUrl.js";
2121
import type { BodyParams, HeaderParams, InferenceTask, ModelId, RequestArgs, UrlParams } from "../types.js";
2222
import { delay } from "../utils/delay.js";
2323
import { omit } from "../utils/omit.js";
24-
import type { ImageToImageTaskHelper } from "./providerHelper.js";
24+
import type { ImageSegmentationTaskHelper, ImageToImageTaskHelper } from "./providerHelper.js";
2525
import {
2626
type AutomaticSpeechRecognitionTaskHelper,
2727
TaskProviderHelper,
@@ -36,6 +36,7 @@ import {
3636
InferenceClientProviderOutputError,
3737
} from "../errors.js";
3838
import type { ImageToImageArgs } from "../tasks/index.js";
39+
import type { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js";
3940

4041
export interface FalAiQueueOutput {
4142
request_id: string;
@@ -406,3 +407,87 @@ export class FalAITextToSpeechTask extends FalAITask {
406407
}
407408
}
408409
}
410+
export class FalAIImageSegmentationTask extends FalAiQueueTask implements ImageSegmentationTaskHelper {
411+
task: InferenceTask;
412+
constructor() {
413+
super("https://queue.fal.run");
414+
this.task = "image-segmentation";
415+
}
416+
417+
override makeRoute(params: UrlParams): string {
418+
if (params.authMethod !== "provider-key") {
419+
return `/${params.model}?_subdomain=queue`;
420+
}
421+
return `/${params.model}`;
422+
}
423+
424+
override preparePayload(params: BodyParams): Record<string, unknown> {
425+
return {
426+
...omit(params.args, ["inputs", "parameters"]),
427+
...(params.args.parameters as Record<string, unknown>),
428+
sync_mode: true,
429+
};
430+
}
431+
432+
async preparePayloadAsync(args: ImageSegmentationArgs): Promise<RequestArgs> {
433+
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined;
434+
const mimeType = blob instanceof Blob ? blob.type : "image/png";
435+
const base64Image = base64FromBytes(
436+
new Uint8Array(blob instanceof ArrayBuffer ? blob : await (blob as Blob).arrayBuffer())
437+
);
438+
return {
439+
...omit(args, ["inputs", "parameters", "data"]),
440+
...args.parameters,
441+
...args,
442+
image_url: `data:${mimeType};base64,${base64Image}`,
443+
sync_mode: true,
444+
};
445+
}
446+
447+
override async getResponse(
448+
response: FalAiQueueOutput,
449+
url?: string,
450+
headers?: Record<string, string>
451+
): Promise<ImageSegmentationOutput> {
452+
const result = await this.getResponseFromQueueApi(response, url, headers);
453+
if (
454+
typeof result === "object" &&
455+
result !== null &&
456+
"image" in result &&
457+
typeof result.image === "object" &&
458+
result.image !== null &&
459+
"url" in result.image &&
460+
typeof result.image.url === "string"
461+
) {
462+
const maskResponse = await fetch(result.image.url);
463+
if (!maskResponse.ok) {
464+
throw new InferenceClientProviderApiError(
465+
`Failed to fetch segmentation mask from ${result.image.url}`,
466+
{ url: result.image.url, method: "GET" },
467+
{
468+
requestId: maskResponse.headers.get("x-request-id") ?? "",
469+
status: maskResponse.status,
470+
body: await maskResponse.text(),
471+
}
472+
);
473+
}
474+
const maskBlob = await maskResponse.blob();
475+
const maskArrayBuffer = await maskBlob.arrayBuffer();
476+
const maskBase64 = base64FromBytes(new Uint8Array(maskArrayBuffer));
477+
478+
return [
479+
{
480+
label: "mask", // placeholder label, as Fal does not provide labels in the response(?)
481+
score: 1.0, // placeholder score, as Fal does not provide scores in the response(?)
482+
mask: maskBase64,
483+
},
484+
];
485+
}
486+
487+
throw new InferenceClientProviderOutputError(
488+
`Received malformed response from Fal.ai image-segmentation API: expected { image: { url: string } } format, got instead: ${JSON.stringify(
489+
response
490+
)}`
491+
);
492+
}
493+
}

packages/inference/src/providers/hf-inference.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ import { base64FromBytes } from "../utils/base64FromBytes.js";
7676
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
7777
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition.js";
7878
import { omit } from "../utils/omit.js";
79+
import { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js";
7980
interface Base64ImageGeneration {
8081
data: Array<{
8182
b64_json: string;
@@ -345,6 +346,15 @@ export class HFInferenceImageSegmentationTask extends HFInferenceTask implements
345346
"Received malformed response from HF-Inference image-segmentation API: expected Array<{label: string, mask: string, score: number}>"
346347
);
347348
}
349+
350+
async preparePayloadAsync(args: ImageSegmentationArgs): Promise<RequestArgs> {
351+
return {
352+
...args,
353+
inputs: base64FromBytes(
354+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
355+
),
356+
};
357+
}
348358
}
349359

350360
export class HFInferenceImageToTextTask extends HFInferenceTask implements ImageToTextTaskHelper {

packages/inference/src/providers/providerHelper.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ import { toArray } from "../utils/toArray.js";
5454
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
5555
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition.js";
5656
import type { ImageToVideoArgs } from "../tasks/cv/imageToVideo.js";
57+
import { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js";
5758

5859
/**
5960
* Base class for task-specific provider helpers
@@ -161,6 +162,7 @@ export interface ImageToVideoTaskHelper {
161162
export interface ImageSegmentationTaskHelper {
162163
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<ImageSegmentationOutput>;
163164
preparePayload(params: BodyParams<ImageSegmentationInput & BaseArgs>): Record<string, unknown> | BodyInit;
165+
preparePayloadAsync(args: ImageSegmentationArgs): Promise<RequestArgs>;
164166
}
165167

166168
export interface ImageClassificationTaskHelper {

packages/inference/src/tasks/cv/imageSegmentation.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js";
33
import { getProviderHelper } from "../../lib/getProviderHelper.js";
44
import type { BaseArgs, Options } from "../../types.js";
55
import { innerRequest } from "../../utils/request.js";
6-
import { preparePayload, type LegacyImageInput } from "./utils.js";
6+
import { makeRequestOptions } from "../../lib/makeRequestOptions.js";
77

8-
export type ImageSegmentationArgs = BaseArgs & (ImageSegmentationInput | LegacyImageInput);
8+
export type ImageSegmentationArgs = BaseArgs & ImageSegmentationInput;
99

1010
/**
1111
* This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
@@ -17,10 +17,11 @@ export async function imageSegmentation(
1717
): Promise<ImageSegmentationOutput> {
1818
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1919
const providerHelper = getProviderHelper(provider, "image-segmentation");
20-
const payload = preparePayload(args);
20+
const payload = await providerHelper.preparePayloadAsync(args);
2121
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
2222
...options,
2323
task: "image-segmentation",
2424
});
25-
return providerHelper.getResponse(res);
25+
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "image-segmentation" });
26+
return providerHelper.getResponse(res, url, info.headers as Record<string, string>);
2627
}

0 commit comments

Comments
 (0)