1616 */
1717import { base64FromBytes } from "../utils/base64FromBytes.js" ;
1818
19- import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks" ;
19+ import type { AutomaticSpeechRecognitionOutput , ImageSegmentationOutput } from "@huggingface/tasks" ;
2020import { isUrl } from "../lib/isUrl.js" ;
2121import type { BodyParams , HeaderParams , InferenceTask , ModelId , RequestArgs , UrlParams } from "../types.js" ;
2222import { delay } from "../utils/delay.js" ;
2323import { omit } from "../utils/omit.js" ;
24- import type { ImageToImageTaskHelper } from "./providerHelper.js" ;
24+ import type { ImageSegmentationTaskHelper , ImageToImageTaskHelper } from "./providerHelper.js" ;
2525import {
2626 type AutomaticSpeechRecognitionTaskHelper ,
2727 TaskProviderHelper ,
@@ -36,6 +36,7 @@ import {
3636 InferenceClientProviderOutputError ,
3737} from "../errors.js" ;
3838import type { ImageToImageArgs } from "../tasks/index.js" ;
39+ import type { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js" ;
3940
4041export interface FalAiQueueOutput {
4142 request_id : string ;
@@ -406,3 +407,87 @@ export class FalAITextToSpeechTask extends FalAITask {
406407 }
407408 }
408409}
410+ export class FalAIImageSegmentationTask extends FalAiQueueTask implements ImageSegmentationTaskHelper {
411+ task : InferenceTask ;
412+ constructor ( ) {
413+ super ( "https://queue.fal.run" ) ;
414+ this . task = "image-segmentation" ;
415+ }
416+
417+ override makeRoute ( params : UrlParams ) : string {
418+ if ( params . authMethod !== "provider-key" ) {
419+ return `/${ params . model } ?_subdomain=queue` ;
420+ }
421+ return `/${ params . model } ` ;
422+ }
423+
424+ override preparePayload ( params : BodyParams ) : Record < string , unknown > {
425+ return {
426+ ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
427+ ...( params . args . parameters as Record < string , unknown > ) ,
428+ sync_mode : true ,
429+ } ;
430+ }
431+
432+ async preparePayloadAsync ( args : ImageSegmentationArgs ) : Promise < RequestArgs > {
433+ const blob = "data" in args && args . data instanceof Blob ? args . data : "inputs" in args ? args . inputs : undefined ;
434+ const mimeType = blob instanceof Blob ? blob . type : "image/png" ;
435+ const base64Image = base64FromBytes (
436+ new Uint8Array ( blob instanceof ArrayBuffer ? blob : await ( blob as Blob ) . arrayBuffer ( ) )
437+ ) ;
438+ return {
439+ ...omit ( args , [ "inputs" , "parameters" , "data" ] ) ,
440+ ...args . parameters ,
441+ ...args ,
442+ image_url : `data:${ mimeType } ;base64,${ base64Image } ` ,
443+ sync_mode : true ,
444+ } ;
445+ }
446+
447+ override async getResponse (
448+ response : FalAiQueueOutput ,
449+ url ?: string ,
450+ headers ?: Record < string , string >
451+ ) : Promise < ImageSegmentationOutput > {
452+ const result = await this . getResponseFromQueueApi ( response , url , headers ) ;
453+ if (
454+ typeof result === "object" &&
455+ result !== null &&
456+ "image" in result &&
457+ typeof result . image === "object" &&
458+ result . image !== null &&
459+ "url" in result . image &&
460+ typeof result . image . url === "string"
461+ ) {
462+ const maskResponse = await fetch ( result . image . url ) ;
463+ if ( ! maskResponse . ok ) {
464+ throw new InferenceClientProviderApiError (
465+ `Failed to fetch segmentation mask from ${ result . image . url } ` ,
466+ { url : result . image . url , method : "GET" } ,
467+ {
468+ requestId : maskResponse . headers . get ( "x-request-id" ) ?? "" ,
469+ status : maskResponse . status ,
470+ body : await maskResponse . text ( ) ,
471+ }
472+ ) ;
473+ }
474+ const maskBlob = await maskResponse . blob ( ) ;
475+ const maskArrayBuffer = await maskBlob . arrayBuffer ( ) ;
476+ const maskBase64 = base64FromBytes ( new Uint8Array ( maskArrayBuffer ) ) ;
477+
478+ return [
479+ {
480+ label : "mask" , // placeholder label, as Fal does not provide labels in the response(?)
481+ score : 1.0 , // placeholder score, as Fal does not provide scores in the response(?)
482+ mask : maskBase64 ,
483+ } ,
484+ ] ;
485+ }
486+
487+ throw new InferenceClientProviderOutputError (
488+ `Received malformed response from Fal.ai image-segmentation API: expected { image: { url: string } } format, got instead: ${ JSON . stringify (
489+ response
490+ ) } `
491+ ) ;
492+ }
493+ }
0 commit comments