diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 8b3b10936..4629f1bb5 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -98,7 +98,7 @@ def __init__( self._stored_exception: Exception | None = None # Guardrails state tracking - self._interrupted_by_guardrail = False + self._interrupted_response_ids: set[str] = set() self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get( @@ -242,7 +242,8 @@ async def on_event(self, event: RealtimeModelEvent) -> None: if current_length >= next_run_threshold: self._item_guardrail_run_counts[item_id] += 1 - self._enqueue_guardrail_task(self._item_transcripts[item_id]) + # Pass response_id so we can ensure only a single interrupt per response + self._enqueue_guardrail_task(self._item_transcripts[item_id], event.response_id) elif event.type == "item_updated": is_new = not any(item.item_id == event.item.item_id for item in self._history) self._history = self._get_new_history(self._history, event.item) @@ -274,7 +275,6 @@ async def on_event(self, event: RealtimeModelEvent) -> None: # Clear guardrail state for next turn self._item_transcripts.clear() self._item_guardrail_run_counts.clear() - self._interrupted_by_guardrail = False await self._put_event( RealtimeAgentEndEvent( @@ -442,7 +442,7 @@ def _get_new_history( # Otherwise, add it to the end return old_history + [event] - async def _run_output_guardrails(self, text: str) -> bool: + async def _run_output_guardrails(self, text: str, response_id: str) -> bool: """Run output guardrails on the given text. Returns True if any guardrail was triggered.""" combined_guardrails = self._current_agent.output_guardrails + self._run_config.get( "output_guardrails", [] @@ -455,7 +455,8 @@ async def _run_output_guardrails(self, text: str) -> bool: output_guardrails.append(guardrail) seen_ids.add(guardrail_id) - if not output_guardrails or self._interrupted_by_guardrail: + # If we've already interrupted this response, skip + if not output_guardrails or response_id in self._interrupted_response_ids: return False triggered_results = [] @@ -475,8 +476,12 @@ async def _run_output_guardrails(self, text: str) -> bool: continue if triggered_results: - # Mark as interrupted to prevent multiple interrupts - self._interrupted_by_guardrail = True + # Double-check: bail if already interrupted for this response + if response_id in self._interrupted_response_ids: + return False + + # Mark as interrupted immediately (before any awaits) to minimize race window + self._interrupted_response_ids.add(response_id) # Emit guardrail tripped event await self._put_event( @@ -502,10 +507,10 @@ async def _run_output_guardrails(self, text: str) -> bool: return False - def _enqueue_guardrail_task(self, text: str) -> None: + def _enqueue_guardrail_task(self, text: str, response_id: str) -> None: # Runs the guardrails in a separate task to avoid blocking the main loop - task = asyncio.create_task(self._run_output_guardrails(text)) + task = asyncio.create_task(self._run_output_guardrails(text, response_id)) self._guardrail_tasks.add(task) # Add callback to remove completed tasks and handle exceptions diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index e5d2d5d45..3b6c5bac6 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -1050,7 +1050,6 @@ async def test_transcript_delta_triggers_guardrail_at_threshold( await self._wait_for_guardrail_tasks(session) # Should have triggered guardrail and interrupted - assert session._interrupted_by_guardrail is True assert mock_model.interrupts_called == 1 assert len(mock_model.sent_messages) == 1 assert "triggered_guardrail" in mock_model.sent_messages[0] @@ -1187,14 +1186,12 @@ async def test_turn_ended_clears_guardrail_state( # Wait for async guardrail tasks to complete await self._wait_for_guardrail_tasks(session) - assert session._interrupted_by_guardrail is True assert len(session._item_transcripts) == 1 # End turn await session.on_event(RealtimeModelTurnEndedEvent()) # State should be cleared - assert session._interrupted_by_guardrail is False assert len(session._item_transcripts) == 0 assert len(session._item_guardrail_run_counts) == 0 @@ -1259,7 +1256,6 @@ async def test_agent_output_guardrails_triggered(self, mock_model, triggered_gua await session.on_event(transcript_event) await self._wait_for_guardrail_tasks(session) - assert session._interrupted_by_guardrail is True assert mock_model.interrupts_called == 1 assert len(mock_model.sent_messages) == 1 assert "triggered_guardrail" in mock_model.sent_messages[0] @@ -1272,6 +1268,63 @@ async def test_agent_output_guardrails_triggered(self, mock_model, triggered_gua assert len(guardrail_events) == 1 assert guardrail_events[0].message == "this is more than ten characters" + @pytest.mark.asyncio + async def test_concurrent_guardrail_tasks_interrupt_once_per_response(self, mock_model): + """Even if multiple guardrail tasks trigger concurrently for the same response_id, + only the first should interrupt and send a message.""" + import asyncio + + # Barrier to release both guardrail tasks at the same time + start_event = asyncio.Event() + + async def async_trigger_guardrail(context, agent, output): + await start_event.wait() + return GuardrailFunctionOutput( + output_info={"reason": "concurrent"}, tripwire_triggered=True + ) + + concurrent_guardrail = OutputGuardrail( + guardrail_function=async_trigger_guardrail, name="concurrent_trigger" + ) + + run_config: RealtimeRunConfig = { + "output_guardrails": [concurrent_guardrail], + "guardrails_settings": {"debounce_text_length": 5}, + } + + # Use a minimal agent (guardrails from run_config) + agent = RealtimeAgent(name="agent") + session = RealtimeSession(mock_model, agent, None, run_config=run_config) + + # Two deltas for same item and response to enqueue two guardrail tasks + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="12345", response_id="resp_same" + ) + ) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="67890", response_id="resp_same" + ) + ) + + # Wait until both tasks are enqueued + for _ in range(50): + if len(session._guardrail_tasks) >= 2: + break + await asyncio.sleep(0.01) + + # Release both tasks concurrently + start_event.set() + + # Wait for completion + if session._guardrail_tasks: + await asyncio.gather(*session._guardrail_tasks, return_exceptions=True) + + # Only one interrupt and one message should be sent + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + class TestModelSettingsIntegration: """Test suite for model settings integration in RealtimeSession."""