Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/strong-lobsters-repair.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@openai/agents-extensions": patch
"@openai/agents-realtime": patch
---

fix(realtime-ws): stop accidental cancellation error
4 changes: 2 additions & 2 deletions packages/agents-extensions/src/TwilioRealtimeTransport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ export class TwilioRealtimeTransportLayer extends OpenAIRealtimeWebSocket {
super.updateSessionConfig(newConfig);
}

_interrupt(_elapsedTime: number) {
_interrupt(_elapsedTime: number, cancelOngoingResponse: boolean = true) {
const elapsedTime = this.#lastPlayedChunkCount + 50; /* 50ms buffer */
this.#logger.debug(
`Interruption detected, clearing Twilio audio and truncating OpenAI audio after ${elapsedTime}ms`,
Expand All @@ -192,7 +192,7 @@ export class TwilioRealtimeTransportLayer extends OpenAIRealtimeWebSocket {
streamSid: this.#streamSid,
}),
);
super._interrupt(elapsedTime);
super._interrupt(elapsedTime, cancelOngoingResponse);
}

protected _onAudio(audioEvent: TransportLayerAudio) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ describe('TwilioRealtimeTransportLayer', () => {
toString: () => JSON.stringify({ event: 'mark', mark: { name: 'u:5' } }),
});
transport._interrupt(0);
expect(interruptSpy).toHaveBeenCalledWith(55);
expect(interruptSpy).toHaveBeenCalledWith(55, true);
expect(twilio.send).toHaveBeenCalledWith(
JSON.stringify({ event: 'clear', streamSid: 'sid' }),
);
Expand Down
5 changes: 4 additions & 1 deletion packages/agents-extensions/test/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ describe('TwilioRealtimeTransportLayer', () => {
const payload = { event: 'mark', mark: { name: 'badmark' } };
twilio.emit('message', { toString: () => JSON.stringify(payload) });

transport._interrupt(0);
transport._interrupt(0, false);
// @ts-expect-error - we're testing protected fields
transport._audioLengthMs = 500;
transport._interrupt(0, true);

const call = sendEventSpy.mock.calls.find(
(c) => c[0]?.type === 'conversation.item.truncate',
Expand Down
9 changes: 9 additions & 0 deletions packages/agents-realtime/src/openaiRealtimeBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ export abstract class OpenAIRealtimeBase
#model: string;
#apiKey: ApiKey | undefined;
#tracingConfig: RealtimeTracingConfig | null = null;
#rawSessionConfig: Record<string, any> | null = null;

protected eventEmitter: RuntimeEventEmitter<OpenAIRealtimeEventTypes> =
new RuntimeEventEmitter<OpenAIRealtimeEventTypes>();
Expand Down Expand Up @@ -149,6 +150,10 @@ export abstract class OpenAIRealtimeBase

abstract readonly muted: boolean | null;

protected get _rawSessionConfig(): Record<string, any> | null {
return this.#rawSessionConfig ?? null;
}

protected async _getApiKey(options: RealtimeTransportLayerConnectOptions) {
const apiKey = options.apiKey ?? this.#apiKey;

Expand Down Expand Up @@ -186,6 +191,10 @@ export abstract class OpenAIRealtimeBase
return;
}

if (parsed.type === 'session.updated') {
this.#rawSessionConfig = parsed.session;
}

if (parsed.type === 'response.done') {
const response = responseDoneEventSchema.safeParse(parsed);
if (!response.success) {
Expand Down
33 changes: 25 additions & 8 deletions packages/agents-realtime/src/openaiRealtimeWebsocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,15 @@ export class OpenAIRealtimeWebSocket

const buff = base64ToArrayBuffer(parsed.delta);
// calculate the audio length in milliseconds assuming 24kHz pcm16le
this._audioLengthMs += buff.byteLength / 24 / 2; // 24kHz * 2 bytes per sample
const audioFormat =
this._rawSessionConfig?.output_audio_format ?? 'pcm16';
if (audioFormat.startsWith('g711_')) {
// 8kHz * 1 byte per sample
this._audioLengthMs += buff.byteLength / 8;
} else {
// 24kHz * 2 bytes per sample
this._audioLengthMs += buff.byteLength / 24 / 2;
}

const audioEvent: TransportLayerAudio = {
type: 'audio',
Expand All @@ -224,7 +232,9 @@ export class OpenAIRealtimeWebSocket
};
this._onAudio(audioEvent);
} else if (parsed.type === 'input_audio_buffer.speech_started') {
this.interrupt();
const automaticResponseCancellationEnabled =
this._rawSessionConfig?.turn_detection?.interrupt_response ?? false;
this.interrupt(!automaticResponseCancellationEnabled);
} else if (parsed.type === 'response.created') {
this.#ongoingResponse = true;
} else if (parsed.type === 'response.done') {
Expand Down Expand Up @@ -343,8 +353,16 @@ export class OpenAIRealtimeWebSocket
*
* @param elapsedTime - The elapsed time since the response started.
*/
_interrupt(elapsedTime: number) {
_interrupt(elapsedTime: number, cancelOngoingResponse: boolean = true) {
if (elapsedTime < 0 || elapsedTime > this._audioLengthMs) {
return;
}

// immediately emit this event so the client can stop playing audio
if (cancelOngoingResponse) {
this._cancelResponse();
}

this.emit('audio_interrupted');
this.sendEvent({
type: 'conversation.item.truncate',
Expand All @@ -362,16 +380,15 @@ export class OpenAIRealtimeWebSocket
* You can also call this method directly if you want to interrupt the conversation for example
* based on an event in the client.
*/
interrupt() {
interrupt(cancelOngoingResponse: boolean = true) {
if (!this.#currentItemId || typeof this._firstAudioTimestamp !== 'number') {
return;
}

this._cancelResponse();

const elapsedTime = Date.now() - this._firstAudioTimestamp;
if (elapsedTime >= 0 && elapsedTime < this._audioLengthMs) {
this._interrupt(elapsedTime);

if (elapsedTime >= 0) {
this._interrupt(elapsedTime, cancelOngoingResponse);
}

this.#currentItemId = undefined;
Expand Down