diff --git a/examples/mcp/prompt_server/main.py b/examples/mcp/prompt_server/main.py index 8f2991fc0..4caa95d88 100644 --- a/examples/mcp/prompt_server/main.py +++ b/examples/mcp/prompt_server/main.py @@ -17,7 +17,7 @@ async def get_instructions_from_prompt(mcp_server: MCPServer, prompt_name: str, try: prompt_result = await mcp_server.get_prompt(prompt_name, kwargs) content = prompt_result.messages[0].content - if hasattr(content, 'text'): + if hasattr(content, "text"): instructions = content.text else: instructions = str(content) diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 1e9edcbc6..edb692960 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -42,15 +42,18 @@ def validate_from_none(value: None) -> _Omit: serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None), ) + @dataclass class MCPToolChoice: server_label: str name: str + Omit = Annotated[_Omit, _OmitTypeAnnotation] Headers: TypeAlias = Mapping[str, Union[str, Omit]] ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None] + @dataclass class ModelSettings: """Settings to use when calling an LLM. diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index d25613aee..76c67903c 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -343,7 +343,7 @@ def convert_tool_choice( elif tool_choice == "mcp": # Note that this is still here for backwards compatibility, # but migrating to MCPToolChoice is recommended. - return { "type": "mcp" } # type: ignore [typeddict-item] + return {"type": "mcp"} # type: ignore [typeddict-item] else: return { "type": "function", diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 2cd71b023..f6bd60064 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -974,27 +974,30 @@ class TestGuardrailFunctionality: async def _wait_for_guardrail_tasks(self, session): """Wait for all pending guardrail tasks to complete.""" import asyncio + if session._guardrail_tasks: await asyncio.gather(*session._guardrail_tasks, return_exceptions=True) @pytest.fixture def triggered_guardrail(self): """Creates a guardrail that always triggers""" + def guardrail_func(context, agent, output): return GuardrailFunctionOutput( - output_info={"reason": "test trigger"}, - tripwire_triggered=True + output_info={"reason": "test trigger"}, tripwire_triggered=True ) + return OutputGuardrail(guardrail_function=guardrail_func, name="triggered_guardrail") @pytest.fixture def safe_guardrail(self): """Creates a guardrail that never triggers""" + def guardrail_func(context, agent, output): return GuardrailFunctionOutput( - output_info={"reason": "safe content"}, - tripwire_triggered=False + output_info={"reason": "safe content"}, tripwire_triggered=False ) + return OutputGuardrail(guardrail_function=guardrail_func, name="safe_guardrail") @pytest.mark.asyncio @@ -1004,7 +1007,7 @@ async def test_transcript_delta_triggers_guardrail_at_threshold( """Test that guardrails run when transcript delta reaches debounce threshold""" run_config: RealtimeRunConfig = { "output_guardrails": [triggered_guardrail], - "guardrails_settings": {"debounce_text_length": 10} + "guardrails_settings": {"debounce_text_length": 10}, } session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) @@ -1041,20 +1044,20 @@ async def test_transcript_delta_multiple_thresholds_same_item( """Test guardrails run at 1x, 2x, 3x thresholds for same item_id""" run_config: RealtimeRunConfig = { "output_guardrails": [triggered_guardrail], - "guardrails_settings": {"debounce_text_length": 5} + "guardrails_settings": {"debounce_text_length": 5}, } session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) # First delta - reaches 1x threshold (5 chars) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="12345", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="12345", response_id="resp_1") + ) # Second delta - reaches 2x threshold (10 chars total) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="67890", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="67890", response_id="resp_1") + ) # Wait for async guardrail tasks to complete await self._wait_for_guardrail_tasks(session) @@ -1070,28 +1073,32 @@ async def test_transcript_delta_different_items_tracked_separately( """Test that different item_ids are tracked separately for debouncing""" run_config: RealtimeRunConfig = { "output_guardrails": [safe_guardrail], - "guardrails_settings": {"debounce_text_length": 10} + "guardrails_settings": {"debounce_text_length": 10}, } session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) # Add text to item_1 (8 chars - below threshold) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="12345678", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="12345678", response_id="resp_1" + ) + ) # Add text to item_2 (8 chars - below threshold) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_2", delta="abcdefgh", response_id="resp_2" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_2", delta="abcdefgh", response_id="resp_2" + ) + ) # Neither should trigger guardrails yet assert mock_model.interrupts_called == 0 # Add more text to item_1 (total 12 chars - above threshold) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="90ab", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="90ab", response_id="resp_1") + ) # item_1 should have triggered guardrail run (but not interrupted since safe) assert session._item_guardrail_run_counts["item_1"] == 1 @@ -1107,15 +1114,17 @@ async def test_turn_ended_clears_guardrail_state( """Test that turn_ended event clears guardrail state for next turn""" run_config: RealtimeRunConfig = { "output_guardrails": [triggered_guardrail], - "guardrails_settings": {"debounce_text_length": 5} + "guardrails_settings": {"debounce_text_length": 5}, } session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) # Trigger guardrail - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="trigger", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="trigger", response_id="resp_1" + ) + ) # Wait for async guardrail tasks to complete await self._wait_for_guardrail_tasks(session) @@ -1132,16 +1141,13 @@ async def test_turn_ended_clears_guardrail_state( assert len(session._item_guardrail_run_counts) == 0 @pytest.mark.asyncio - async def test_multiple_guardrails_all_triggered( - self, mock_model, mock_agent - ): + async def test_multiple_guardrails_all_triggered(self, mock_model, mock_agent): """Test that all triggered guardrails are included in the event""" + def create_triggered_guardrail(name): def guardrail_func(context, agent, output): - return GuardrailFunctionOutput( - output_info={"name": name}, - tripwire_triggered=True - ) + return GuardrailFunctionOutput(output_info={"name": name}, tripwire_triggered=True) + return OutputGuardrail(guardrail_function=guardrail_func, name=name) guardrail1 = create_triggered_guardrail("guardrail_1") @@ -1149,14 +1155,16 @@ def guardrail_func(context, agent, output): run_config: RealtimeRunConfig = { "output_guardrails": [guardrail1, guardrail2], - "guardrails_settings": {"debounce_text_length": 5} + "guardrails_settings": {"debounce_text_length": 5}, } session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) - await session.on_event(RealtimeModelTranscriptDeltaEvent( - item_id="item_1", delta="trigger", response_id="resp_1" - )) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="trigger", response_id="resp_1" + ) + ) # Wait for async guardrail tasks to complete await self._wait_for_guardrail_tasks(session) diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py index 456ae125f..7fb2594a9 100644 --- a/tests/realtime/test_tracing.py +++ b/tests/realtime/test_tracing.py @@ -222,27 +222,24 @@ async def test_tracing_disabled_prevents_tracing(self, mock_websocket): # Create a test agent and runner with tracing disabled agent = RealtimeAgent(name="test_agent", instructions="test") - runner = RealtimeRunner( - starting_agent=agent, - config={"tracing_disabled": True} - ) + runner = RealtimeRunner(starting_agent=agent, config={"tracing_disabled": True}) # Test the _get_model_settings method directly since that's where the logic is model_settings = await runner._get_model_settings( agent=agent, disable_tracing=True, # This should come from config["tracing_disabled"] initial_settings=None, - overrides=None + overrides=None, ) # When tracing is disabled, model settings should have tracing=None assert model_settings["tracing"] is None # Also test that the runner passes disable_tracing=True correctly - with patch.object(runner, '_get_model_settings') as mock_get_settings: + with patch.object(runner, "_get_model_settings") as mock_get_settings: mock_get_settings.return_value = {"tracing": None} - with patch('agents.realtime.session.RealtimeSession') as mock_session_class: + with patch("agents.realtime.session.RealtimeSession") as mock_session_class: mock_session = AsyncMock() mock_session_class.return_value = mock_session @@ -250,8 +247,5 @@ async def test_tracing_disabled_prevents_tracing(self, mock_websocket): # Verify that _get_model_settings was called with disable_tracing=True mock_get_settings.assert_called_once_with( - agent=agent, - disable_tracing=True, - initial_settings=None, - overrides=None + agent=agent, disable_tracing=True, initial_settings=None, overrides=None )