diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 72cc35bf62..4409927e47 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -143,6 +143,7 @@ export const PROVIDERS: Record { + return { + input: { + ...omit(params.args, ["inputs", "parameters"]), + ...(params.args.parameters as Record), + audio: params.args.inputs, // This will be processed in preparePayloadAsync + }, + version: params.model.includes(":") ? params.model.split(":")[1] : undefined, + }; + } + + async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise { + const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined; + + if (!blob || !(blob instanceof Blob)) { + throw new Error("Audio input must be a Blob"); + } + + // Convert Blob to base64 data URL + const bytes = new Uint8Array(await blob.arrayBuffer()); + const base64 = base64FromBytes(bytes); + const audioInput = `data:${blob.type || "audio/wav"};base64,${base64}`; + + return { + ...("data" in args ? omit(args, "data") : omit(args, "inputs")), + inputs: audioInput, + }; + } + + override async getResponse(response: ReplicateOutput): Promise { + if (typeof response?.output === "string") return { text: response.output }; + if (Array.isArray(response?.output) && typeof response.output[0] === "string") return { text: response.output[0] }; + + const out = response?.output as + | undefined + | { + transcription?: string; + translation?: string; + txt_file?: string; + }; + if (out && typeof out === "object") { + if (typeof out.transcription === "string") return { text: out.transcription }; + if (typeof out.translation === "string") return { text: out.translation }; + if (typeof out.txt_file === "string") { + const r = await fetch(out.txt_file); + return { text: await r.text() }; + } + } + throw new InferenceClientProviderOutputError( + "Received malformed response from Replicate automatic-speech-recognition API" + ); + } +} + export class ReplicateImageToImageTask extends ReplicateTask implements ImageToImageTaskHelper { override preparePayload(params: BodyParams): Record { return { diff --git a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts index a8ce6ebed6..c5a716d5bc 100644 --- a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts +++ b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts @@ -4,7 +4,6 @@ import { getProviderHelper } from "../../lib/getProviderHelper.js"; import type { BaseArgs, Options } from "../../types.js"; import { innerRequest } from "../../utils/request.js"; import type { LegacyAudioInput } from "./utils.js"; -import { InferenceClientProviderOutputError } from "../../errors.js"; export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognitionInput | LegacyAudioInput); /** @@ -22,9 +21,5 @@ export async function automaticSpeechRecognition( ...options, task: "automatic-speech-recognition", }); - const isValidOutput = typeof res?.text === "string"; - if (!isValidOutput) { - throw new InferenceClientProviderOutputError("Received malformed response from automatic-speech-recognition API"); - } return providerHelper.getResponse(res); }