diff --git a/.changeset/hungry-schools-think.md b/.changeset/hungry-schools-think.md new file mode 100644 index 000000000..8f0b1b30a --- /dev/null +++ b/.changeset/hungry-schools-think.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +Fix improper resource cleanup inside AgentActivity by not close global STT / TTS / VAD components diff --git a/agents/src/inference/stt.ts b/agents/src/inference/stt.ts index 0c9236cd1..8b0f26b0f 100644 --- a/agents/src/inference/stt.ts +++ b/agents/src/inference/stt.ts @@ -422,8 +422,7 @@ export class SpeechStream extends BaseSpeechStream { try { ws = await this.stt.connectWs(this.connOptions.timeoutMs); - // Wrap tasks for proper cancellation support using Task signals - const controller = new AbortController(); + const controller = this.abortController; // Use base class abortController for proper cancellation const sendTask = Task.from(({ signal }) => send(ws!, signal), controller); const wsListenerTask = Task.from(({ signal }) => createWsListener(ws!, signal), controller); const recvTask = Task.from(({ signal }) => recv(signal), controller); diff --git a/agents/src/ipc/job_proc_lazy_main.ts b/agents/src/ipc/job_proc_lazy_main.ts index 7c75b4d09..e39b8c398 100644 --- a/agents/src/ipc/job_proc_lazy_main.ts +++ b/agents/src/ipc/job_proc_lazy_main.ts @@ -189,7 +189,7 @@ const startJob = ( let logger = log().child({ pid: proc.pid }); process.on('unhandledRejection', (reason) => { - logger.error(reason); + logger.debug({ error: reason }, 'Unhandled promise rejection'); }); logger.debug('initializing job runner'); diff --git a/agents/src/stt/stream_adapter.ts b/agents/src/stt/stream_adapter.ts index 17f29a510..50c976084 100644 --- a/agents/src/stt/stream_adapter.ts +++ b/agents/src/stt/stream_adapter.ts @@ -23,10 +23,14 @@ export class StreamAdapter extends STT { this.#stt.on('metrics_collected', (metrics) => { this.emit('metrics_collected', metrics); }); + + this.#stt.on('error', (error) => { + this.emit('error', error); + }); } - _recognize(frame: AudioFrame): Promise { - return this.#stt.recognize(frame); + _recognize(frame: AudioFrame, abortSignal?: AbortSignal): Promise { + return this.#stt.recognize(frame, abortSignal); } stream(options?: { connOptions?: APIConnectOptions }): StreamAdapterWrapper { @@ -46,6 +50,11 @@ export class StreamAdapterWrapper extends SpeechStream { this.label = `stt.StreamAdapterWrapper<${this.#stt.label}>`; } + close() { + super.close(); + this.#vadStream.close(); + } + async monitorMetrics() { return; // do nothing } @@ -72,7 +81,7 @@ export class StreamAdapterWrapper extends SpeechStream { this.output.put({ type: SpeechEventType.END_OF_SPEECH }); try { - const event = await this.#stt.recognize(ev.frames); + const event = await this.#stt.recognize(ev.frames, this.abortSignal); if (!event.alternatives![0].text) { continue; } @@ -93,6 +102,6 @@ export class StreamAdapterWrapper extends SpeechStream { } }; - Promise.all([forwardInput(), recognize()]); + await Promise.all([forwardInput(), recognize()]); } } diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index 039aa4f69..bfc20bf59 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -113,9 +113,9 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter { + async recognize(frame: AudioBuffer, abortSignal?: AbortSignal): Promise { const startTime = process.hrtime.bigint(); - const event = await this._recognize(frame); + const event = await this._recognize(frame, abortSignal); const durationMs = Number((process.hrtime.bigint() - startTime) / BigInt(1000000)); this.emit('metrics_collected', { type: 'stt_metrics', @@ -128,7 +128,11 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter; + + protected abstract _recognize( + frame: AudioBuffer, + abortSignal?: AbortSignal, + ): Promise; /** * Returns a {@link SpeechStream} that can be used to push audio frames and receive @@ -173,6 +177,8 @@ export abstract class SpeechStream implements AsyncIterableIterator private logger = log(); private _connOptions: APIConnectOptions; + protected abortController = new AbortController(); + constructor( stt: STT, sampleRate?: number, @@ -290,6 +296,10 @@ export abstract class SpeechStream implements AsyncIterableIterator protected abstract run(): Promise; + protected get abortSignal(): AbortSignal { + return this.abortController.signal; + } + updateInputStream(audioStream: ReadableStream) { this.deferredInputStream.setSource(audioStream); } @@ -354,6 +364,7 @@ export abstract class SpeechStream implements AsyncIterableIterator if (!this.input.closed) this.input.close(); if (!this.queue.closed) this.queue.close(); if (!this.output.closed) this.output.close(); + if (!this.abortController.signal.aborted) this.abortController.abort(); this.closed = true; } diff --git a/agents/src/voice/agent.ts b/agents/src/voice/agent.ts index 10ee8a490..95bfea2a8 100644 --- a/agents/src/voice/agent.ts +++ b/agents/src/voice/agent.ts @@ -260,28 +260,41 @@ export class Agent { let wrapped_stt = activity.stt; if (!wrapped_stt.capabilities.streaming) { - if (!agent.vad) { + const vad = agent.vad || activity.vad; + if (!vad) { throw new Error( 'STT does not support streaming, add a VAD to the AgentTask/VoiceAgent to enable streaming', ); } - wrapped_stt = new STTStreamAdapter(wrapped_stt, agent.vad); + wrapped_stt = new STTStreamAdapter(wrapped_stt, vad); } const connOptions = activity.agentSession.connOptions.sttConnOptions; const stream = wrapped_stt.stream({ connOptions }); stream.updateInputStream(audio); + let cleaned = false; + const cleanup = () => { + if (cleaned) return; + cleaned = true; + stream.detachInputStream(); + stream.close(); + }; + return new ReadableStream({ async start(controller) { - for await (const event of stream) { - controller.enqueue(event); + try { + for await (const event of stream) { + controller.enqueue(event); + } + controller.close(); + } finally { + // Always clean up the STT stream, whether it ends naturally or is cancelled + cleanup(); } - controller.close(); }, cancel() { - stream.detachInputStream(); - stream.close(); + cleanup(); }, }); }, @@ -314,15 +327,27 @@ export class Agent { connOptions, parallelToolCalls: true, }); + + let cleaned = false; + const cleanup = () => { + if (cleaned) return; + cleaned = true; + stream.close(); + }; + return new ReadableStream({ async start(controller) { - for await (const chunk of stream) { - controller.enqueue(chunk); + try { + for await (const chunk of stream) { + controller.enqueue(chunk); + } + controller.close(); + } finally { + cleanup(); } - controller.close(); }, cancel() { - stream.close(); + cleanup(); }, }); }, @@ -347,18 +372,29 @@ export class Agent { const stream = wrapped_tts.stream({ connOptions }); stream.updateInputStream(text); + let cleaned = false; + const cleanup = () => { + if (cleaned) return; + cleaned = true; + stream.close(); + }; + return new ReadableStream({ async start(controller) { - for await (const chunk of stream) { - if (chunk === SynthesizeStream.END_OF_STREAM) { - break; + try { + for await (const chunk of stream) { + if (chunk === SynthesizeStream.END_OF_STREAM) { + break; + } + controller.enqueue(chunk.frame); } - controller.enqueue(chunk.frame); + controller.close(); + } finally { + cleanup(); } - controller.close(); }, cancel() { - stream.close(); + cleanup(); }, }); }, diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index 98e377b07..a1e3bf1d6 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -2259,15 +2259,12 @@ export class AgentActivity implements RecognitionHooks { } if (this.stt instanceof STT) { this.stt.off('metrics_collected', this.onMetricsCollected); - await this.stt.close(); } if (this.tts instanceof TTS) { this.tts.off('metrics_collected', this.onMetricsCollected); - await this.tts.close(); } if (this.vad instanceof VAD) { this.vad.off('metrics_collected', this.onMetricsCollected); - await this.vad.close(); } this.detachAudioInput(); diff --git a/agents/src/voice/agent_session.ts b/agents/src/voice/agent_session.ts index b88f6db68..312d4b049 100644 --- a/agents/src/voice/agent_session.ts +++ b/agents/src/voice/agent_session.ts @@ -328,15 +328,6 @@ export class AgentSession< ); } - this.logger.info( - { - input: this.input.audio, - output: this.output.audio, - enableRecording: this._enableRecording, - }, - 'Recording audio input and output', - ); - if (this.input.audio && this.output.audio && this._enableRecording) { this._recorderIO = new RecorderIO({ agentSession: this }); this.input.audio = this._recorderIO.recordInput(this.input.audio); diff --git a/agents/src/worker.ts b/agents/src/worker.ts index a8a64b826..044cf7d0d 100644 --- a/agents/src/worker.ts +++ b/agents/src/worker.ts @@ -729,7 +729,7 @@ export class AgentServer { const req = new JobRequest(msg.job!, onReject, onAccept); this.#logger - .child({ job: msg.job, resuming: msg.resuming, agentName: this.#opts.agentName }) + .child({ jobId: msg.job?.id, resuming: msg.resuming, agentName: this.#opts.agentName }) .info('received job request'); const jobRequestTask = async () => { diff --git a/examples/src/basic_tool_call_agent.ts b/examples/src/basic_tool_call_agent.ts index 992c61379..3c3dfd5b3 100644 --- a/examples/src/basic_tool_call_agent.ts +++ b/examples/src/basic_tool_call_agent.ts @@ -4,7 +4,7 @@ import { type JobContext, type JobProcess, - WorkerOptions, + ServerOptions, cli, defineAgent, llm, @@ -12,6 +12,7 @@ import { } from '@livekit/agents'; import * as livekit from '@livekit/agents-plugin-livekit'; import * as silero from '@livekit/agents-plugin-silero'; +import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -148,8 +149,11 @@ export default defineAgent({ await session.start({ agent: routerAgent, room: ctx.room, + inputOptions: { + noiseCancellation: BackgroundVoiceCancellation(), + }, }); }, }); -cli.runApp(new WorkerOptions({ agent: fileURLToPath(import.meta.url) })); +cli.runApp(new ServerOptions({ agent: fileURLToPath(import.meta.url) })); diff --git a/plugins/deepgram/src/stt.ts b/plugins/deepgram/src/stt.ts index 640f29d0a..bdda68818 100644 --- a/plugins/deepgram/src/stt.ts +++ b/plugins/deepgram/src/stt.ts @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 import { + type APIConnectOptions, type AudioBuffer, AudioByteStream, AudioEnergyFilter, @@ -115,8 +116,8 @@ export class STT extends stt.STT { this.#opts = { ...this.#opts, ...opts }; } - stream(): SpeechStream { - return new SpeechStream(this, this.#opts, this.abortController); + stream(options?: { connOptions?: APIConnectOptions }): SpeechStream { + return new SpeechStream(this, this.#opts, options?.connOptions); } async close() { @@ -134,12 +135,8 @@ export class SpeechStream extends stt.SpeechStream { #audioDurationCollector: PeriodicCollector; label = 'deepgram.SpeechStream'; - constructor( - stt: STT, - opts: STTOptions, - private abortController: AbortController, - ) { - super(stt, opts.sampleRate); + constructor(stt: STT, opts: STTOptions, connOptions?: APIConnectOptions) { + super(stt, opts.sampleRate, connOptions); this.#opts = opts; this.closed = false; this.#audioEnergyFilter = new AudioEnergyFilter(); @@ -263,12 +260,13 @@ export class SpeechStream extends stt.SpeechStream { samples100Ms, ); + // waitForAbort internally sets up an abort listener on the abort signal + // we need to put it outside loop to avoid constant re-registration of the listener + const abortPromise = waitForAbort(this.abortSignal); + try { while (!this.closed) { - const result = await Promise.race([ - this.input.next(), - waitForAbort(this.abortController.signal), - ]); + const result = await Promise.race([this.input.next(), abortPromise]); if (result === undefined) return; // aborted if (result.done) { @@ -306,6 +304,16 @@ export class SpeechStream extends stt.SpeechStream { }; const listenTask = Task.from(async (controller) => { + const putMessage = (message: stt.SpeechEvent) => { + if (!this.queue.closed) { + try { + this.queue.put(message); + } catch (e) { + // ignore + } + } + }; + const listenMessage = new Promise((resolve, reject) => { ws.on('message', (msg) => { try { @@ -318,13 +326,7 @@ export class SpeechStream extends stt.SpeechStream { // It's also possible we receive a transcript without a SpeechStarted event. if (this.#speaking) return; this.#speaking = true; - if (!this.queue.closed) { - try { - this.queue.put({ type: stt.SpeechEventType.START_OF_SPEECH }); - } catch (e) { - // ignore - } - } + putMessage({ type: stt.SpeechEventType.START_OF_SPEECH }); break; } // see this page: @@ -345,18 +347,18 @@ export class SpeechStream extends stt.SpeechStream { if (alternatives[0] && alternatives[0].text) { if (!this.#speaking) { this.#speaking = true; - this.queue.put({ + putMessage({ type: stt.SpeechEventType.START_OF_SPEECH, }); } if (isFinal) { - this.queue.put({ + putMessage({ type: stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives: [alternatives[0], ...alternatives.slice(1)], }); } else { - this.queue.put({ + putMessage({ type: stt.SpeechEventType.INTERIM_TRANSCRIPT, alternatives: [alternatives[0], ...alternatives.slice(1)], }); @@ -368,7 +370,7 @@ export class SpeechStream extends stt.SpeechStream { // a non-empty transcript (deepgram doesn't have a SpeechEnded event) if (isEndpoint && this.#speaking) { this.#speaking = false; - this.queue.put({ type: stt.SpeechEventType.END_OF_SPEECH }); + putMessage({ type: stt.SpeechEventType.END_OF_SPEECH }); } break; diff --git a/plugins/openai/src/stt.ts b/plugins/openai/src/stt.ts index 5a4b3c5bb..6f93a0770 100644 --- a/plugins/openai/src/stt.ts +++ b/plugins/openai/src/stt.ts @@ -27,7 +27,6 @@ export class STT extends stt.STT { #opts: STTOptions; #client: OpenAI; label = 'openai.STT'; - private abortController = new AbortController(); /** * Create a new instance of OpenAI STT. @@ -142,10 +141,11 @@ export class STT extends stt.STT { return Buffer.concat([header, Buffer.from(frame.data.buffer)]); } - async _recognize(buffer: AudioBuffer, language?: string): Promise { - const config = this.#sanitizeOptions(language); + async _recognize(buffer: AudioBuffer, abortSignal?: AbortSignal): Promise { + const config = this.#sanitizeOptions(); buffer = mergeFrames(buffer); - const file = new File([this.#createWav(buffer)], 'audio.wav', { type: 'audio/wav' }); + const wavBuffer = this.#createWav(buffer); + const file = new File([new Uint8Array(wavBuffer)], 'audio.wav', { type: 'audio/wav' }); const resp = await this.#client.audio.transcriptions.create( { @@ -156,7 +156,7 @@ export class STT extends stt.STT { response_format: 'json', }, { - signal: this.abortController.signal, + signal: abortSignal, }, ); @@ -165,7 +165,7 @@ export class STT extends stt.STT { alternatives: [ { text: resp.text || '', - language: language || '', + language: config.language || '', startTime: 0, endTime: 0, confidence: 0, @@ -178,8 +178,4 @@ export class STT extends stt.STT { stream(): stt.SpeechStream { throw new Error('Streaming is not supported on OpenAI STT'); } - - async close(): Promise { - this.abortController.abort(); - } }