Skip to content

Commit 68ff0ba

Browse files
fix(realtime): guardrail interruption for realtime session (#17)
* fix(realtime): send cancel before audio starts * fix guardrails * remove logging
1 parent 71399e7 commit 68ff0ba

File tree

9 files changed

+79
-35
lines changed

9 files changed

+79
-35
lines changed

.changeset/wise-results-mate.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@openai/agents-realtime': patch
3+
---
4+
5+
fix: avoid realtime guardrail race condition and detect ongoing response

examples/realtime-next/src/app/websocket/page.tsx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
TransportEvent,
88
RealtimeItem,
99
OutputGuardrailTripwireTriggered,
10+
RealtimeOutputGuardrail,
1011
} from '@openai/agents/realtime';
1112
import { useEffect, useRef, useState } from 'react';
1213
import { z } from 'zod';
@@ -26,6 +27,21 @@ const refundBackchannel = tool({
2627
},
2728
});
2829

30+
const guardrails: RealtimeOutputGuardrail[] = [
31+
{
32+
name: 'No mention of Dom',
33+
execute: async ({ agentOutput }) => {
34+
const domInOutput = agentOutput.includes('Dom');
35+
return {
36+
tripwireTriggered: domInOutput,
37+
outputInfo: {
38+
domInOutput,
39+
},
40+
};
41+
},
42+
},
43+
];
44+
2945
const agent = new RealtimeAgent({
3046
name: 'Greeter',
3147
instructions:
@@ -48,6 +64,7 @@ export default function Home() {
4864
useEffect(() => {
4965
session.current = new RealtimeSession(agent, {
5066
transport: 'websocket',
67+
outputGuardrails: guardrails,
5168
});
5269
recorder.current = new WavRecorder({ sampleRate: 24000 });
5370
player.current = new WavStreamPlayer({ sampleRate: 24000 });
@@ -87,6 +104,7 @@ export default function Home() {
87104
async function connect() {
88105
if (isConnected) {
89106
await session.current?.close();
107+
await player.current?.interrupt();
90108
await recorder.current?.end();
91109
setIsConnected(false);
92110
} else {

packages/agents-realtime/src/openaiRealtimeBase.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ export abstract class OpenAIRealtimeBase
256256
type: 'transcript_delta',
257257
delta: parsed.delta,
258258
itemId: parsed.item_id,
259+
responseId: parsed.response_id,
259260
});
260261
}
261262
// no support for partial transcripts yet.

packages/agents-realtime/src/openaiRealtimeEvents.ts

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@ import type { MessageEvent as WebSocketMessageEvent } from 'ws';
66
// provide better runtime validation when parsing events from the server.
77

88
export const realtimeResponse = z.object({
9-
id: z.string().optional(),
10-
conversation_id: z.string().optional(),
11-
max_output_tokens: z.number().or(z.literal('inf')).optional(),
9+
id: z.string().optional().nullable(),
10+
conversation_id: z.string().optional().nullable(),
11+
max_output_tokens: z.number().or(z.literal('inf')).optional().nullable(),
1212
metadata: z.record(z.string(), z.any()).optional().nullable(),
13-
modalities: z.array(z.string()).optional(),
14-
object: z.literal('realtime.response').optional(),
15-
output: z.array(z.any()).optional(),
16-
output_audio_format: z.string().optional(),
17-
status: z.enum(['completed', 'incomplete', 'failed', 'cancelled']).optional(),
13+
modalities: z.array(z.string()).optional().nullable(),
14+
object: z.literal('realtime.response').optional().nullable(),
15+
output: z.array(z.any()).optional().nullable(),
16+
output_audio_format: z.string().optional().nullable(),
17+
status: z
18+
.enum(['completed', 'incomplete', 'failed', 'cancelled', 'in_progress'])
19+
.optional()
20+
.nullable(),
1821
status_details: z.record(z.string(), z.any()).optional().nullable(),
1922
usage: z
2023
.object({
@@ -26,8 +29,9 @@ export const realtimeResponse = z.object({
2629
.optional()
2730
.nullable(),
2831
})
29-
.optional(),
30-
voice: z.string().optional(),
32+
.optional()
33+
.nullable(),
34+
voice: z.string().optional().nullable(),
3135
});
3236

3337
// Basic content schema used by ConversationItem.
@@ -315,7 +319,6 @@ export const responseDoneEventSchema = z.object({
315319
type: z.literal('response.done'),
316320
event_id: z.string(),
317321
response: realtimeResponse,
318-
test: z.boolean(),
319322
});
320323

321324
export const responseFunctionCallArgumentsDeltaEventSchema = z.object({

packages/agents-realtime/src/openaiRealtimeWebRtc.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ export class OpenAIRealtimeWebRTC
196196
if (!parsed || isGeneric) {
197197
return;
198198
}
199+
199200
if (parsed.type === 'response.created') {
200201
this.#ongoingResponse = true;
201202
} else if (parsed.type === 'response.done') {
@@ -334,6 +335,7 @@ export class OpenAIRealtimeWebRTC
334335
this.sendEvent({
335336
type: 'response.cancel',
336337
});
338+
this.#ongoingResponse = false;
337339
}
338340

339341
this.sendEvent({

packages/agents-realtime/src/openaiRealtimeWebsocket.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ export class OpenAIRealtimeWebSocket
328328
this.sendEvent({
329329
type: 'response.cancel',
330330
});
331+
this.#ongoingResponse = false;
331332
}
332333
}
333334

@@ -367,8 +368,6 @@ export class OpenAIRealtimeWebSocket
367368
this._cancelResponse();
368369

369370
const elapsedTime = Date.now() - this._firstAudioTimestamp;
370-
console.log(`Interrupting response after ${elapsedTime}ms`);
371-
console.log(`Audio length: ${this._audioLengthMs}ms`);
372371
if (elapsedTime >= 0 && elapsedTime < this._audioLengthMs) {
373372
this._interrupt(elapsedTime);
374373
}

packages/agents-realtime/src/realtimeSession.ts

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ export class RealtimeSession<
180180
#transcribedTextDeltas: Record<string, string> = {};
181181
#history: RealtimeItem[] = [];
182182
#shouldIncludeAudioData: boolean;
183+
#interruptedByGuardrail: Record<string, boolean> = {};
183184

184185
constructor(
185186
public readonly initialAgent:
@@ -446,7 +447,7 @@ export class RealtimeSession<
446447
}
447448
}
448449

449-
async #runOutputGuardrails(output: string) {
450+
async #runOutputGuardrails(output: string, responseId: string) {
450451
if (this.#outputGuardrails.length === 0) {
451452
return;
452453
}
@@ -460,24 +461,28 @@ export class RealtimeSession<
460461
this.#outputGuardrails.map((guardrail) => guardrail.run(guardrailArgs)),
461462
);
462463

463-
for (const result of results) {
464-
if (result.output.tripwireTriggered) {
465-
const error = new OutputGuardrailTripwireTriggered(
466-
`Output guardrail triggered: ${JSON.stringify(result.output.outputInfo)}`,
467-
result,
468-
);
469-
this.emit(
470-
'guardrail_tripped',
471-
this.#context,
472-
this.#currentAgent,
473-
error,
474-
);
475-
this.interrupt();
476-
477-
const feedbackText = getRealtimeGuardrailFeedbackMessage(result);
478-
this.sendMessage(feedbackText);
479-
break;
464+
const firstTripwireTriggered = results.find(
465+
(result) => result.output.tripwireTriggered,
466+
);
467+
if (firstTripwireTriggered) {
468+
// this ensures that if one guardrail already trips and we are in the middle of another
469+
// guardrail run, we don't trip again
470+
if (this.#interruptedByGuardrail[responseId]) {
471+
return;
480472
}
473+
this.#interruptedByGuardrail[responseId] = true;
474+
const error = new OutputGuardrailTripwireTriggered(
475+
`Output guardrail triggered: ${JSON.stringify(firstTripwireTriggered.output.outputInfo)}`,
476+
firstTripwireTriggered,
477+
);
478+
this.emit('guardrail_tripped', this.#context, this.#currentAgent, error);
479+
this.interrupt();
480+
481+
const feedbackText = getRealtimeGuardrailFeedbackMessage(
482+
firstTripwireTriggered,
483+
);
484+
this.sendMessage(feedbackText);
485+
return;
481486
}
482487
}
483488

@@ -498,7 +503,7 @@ export class RealtimeSession<
498503
this.emit('agent_end', this.#context, this.#currentAgent, textOutput);
499504
this.#currentAgent.emit('agent_end', this.#context, textOutput);
500505

501-
this.#runOutputGuardrails(textOutput);
506+
this.#runOutputGuardrails(textOutput, event.response.id);
502507
});
503508

504509
this.#transport.on('audio_done', () => {
@@ -511,6 +516,7 @@ export class RealtimeSession<
511516
try {
512517
const delta = event.delta;
513518
const itemId = event.itemId;
519+
const responseId = event.responseId;
514520
if (lastItemId !== itemId) {
515521
lastItemId = itemId;
516522
lastRunIndex = 0;
@@ -531,7 +537,7 @@ export class RealtimeSession<
531537
// We don't cancel existing runs because we want the first one to fail to fail
532538
// The transport layer should upon failure handle the interruption and stop the model
533539
// from generating further
534-
this.#runOutputGuardrails(newText);
540+
this.#runOutputGuardrails(newText, responseId);
535541
}
536542
} catch (err) {
537543
this.emit('error', {
@@ -672,6 +678,7 @@ export class RealtimeSession<
672678
* Disconnect from the session.
673679
*/
674680
close() {
681+
this.#interruptedByGuardrail = {};
675682
this.#transport.close();
676683
}
677684

packages/agents-realtime/src/transportLayerEvents.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ export type TransportLayerTranscriptDelta = {
3434
type: 'transcript_delta';
3535
itemId: string;
3636
delta: string;
37+
responseId: string;
3738
};
3839

3940
export type TransportLayerResponseCompleted =

packages/agents-realtime/test/realtimeSession.test.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,16 @@ describe('RealtimeSession', () => {
171171
outputGuardrailSettings: { debounceTextLength: 1 },
172172
});
173173
await s.connect({ apiKey: 'test' });
174-
t.emit('audio_transcript_delta', { delta: 'a', itemId: '1' } as any);
175-
t.emit('audio_transcript_delta', { delta: 'a', itemId: '2' } as any);
174+
t.emit('audio_transcript_delta', {
175+
delta: 'a',
176+
itemId: '1',
177+
responseId: 'z',
178+
} as any);
179+
t.emit('audio_transcript_delta', {
180+
delta: 'a',
181+
itemId: '2',
182+
responseId: 'z',
183+
} as any);
176184
await vi.waitFor(() => expect(runMock).toHaveBeenCalledTimes(2));
177185
vi.restoreAllMocks();
178186
});

0 commit comments

Comments
 (0)