Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
self._stored_exception: Exception | None = None

# Guardrails state tracking
self._interrupted_by_guardrail = False
self._interrupted_response_ids: set[str] = set()
self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript
self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count
self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get(
Expand Down Expand Up @@ -242,7 +242,8 @@ async def on_event(self, event: RealtimeModelEvent) -> None:

if current_length >= next_run_threshold:
self._item_guardrail_run_counts[item_id] += 1
self._enqueue_guardrail_task(self._item_transcripts[item_id])
# Pass response_id so we can ensure only a single interrupt per response
self._enqueue_guardrail_task(self._item_transcripts[item_id], event.response_id)
elif event.type == "item_updated":
is_new = not any(item.item_id == event.item.item_id for item in self._history)
self._history = self._get_new_history(self._history, event.item)
Expand Down Expand Up @@ -274,7 +275,6 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
# Clear guardrail state for next turn
self._item_transcripts.clear()
self._item_guardrail_run_counts.clear()
self._interrupted_by_guardrail = False

await self._put_event(
RealtimeAgentEndEvent(
Expand Down Expand Up @@ -442,7 +442,7 @@ def _get_new_history(
# Otherwise, add it to the end
return old_history + [event]

async def _run_output_guardrails(self, text: str) -> bool:
async def _run_output_guardrails(self, text: str, response_id: str) -> bool:
"""Run output guardrails on the given text. Returns True if any guardrail was triggered."""
combined_guardrails = self._current_agent.output_guardrails + self._run_config.get(
"output_guardrails", []
Expand All @@ -455,7 +455,8 @@ async def _run_output_guardrails(self, text: str) -> bool:
output_guardrails.append(guardrail)
seen_ids.add(guardrail_id)

if not output_guardrails or self._interrupted_by_guardrail:
# If we've already interrupted this response, skip
if not output_guardrails or response_id in self._interrupted_response_ids:
return False

triggered_results = []
Expand All @@ -475,8 +476,12 @@ async def _run_output_guardrails(self, text: str) -> bool:
continue

if triggered_results:
# Mark as interrupted to prevent multiple interrupts
self._interrupted_by_guardrail = True
# Double-check: bail if already interrupted for this response
if response_id in self._interrupted_response_ids:
return False

# Mark as interrupted immediately (before any awaits) to minimize race window
self._interrupted_response_ids.add(response_id)

# Emit guardrail tripped event
await self._put_event(
Expand All @@ -502,10 +507,10 @@ async def _run_output_guardrails(self, text: str) -> bool:

return False

def _enqueue_guardrail_task(self, text: str) -> None:
def _enqueue_guardrail_task(self, text: str, response_id: str) -> None:
# Runs the guardrails in a separate task to avoid blocking the main loop

task = asyncio.create_task(self._run_output_guardrails(text))
task = asyncio.create_task(self._run_output_guardrails(text, response_id))
self._guardrail_tasks.add(task)

# Add callback to remove completed tasks and handle exceptions
Expand Down
61 changes: 57 additions & 4 deletions tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,6 @@ async def test_transcript_delta_triggers_guardrail_at_threshold(
await self._wait_for_guardrail_tasks(session)

# Should have triggered guardrail and interrupted
assert session._interrupted_by_guardrail is True
assert mock_model.interrupts_called == 1
assert len(mock_model.sent_messages) == 1
assert "triggered_guardrail" in mock_model.sent_messages[0]
Expand Down Expand Up @@ -1187,14 +1186,12 @@ async def test_turn_ended_clears_guardrail_state(
# Wait for async guardrail tasks to complete
await self._wait_for_guardrail_tasks(session)

assert session._interrupted_by_guardrail is True
assert len(session._item_transcripts) == 1

# End turn
await session.on_event(RealtimeModelTurnEndedEvent())

# State should be cleared
assert session._interrupted_by_guardrail is False
assert len(session._item_transcripts) == 0
assert len(session._item_guardrail_run_counts) == 0

Expand Down Expand Up @@ -1259,7 +1256,6 @@ async def test_agent_output_guardrails_triggered(self, mock_model, triggered_gua
await session.on_event(transcript_event)
await self._wait_for_guardrail_tasks(session)

assert session._interrupted_by_guardrail is True
assert mock_model.interrupts_called == 1
assert len(mock_model.sent_messages) == 1
assert "triggered_guardrail" in mock_model.sent_messages[0]
Expand All @@ -1272,6 +1268,63 @@ async def test_agent_output_guardrails_triggered(self, mock_model, triggered_gua
assert len(guardrail_events) == 1
assert guardrail_events[0].message == "this is more than ten characters"

@pytest.mark.asyncio
async def test_concurrent_guardrail_tasks_interrupt_once_per_response(self, mock_model):
"""Even if multiple guardrail tasks trigger concurrently for the same response_id,
only the first should interrupt and send a message."""
import asyncio

# Barrier to release both guardrail tasks at the same time
start_event = asyncio.Event()

async def async_trigger_guardrail(context, agent, output):
await start_event.wait()
return GuardrailFunctionOutput(
output_info={"reason": "concurrent"}, tripwire_triggered=True
)

concurrent_guardrail = OutputGuardrail(
guardrail_function=async_trigger_guardrail, name="concurrent_trigger"
)

run_config: RealtimeRunConfig = {
"output_guardrails": [concurrent_guardrail],
"guardrails_settings": {"debounce_text_length": 5},
}

# Use a minimal agent (guardrails from run_config)
agent = RealtimeAgent(name="agent")
session = RealtimeSession(mock_model, agent, None, run_config=run_config)

# Two deltas for same item and response to enqueue two guardrail tasks
await session.on_event(
RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="12345", response_id="resp_same"
)
)
await session.on_event(
RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="67890", response_id="resp_same"
)
)

# Wait until both tasks are enqueued
for _ in range(50):
if len(session._guardrail_tasks) >= 2:
break
await asyncio.sleep(0.01)

# Release both tasks concurrently
start_event.set()

# Wait for completion
if session._guardrail_tasks:
await asyncio.gather(*session._guardrail_tasks, return_exceptions=True)

# Only one interrupt and one message should be sent
assert mock_model.interrupts_called == 1
assert len(mock_model.sent_messages) == 1


class TestModelSettingsIntegration:
"""Test suite for model settings integration in RealtimeSession."""
Expand Down