diff --git a/.changeset/sour-mugs-lay.md b/.changeset/sour-mugs-lay.md new file mode 100644 index 000000000..45896253a --- /dev/null +++ b/.changeset/sour-mugs-lay.md @@ -0,0 +1,7 @@ +--- +'@livekit/agents-plugin-google': patch +'@livekit/agents-plugin-openai': patch +'@livekit/agents': patch +--- + +Support openai half-duplex mode (audio in -> text out -> custom TTS model) diff --git a/agents/src/llm/realtime.ts b/agents/src/llm/realtime.ts index 12cafb7b6..2a5860e2e 100644 --- a/agents/src/llm/realtime.ts +++ b/agents/src/llm/realtime.ts @@ -19,6 +19,7 @@ export interface MessageGeneration { messageId: string; textStream: ReadableStream; audioStream: ReadableStream; + modalities?: Promise<('text' | 'audio')[]>; } export interface GenerationCreatedEvent { @@ -40,6 +41,7 @@ export interface RealtimeCapabilities { turnDetection: boolean; userTranscription: boolean; autoToolReplyGeneration: boolean; + audioOutput: boolean; } export interface InputTranscriptionCompleted { @@ -121,7 +123,12 @@ export abstract class RealtimeSession extends EventEmitter { /** * Truncate the message at the given audio end time */ - abstract truncate(options: { messageId: string; audioEndMs: number }): Promise; + abstract truncate(options: { + messageId: string; + audioEndMs: number; + modalities?: ('text' | 'audio')[]; + audioTranscript?: string; + }): Promise; async close(): Promise { this._mainTask.cancel(); diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index 137b38dc7..0ee5bedd7 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -235,6 +235,14 @@ export class AgentActivity implements RecognitionHooks { } catch (error) { this.logger.error(error, 'failed to update the tools'); } + + if (!this.llm.capabilities.audioOutput && !this.tts && this.agentSession.output.audio) { + this.logger.error( + 'audio output is enabled but RealtimeModel has no audio modality ' + + 'and no TTS is set. Either enable audio modality in the RealtimeModel ' + + 'or set a TTS model.', + ); + } } else if (this.llm instanceof LLM) { try { updateInstructions({ @@ -1612,7 +1620,7 @@ export class AgentActivity implements RecognitionHooks { const readMessages = async ( abortController: AbortController, - outputs: Array<[string, _TextOut | null, _AudioOut | null]>, + outputs: Array<[string, _TextOut | null, _AudioOut | null, ('text' | 'audio')[] | undefined]>, ) => { replyAbortController.signal.addEventListener('abort', () => abortController.abort(), { once: true, @@ -1627,7 +1635,25 @@ export class AgentActivity implements RecognitionHooks { ); break; } - const trNodeResult = await this.agent.transcriptionNode(msg.textStream, modelSettings); + + const msgModalities = msg.modalities ? await msg.modalities : undefined; + let ttsTextInput: ReadableStream | null = null; + let trTextInput: ReadableStream; + + if (msgModalities && !msgModalities.includes('audio') && this.tts) { + if (this.llm instanceof RealtimeModel && this.llm.capabilities.audioOutput) { + this.logger.warn( + 'text response received from realtime API, falling back to use a TTS model.', + ); + } + const [_ttsTextInput, _trTextInput] = msg.textStream.tee(); + ttsTextInput = _ttsTextInput; + trTextInput = _trTextInput; + } else { + trTextInput = msg.textStream; + } + + const trNodeResult = await this.agent.transcriptionNode(trTextInput, modelSettings); let textOut: _TextOut | null = null; if (trNodeResult) { const [textForwardTask, _textOut] = performTextForwarding( @@ -1638,30 +1664,51 @@ export class AgentActivity implements RecognitionHooks { forwardTasks.push(textForwardTask); textOut = _textOut; } + let audioOut: _AudioOut | null = null; if (audioOutput) { - const realtimeAudio = await this.agent.realtimeAudioOutputNode( - msg.audioStream, - modelSettings, - ); - if (realtimeAudio) { + let realtimeAudioResult: ReadableStream | null = null; + + if (ttsTextInput) { + const [ttsTask, ttsStream] = performTTSInference( + (...args) => this.agent.ttsNode(...args), + ttsTextInput, + modelSettings, + abortController, + ); + tasks.push(ttsTask); + realtimeAudioResult = ttsStream; + } else if (msgModalities && msgModalities.includes('audio')) { + realtimeAudioResult = await this.agent.realtimeAudioOutputNode( + msg.audioStream, + modelSettings, + ); + } else if (this.llm instanceof RealtimeModel && this.llm.capabilities.audioOutput) { + this.logger.error( + 'Text message received from Realtime API with audio modality. ' + + 'This usually happens when text chat context is synced to the API. ' + + 'Try to add a TTS model as fallback or use text modality with TTS instead.', + ); + } else { + this.logger.warn( + 'audio output is enabled but neither tts nor realtime audio is available', + ); + } + + if (realtimeAudioResult) { const [forwardTask, _audioOut] = performAudioForwarding( - realtimeAudio, + realtimeAudioResult, audioOutput, abortController, ); forwardTasks.push(forwardTask); audioOut = _audioOut; audioOut.firstFrameFut.await.finally(onFirstFrame); - } else { - this.logger.warn( - 'audio output is enabled but neither tts nor realtime audio is available', - ); } } else if (textOut) { textOut.firstTextFut.await.finally(onFirstFrame); } - outputs.push([msg.messageId, textOut, audioOut]); + outputs.push([msg.messageId, textOut, audioOut, msgModalities]); } await waitFor(forwardTasks); } catch (error) { @@ -1671,7 +1718,9 @@ export class AgentActivity implements RecognitionHooks { } }; - const messageOutputs: Array<[string, _TextOut | null, _AudioOut | null]> = []; + const messageOutputs: Array< + [string, _TextOut | null, _AudioOut | null, ('text' | 'audio')[] | undefined] + > = []; const tasks = [ Task.from( (controller) => readMessages(controller, messageOutputs), @@ -1750,7 +1799,7 @@ export class AgentActivity implements RecognitionHooks { if (messageOutputs.length > 0) { // there should be only one message - const [msgId, textOut, audioOut] = messageOutputs[0]!; + const [msgId, textOut, audioOut, msgModalities] = messageOutputs[0]!; let forwardedText = textOut?.text || ''; if (audioOutput) { @@ -1775,6 +1824,8 @@ export class AgentActivity implements RecognitionHooks { this.realtimeSession.truncate({ messageId: msgId, audioEndMs: Math.floor(playbackPosition), + modalities: msgModalities, + audioTranscript: forwardedText, }); } @@ -1805,7 +1856,7 @@ export class AgentActivity implements RecognitionHooks { if (messageOutputs.length > 0) { // there should be only one message - const [msgId, textOut, _] = messageOutputs[0]!; + const [msgId, textOut, _, __] = messageOutputs[0]!; const message = ChatMessage.create({ role: 'assistant', content: textOut?.text || '', diff --git a/examples/src/realtime_with_tts.ts b/examples/src/realtime_with_tts.ts new file mode 100644 index 000000000..c8207c8c5 --- /dev/null +++ b/examples/src/realtime_with_tts.ts @@ -0,0 +1,77 @@ +// SPDX-FileCopyrightText: 2025 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { + type JobContext, + type JobProcess, + ServerOptions, + cli, + defineAgent, + llm, + log, + voice, +} from '@livekit/agents'; +import * as elevenlabs from '@livekit/agents-plugin-elevenlabs'; +import * as openai from '@livekit/agents-plugin-openai'; +import * as silero from '@livekit/agents-plugin-silero'; +import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; +import { fileURLToPath } from 'node:url'; +import { z } from 'zod'; + +export default defineAgent({ + prewarm: async (proc: JobProcess) => { + proc.userData.vad = await silero.VAD.load(); + }, + entry: async (ctx: JobContext) => { + const logger = log(); + + const getWeather = llm.tool({ + description: 'Called when the user asks about the weather.', + parameters: z.object({ + location: z.string().describe('The location to get the weather for'), + }), + execute: async ({ location }) => { + logger.info(`getting weather for ${location}`); + return `The weather in ${location} is sunny, and the temperature is 20 degrees Celsius.`; + }, + }); + + const agent = new voice.Agent({ + instructions: 'You are a helpful assistant. Always speak in English.', + tools: { + getWeather, + }, + }); + + const session = new voice.AgentSession({ + // Use RealtimeModel with text-only modality + separate TTS + llm: new openai.realtime.RealtimeModel({ + modalities: ['text'], + }), + tts: new elevenlabs.TTS(), + voiceOptions: { + maxToolSteps: 5, + }, + }); + + await session.start({ + agent, + room: ctx.room, + inputOptions: { + noiseCancellation: BackgroundVoiceCancellation(), + }, + outputOptions: { + transcriptionEnabled: true, + audioEnabled: true, // You can also disable audio output to use text modality only + }, + }); + + session.say('Hello, how can I help you today?'); + + session.on(voice.AgentSessionEventTypes.MetricsCollected, (ev) => { + logger.debug('metrics_collected', ev); + }); + }, +}); + +cli.runApp(new ServerOptions({ agent: fileURLToPath(import.meta.url) })); diff --git a/plugins/google/src/beta/realtime/realtime_api.ts b/plugins/google/src/beta/realtime/realtime_api.ts index e1857c0da..a1d888bd0 100644 --- a/plugins/google/src/beta/realtime/realtime_api.ts +++ b/plugins/google/src/beta/realtime/realtime_api.ts @@ -290,6 +290,7 @@ export class RealtimeModel extends llm.RealtimeModel { turnDetection: serverTurnDetection, userTranscription: inputAudioTranscription !== null, autoToolReplyGeneration: true, + audioOutput: options.modalities?.includes(Modality.AUDIO) ?? true, }); // Environment variable fallbacks @@ -600,7 +601,7 @@ export class RealtimeSession extends llm.RealtimeSession { this.hasReceivedAudioInput = true; for (const f of this.resampleAudio(frame)) { - for (const nf of this.bstream.write(f.data.buffer)) { + for (const nf of this.bstream.write(f.data.buffer as ArrayBuffer)) { const realtimeInput: types.LiveClientRealtimeInput = { mediaChunks: [ { diff --git a/plugins/openai/src/realtime/api_proto.ts b/plugins/openai/src/realtime/api_proto.ts index 75e66c3e6..4bc026f77 100644 --- a/plugins/openai/src/realtime/api_proto.ts +++ b/plugins/openai/src/realtime/api_proto.ts @@ -190,7 +190,7 @@ export interface SessionResource { id: string; object: 'realtime.session'; model: string; - modalities: ['text', 'audio'] | ['text']; // default: ["text", "audio"] + modalities: Modality[]; // default: ["text", "audio"] instructions: string; voice: Voice; // default: "alloy" input_audio_format: AudioFormat; // default: "pcm16" @@ -267,7 +267,7 @@ export interface SessionUpdateEvent extends BaseClientEvent { type: 'session.update'; session: Partial<{ model: Model; - modalities: ['text', 'audio'] | ['text']; + modalities: Modality[]; instructions: string; voice: Voice; input_audio_format: AudioFormat; @@ -350,7 +350,7 @@ export interface ConversationItemDeleteEvent extends BaseClientEvent { export interface ResponseCreateEvent extends BaseClientEvent { type: 'response.create'; response?: Partial<{ - modalities: ['text', 'audio'] | ['text']; + modalities: Modality[]; instructions: string; voice: Voice; output_audio_format: AudioFormat; @@ -511,6 +511,7 @@ export interface ResponseContentPartDoneEvent extends BaseServerEvent { export interface ResponseTextDeltaEvent extends BaseServerEvent { type: 'response.text.delta'; response_id: string; + item_id: string; output_index: number; content_index: number; delta: string; @@ -519,6 +520,7 @@ export interface ResponseTextDeltaEvent extends BaseServerEvent { export interface ResponseTextDoneEvent extends BaseServerEvent { type: 'response.text.done'; response_id: string; + item_id: string; output_index: number; content_index: number; text: string; diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index 05777d6ae..a67d3bf58 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -34,6 +34,8 @@ const BASE_URL = 'https://api.openai.com/v1'; const MOCK_AUDIO_ID_PREFIX = 'lk_mock_audio_item_'; +type Modality = 'text' | 'audio'; + interface RealtimeOptions { model: api_proto.Model; voice: api_proto.Voice; @@ -54,6 +56,7 @@ interface RealtimeOptions { maxSessionDuration: number; // reset the connection after this many seconds if provided connOptions: APIConnectOptions; + modalities: Modality[]; } interface MessageGeneration { @@ -61,6 +64,7 @@ interface MessageGeneration { textChannel: stream.StreamChannel; audioChannel: stream.StreamChannel; audioTranscript: string; + modalities: Future<('text' | 'audio')[]>; } interface ResponseGeneration { @@ -125,6 +129,7 @@ const DEFAULT_REALTIME_MODEL_OPTIONS = { maxResponseOutputTokens: DEFAULT_MAX_RESPONSE_OUTPUT_TOKENS, maxSessionDuration: DEFAULT_MAX_SESSION_DURATION, connOptions: DEFAULT_API_CONNECT_OPTIONS, + modalities: ['text', 'audio'] as Modality[], }; export class RealtimeModel extends llm.RealtimeModel { sampleRate = api_proto.SAMPLE_RATE; @@ -142,6 +147,7 @@ export class RealtimeModel extends llm.RealtimeModel { temperature?: number; toolChoice?: llm.ToolChoice; baseURL?: string; + modalities?: Modality[]; inputAudioTranscription?: api_proto.InputAudioTranscription | null; // TODO(shubhra): add inputAudioNoiseReduction turnDetection?: api_proto.TurnDetectionType | null; @@ -155,11 +161,15 @@ export class RealtimeModel extends llm.RealtimeModel { connOptions?: APIConnectOptions; } = {}, ) { + const modalities = (options.modalities || + DEFAULT_REALTIME_MODEL_OPTIONS.modalities) as Modality[]; + super({ messageTruncation: true, turnDetection: options.turnDetection !== null, userTranscription: options.inputAudioTranscription !== null, autoToolReplyGeneration: false, + audioOutput: modalities.includes('audio'), }); const isAzure = !!(options.apiVersion || options.entraToken || options.azureDeployment); @@ -188,13 +198,15 @@ export class RealtimeModel extends llm.RealtimeModel { options.baseURL = `${azureEndpoint.replace(/\/$/, '')}/openai`; } + const { modalities: _, ...optionsWithoutModalities } = options; this._options = { ...DEFAULT_REALTIME_MODEL_OPTIONS, - ...options, + ...optionsWithoutModalities, baseURL: options.baseURL || BASE_URL, apiKey, isAzure, model: options.model || DEFAULT_REALTIME_MODEL_OPTIONS.model, + modalities, }; } @@ -389,6 +401,12 @@ export class RealtimeSession extends llm.RealtimeSession { } private createSessionUpdateEvent(): api_proto.SessionUpdateEvent { + // OpenAI supports ['text'] or ['text', 'audio'] (audio always includes text transcript) + // We normalize to ensure 'text' is always present when using audio + const modalities: Modality[] = this.oaiRealtimeModel._options.modalities.includes('audio') + ? ['text', 'audio'] + : ['text']; + return { type: 'session.update', session: { @@ -396,7 +414,7 @@ export class RealtimeSession extends llm.RealtimeSession { voice: this.oaiRealtimeModel._options.voice, input_audio_format: 'pcm16', output_audio_format: 'pcm16', - modalities: ['text', 'audio'], + modalities: modalities, turn_detection: this.oaiRealtimeModel._options.turnDetection, input_audio_transcription: this.oaiRealtimeModel._options.inputAudioTranscription, // TODO(shubhra): add inputAudioNoiseReduction @@ -592,7 +610,7 @@ export class RealtimeSession extends llm.RealtimeSession { pushAudio(frame: AudioFrame): void { for (const f of this.resampleAudio(frame)) { - for (const nf of this.bstream.write(f.data.buffer)) { + for (const nf of this.bstream.write(f.data.buffer as ArrayBuffer)) { this.sendEvent({ type: 'input_audio_buffer.append', audio: Buffer.from(nf.data.buffer).toString('base64'), @@ -632,13 +650,38 @@ export class RealtimeSession extends llm.RealtimeSession { } as api_proto.ResponseCancelEvent); } - async truncate(_options: { messageId: string; audioEndMs: number }): Promise { - this.sendEvent({ - type: 'conversation.item.truncate', - content_index: 0, - item_id: _options.messageId, - audio_end_ms: _options.audioEndMs, - } as api_proto.ConversationItemTruncateEvent); + async truncate(_options: { + messageId: string; + audioEndMs: number; + modalities?: Modality[]; + audioTranscript?: string; + }): Promise { + if (!_options.modalities || _options.modalities.includes('audio')) { + this.sendEvent({ + type: 'conversation.item.truncate', + content_index: 0, + item_id: _options.messageId, + audio_end_ms: _options.audioEndMs, + } as api_proto.ConversationItemTruncateEvent); + } else if (_options.audioTranscript !== undefined) { + // sync it to the remote chat context + const chatCtx = this.chatCtx.copy(); + const idx = chatCtx.indexById(_options.messageId); + if (idx !== undefined) { + const item = chatCtx.items[idx]; + if (item && item.type === 'message') { + const newItem = llm.ChatMessage.create({ + ...item, + content: [_options.audioTranscript], + }); + chatCtx.items[idx] = newItem; + const events = this.createChatCtxUpdateEvents(chatCtx); + for (const ev of events) { + this.sendEvent(ev); + } + } + } + } } private loggableEvent( @@ -907,6 +950,12 @@ export class RealtimeSession extends llm.RealtimeSession { case 'response.content_part.done': this.handleResponseContentPartDone(event); break; + case 'response.text.delta': + this.handleResponseTextDelta(event); + break; + case 'response.text.done': + this.handleResponseTextDone(event); + break; case 'response.audio_transcript.delta': this.handleResponseAudioTranscriptDelta(event); break; @@ -1049,6 +1098,35 @@ export class RealtimeSession extends llm.RealtimeSession { this.textModeRecoveryRetries = 0; return; } + + const itemId = event.item.id; + if (!itemId) { + throw new Error('item.id is not set'); + } + + const modalitiesFut = new Future(); + const itemGeneration: MessageGeneration = { + messageId: itemId, + textChannel: stream.createStreamChannel(), + audioChannel: stream.createStreamChannel(), + audioTranscript: '', + modalities: modalitiesFut, + }; + + // If audioOutput is not supported, close audio channel immediately + if (!this.oaiRealtimeModel.capabilities.audioOutput) { + itemGeneration.audioChannel.close(); + modalitiesFut.resolve(['text']); + } + + this.currentGeneration.messageChannel.write({ + messageId: itemId, + textStream: itemGeneration.textChannel.stream(), + audioStream: itemGeneration.audioChannel.stream(), + modalities: modalitiesFut.await, + }); + + this.currentGeneration.messages.set(itemId, itemGeneration); } private handleConversationItemCreated(event: api_proto.ConversationItemCreatedEvent): void { @@ -1125,39 +1203,24 @@ export class RealtimeSession extends llm.RealtimeSession { const itemId = event.item_id; const itemType = event.part.type; - const responseId = event.response_id; - if (itemType === 'audio') { - this.resolveGeneration(responseId); - if (this.textModeRecoveryRetries > 0) { - this.#logger.info( - { retries: this.textModeRecoveryRetries }, - 'recovered from text-only response', - ); - this.textModeRecoveryRetries = 0; - } + const itemGeneration = this.currentGeneration.messages.get(itemId); + if (!itemGeneration) { + this.#logger.warn(`itemGeneration not found for itemId=${itemId}`); + return; + } - const itemGeneration: MessageGeneration = { - messageId: itemId, - textChannel: stream.createStreamChannel(), - audioChannel: stream.createStreamChannel(), - audioTranscript: '', - }; - - this.currentGeneration.messageChannel.write({ - messageId: itemId, - textStream: itemGeneration.textChannel.stream(), - audioStream: itemGeneration.audioChannel.stream(), - }); + if (itemType === 'text' && this.oaiRealtimeModel.capabilities.audioOutput) { + this.#logger.warn('Text response received from OpenAI Realtime API in audio modality.'); + } - this.currentGeneration.messages.set(itemId, itemGeneration); + if (!itemGeneration.modalities.done) { + const modalityResult: Modality[] = itemType === 'text' ? ['text'] : ['audio', 'text']; + itemGeneration.modalities.resolve(modalityResult); + } + + if (this.currentGeneration._firstTokenTimestamp === undefined) { this.currentGeneration._firstTokenTimestamp = Date.now(); - return; - } else { - this.interrupt(); - if (this.textModeRecoveryRetries === 0) { - this.#logger.warn({ responseId }, 'received text-only response from OpenAI Realtime API'); - } } } @@ -1173,6 +1236,33 @@ export class RealtimeSession extends llm.RealtimeSession { // TODO(shubhra): handle text mode recovery } + private handleResponseTextDelta(event: api_proto.ResponseTextDeltaEvent): void { + if (!this.currentGeneration) { + throw new Error('currentGeneration is not set'); + } + + const itemGeneration = this.currentGeneration.messages.get(event.item_id); + if (!itemGeneration) { + throw new Error('itemGeneration is not set'); + } + + if ( + !this.oaiRealtimeModel.capabilities.audioOutput && + !this.currentGeneration._firstTokenTimestamp + ) { + this.currentGeneration._firstTokenTimestamp = Date.now(); + } + + itemGeneration.textChannel.write(event.delta); + itemGeneration.audioTranscript += event.delta; + } + + private handleResponseTextDone(_event: api_proto.ResponseTextDoneEvent): void { + if (!this.currentGeneration) { + throw new Error('currentGeneration is not set'); + } + } + private handleResponseAudioTranscriptDelta( event: api_proto.ResponseAudioTranscriptDeltaEvent, ): void { @@ -1204,6 +1294,14 @@ export class RealtimeSession extends llm.RealtimeSession { throw new Error('itemGeneration is not set'); } + if (this.currentGeneration._firstTokenTimestamp === undefined) { + this.currentGeneration._firstTokenTimestamp = Date.now(); + } + + if (!itemGeneration.modalities.done) { + itemGeneration.modalities.resolve(['audio', 'text']); + } + const binaryString = atob(event.delta); const len = binaryString.length; const bytes = new Uint8Array(len); @@ -1261,6 +1359,10 @@ export class RealtimeSession extends llm.RealtimeSession { // text response doesn't have itemGeneration itemGeneration.textChannel.close(); itemGeneration.audioChannel.close(); + if (!itemGeneration.modalities.done) { + // In case message modalities is not set, this shouldn't happen + itemGeneration.modalities.resolve(this.oaiRealtimeModel._options.modalities); + } } } @@ -1284,6 +1386,9 @@ export class RealtimeSession extends llm.RealtimeSession { for (const generation of this.currentGeneration.messages.values()) { generation.textChannel.close(); generation.audioChannel.close(); + if (!generation.modalities.done) { + generation.modalities.resolve(this.oaiRealtimeModel._options.modalities); + } } this.currentGeneration.functionChannel.close();