diff --git a/packages/inference/src/providers/fal-ai.ts b/packages/inference/src/providers/fal-ai.ts index 83f19d287e..a1b9ff3b37 100644 --- a/packages/inference/src/providers/fal-ai.ts +++ b/packages/inference/src/providers/fal-ai.ts @@ -14,10 +14,12 @@ * * Thanks! */ +import { base64FromBytes } from "../utils/base64FromBytes"; + import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks"; import { InferenceOutputError } from "../lib/InferenceOutputError"; import { isUrl } from "../lib/isUrl"; -import type { BodyParams, HeaderParams, ModelId, UrlParams } from "../types"; +import type { BodyParams, HeaderParams, ModelId, RequestArgs, UrlParams } from "../types"; import { delay } from "../utils/delay"; import { omit } from "../utils/omit"; import { @@ -27,6 +29,7 @@ import { type TextToVideoTaskHelper, } from "./providerHelper"; import { HF_HUB_URL } from "../config"; +import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition"; export interface FalAiQueueOutput { request_id: string; @@ -224,6 +227,28 @@ export class FalAIAutomaticSpeechRecognitionTask extends FalAITask implements Au } return { text: res.text }; } + + async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise { + const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined; + const contentType = blob?.type; + if (!contentType) { + throw new Error( + `Unable to determine the input's content-type. Make sure your are passing a Blob when using provider fal-ai.` + ); + } + if (!FAL_AI_SUPPORTED_BLOB_TYPES.includes(contentType)) { + throw new Error( + `Provider fal-ai does not support blob type ${contentType} - supported content types are: ${FAL_AI_SUPPORTED_BLOB_TYPES.join( + ", " + )}` + ); + } + const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer())); + return { + ...("data" in args ? omit(args, "data") : omit(args, "inputs")), + audio_url: `data:${contentType};base64,${base64audio}`, + }; + } } export class FalAITextToSpeechTask extends FalAITask { diff --git a/packages/inference/src/providers/hf-inference.ts b/packages/inference/src/providers/hf-inference.ts index b049c0b54d..54ff1c3d9c 100644 --- a/packages/inference/src/providers/hf-inference.ts +++ b/packages/inference/src/providers/hf-inference.ts @@ -36,7 +36,7 @@ import type { import { HF_ROUTER_URL } from "../config"; import { InferenceOutputError } from "../lib/InferenceOutputError"; import type { TabularClassificationOutput } from "../tasks/tabular/tabularClassification"; -import type { BodyParams, UrlParams } from "../types"; +import type { BodyParams, RequestArgs, UrlParams } from "../types"; import { toArray } from "../utils/toArray"; import type { AudioClassificationTaskHelper, @@ -70,7 +70,10 @@ import type { } from "./providerHelper"; import { TaskProviderHelper } from "./providerHelper"; - +import { base64FromBytes } from "../utils/base64FromBytes"; +import type { ImageToImageArgs } from "../tasks/cv/imageToImage"; +import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition"; +import { omit } from "../utils/omit"; interface Base64ImageGeneration { data: Array<{ b64_json: string; @@ -221,6 +224,15 @@ export class HFInferenceAutomaticSpeechRecognitionTask override async getResponse(response: AutomaticSpeechRecognitionOutput): Promise { return response; } + + async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise { + return "data" in args + ? args + : { + ...omit(args, "inputs"), + data: args.inputs, + }; + } } export class HFInferenceAudioToAudioTask extends HFInferenceTask implements AudioToAudioTaskHelper { @@ -326,6 +338,23 @@ export class HFInferenceImageToTextTask extends HFInferenceTask implements Image } export class HFInferenceImageToImageTask extends HFInferenceTask implements ImageToImageTaskHelper { + async preparePayloadAsync(args: ImageToImageArgs): Promise { + if (!args.parameters) { + return { + ...args, + model: args.model, + data: args.inputs, + }; + } else { + return { + ...args, + inputs: base64FromBytes( + new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer()) + ), + }; + } + } + override async getResponse(response: Blob): Promise { if (response instanceof Blob) { return response; diff --git a/packages/inference/src/providers/providerHelper.ts b/packages/inference/src/providers/providerHelper.ts index 994f3ab0ea..e94b93db07 100644 --- a/packages/inference/src/providers/providerHelper.ts +++ b/packages/inference/src/providers/providerHelper.ts @@ -48,8 +48,10 @@ import type { import { HF_ROUTER_URL } from "../config"; import { InferenceOutputError } from "../lib/InferenceOutputError"; import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio"; -import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, UrlParams } from "../types"; +import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, RequestArgs, UrlParams } from "../types"; import { toArray } from "../utils/toArray"; +import type { ImageToImageArgs } from "../tasks/cv/imageToImage"; +import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition"; /** * Base class for task-specific provider helpers @@ -142,6 +144,7 @@ export interface TextToVideoTaskHelper { export interface ImageToImageTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record; + preparePayloadAsync(args: ImageToImageArgs): Promise; } export interface ImageSegmentationTaskHelper { @@ -245,6 +248,7 @@ export interface AudioToAudioTaskHelper { export interface AutomaticSpeechRecognitionTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record | BodyInit; + preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise; } export interface AudioClassificationTaskHelper { diff --git a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts index 57128f73c9..617f73dcc4 100644 --- a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts +++ b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts @@ -2,13 +2,9 @@ import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; -import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai"; -import type { BaseArgs, Options, RequestArgs } from "../../types"; -import { base64FromBytes } from "../../utils/base64FromBytes"; -import { omit } from "../../utils/omit"; +import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; import type { LegacyAudioInput } from "./utils"; -import { preparePayload } from "./utils"; export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognitionInput | LegacyAudioInput); /** @@ -21,7 +17,7 @@ export async function automaticSpeechRecognition( ): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "automatic-speech-recognition"); - const payload = await buildPayload(args); + const payload = await providerHelper.preparePayloadAsync(args); const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "automatic-speech-recognition", @@ -32,29 +28,3 @@ export async function automaticSpeechRecognition( } return providerHelper.getResponse(res); } - -async function buildPayload(args: AutomaticSpeechRecognitionArgs): Promise { - if (args.provider === "fal-ai") { - const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined; - const contentType = blob?.type; - if (!contentType) { - throw new Error( - `Unable to determine the input's content-type. Make sure your are passing a Blob when using provider fal-ai.` - ); - } - if (!FAL_AI_SUPPORTED_BLOB_TYPES.includes(contentType)) { - throw new Error( - `Provider fal-ai does not support blob type ${contentType} - supported content types are: ${FAL_AI_SUPPORTED_BLOB_TYPES.join( - ", " - )}` - ); - } - const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer())); - return { - ...("data" in args ? omit(args, "data") : omit(args, "inputs")), - audio_url: `data:${contentType};base64,${base64audio}`, - }; - } else { - return preparePayload(args); - } -} diff --git a/packages/inference/src/tasks/cv/imageToImage.ts b/packages/inference/src/tasks/cv/imageToImage.ts index 1b766c7bf1..fb007473d9 100644 --- a/packages/inference/src/tasks/cv/imageToImage.ts +++ b/packages/inference/src/tasks/cv/imageToImage.ts @@ -1,8 +1,7 @@ import type { ImageToImageInput } from "@huggingface/tasks"; import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; -import type { BaseArgs, Options, RequestArgs } from "../../types"; -import { base64FromBytes } from "../../utils/base64FromBytes"; +import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; export type ImageToImageArgs = BaseArgs & ImageToImageInput; @@ -14,22 +13,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput; export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "image-to-image"); - let reqArgs: RequestArgs; - if (!args.parameters) { - reqArgs = { - accessToken: args.accessToken, - model: args.model, - data: args.inputs, - }; - } else { - reqArgs = { - ...args, - inputs: base64FromBytes( - new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer()) - ), - }; - } - const { data: res } = await innerRequest(reqArgs, providerHelper, { + const payload = await providerHelper.preparePayloadAsync(args); + const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "image-to-image", });