diff --git a/.changeset/clever-games-float.md b/.changeset/clever-games-float.md new file mode 100644 index 000000000..49d27159b --- /dev/null +++ b/.changeset/clever-games-float.md @@ -0,0 +1,6 @@ +--- +'@livekit/agents-plugin-google': patch +'@livekit/agents-plugin-rime': patch +--- + +Implemented close() functionality in gemini and rime tts. diff --git a/plugins/google/src/beta/gemini_tts.ts b/plugins/google/src/beta/gemini_tts.ts index 0d46914d7..100018bc2 100644 --- a/plugins/google/src/beta/gemini_tts.ts +++ b/plugins/google/src/beta/gemini_tts.ts @@ -65,6 +65,7 @@ export interface TTSOptions { export class TTS extends tts.TTS { #opts: TTSOptions; #client: GoogleGenAI; + #abortController = new AbortController(); label = 'google.gemini.TTS'; /** @@ -164,6 +165,14 @@ export class TTS extends tts.TTS { get client(): GoogleGenAI { return this.#client; } + + get signal(): AbortSignal { + return this.#abortController.signal; + } + + async close(): Promise { + this.#abortController.abort(); + } } export class ChunkedStream extends tts.ChunkedStream { @@ -205,7 +214,10 @@ export class ChunkedStream extends tts.ChunkedStream { const responseStream = await this.#tts.client.models.generateContentStream({ model: this.#tts.opts.model, contents, - config, + config: { + ...config, + abortSignal: this.#tts.signal, + }, }); try { @@ -213,6 +225,9 @@ export class ChunkedStream extends tts.ChunkedStream { await this.#processResponse(response, bstream, requestId); } } catch (error: unknown) { + if (error instanceof Error && error.name === 'AbortError') { + return; + } if (isAPIError(error)) throw error; const err = error as { diff --git a/plugins/rime/src/tts.ts b/plugins/rime/src/tts.ts index 190cefdc2..4d584fc7f 100644 --- a/plugins/rime/src/tts.ts +++ b/plugins/rime/src/tts.ts @@ -62,6 +62,7 @@ const defaultTTSOptions: TTSOptions = { export class TTS extends tts.TTS { private opts: TTSOptions; + private abortController = new AbortController(); label = 'rime.TTS'; /** @@ -102,18 +103,23 @@ export class TTS extends tts.TTS { * @returns A chunked stream of synthesized audio */ synthesize(text: string): ChunkedStream { - return new ChunkedStream(this, text, this.opts); + return new ChunkedStream(this, text, this.opts, this.abortController.signal); } stream(): tts.SynthesizeStream { throw new Error('Streaming is not supported on RimeTTS'); } + + async close(): Promise { + this.abortController.abort(); + } } export class ChunkedStream extends tts.ChunkedStream { label = 'rime-tts.ChunkedStream'; private opts: TTSOptions; private text: string; + private signal: AbortSignal; /** * Create a new ChunkedStream instance. @@ -121,52 +127,63 @@ export class ChunkedStream extends tts.ChunkedStream { * @param tts - The parent TTS instance * @param text - Text to synthesize * @param opts - TTS configuration options + * @param signal - AbortSignal for cancellation */ - constructor(tts: TTS, text: string, opts: TTSOptions) { + constructor(tts: TTS, text: string, opts: TTSOptions, signal: AbortSignal) { super(text, tts); this.text = text; this.opts = opts; + this.signal = signal; } protected async run() { - const requestId = shortuuid(); - const response = await fetch(`${this.opts.baseURL}`, { - method: 'POST', - headers: { - Accept: 'audio/pcm', - Authorization: `Bearer ${this.opts.apiKey}`, - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - ...Object.fromEntries( - Object.entries(this.opts).filter(([k]) => !['apiKey', 'baseURL'].includes(k)), - ), - text: this.text, - }), - }); - - if (!response.ok) { - throw new Error(`Rime AI TTS request failed: ${response.status} ${response.statusText}`); - } + try { + const requestId = shortuuid(); + const response = await fetch(`${this.opts.baseURL}`, { + method: 'POST', + headers: { + Accept: 'audio/pcm', + Authorization: `Bearer ${this.opts.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + ...Object.fromEntries( + Object.entries(this.opts).filter(([k]) => !['apiKey', 'baseURL'].includes(k)), + ), + text: this.text, + }), + signal: this.signal, + }); + + if (!response.ok) { + throw new Error(`Rime AI TTS request failed: ${response.status} ${response.statusText}`); + } - const buffer = await response.arrayBuffer(); - const sampleRate = getSampleRate(this.opts); - const audioByteStream = new AudioByteStream(sampleRate, RIME_TTS_CHANNELS); - const frames = audioByteStream.write(buffer); - let lastFrame: AudioFrame | undefined; - const sendLastFrame = (segmentId: string, final: boolean) => { - if (lastFrame) { - this.queue.put({ requestId, segmentId, frame: lastFrame, final }); - lastFrame = undefined; + const buffer = await response.arrayBuffer(); + const sampleRate = getSampleRate(this.opts); + const audioByteStream = new AudioByteStream(sampleRate, RIME_TTS_CHANNELS); + const frames = audioByteStream.write(buffer); + let lastFrame: AudioFrame | undefined; + const sendLastFrame = (segmentId: string, final: boolean) => { + if (lastFrame) { + this.queue.put({ requestId, segmentId, frame: lastFrame, final }); + lastFrame = undefined; + } + }; + + for (const frame of frames) { + sendLastFrame(requestId, false); + lastFrame = frame; } - }; + sendLastFrame(requestId, true); - for (const frame of frames) { - sendLastFrame(requestId, false); - lastFrame = frame; + this.queue.close(); + } catch (error) { + if (error instanceof Error && error.name === 'AbortError') { + this.queue.close(); + return; + } + throw error; } - sendLastFrame(requestId, true); - - this.queue.close(); } }