Skip to content

Commit bd988dc

Browse files
authored
Run make format (#1106)
cleanup --- [//]: # (BEGIN SAPLING FOOTER) * #1112 * #1111 * #1107 * __->__ #1106
1 parent 3d66226 commit bd988dc

File tree

5 files changed

+55
-50
lines changed

5 files changed

+55
-50
lines changed

examples/mcp/prompt_server/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async def get_instructions_from_prompt(mcp_server: MCPServer, prompt_name: str,
1717
try:
1818
prompt_result = await mcp_server.get_prompt(prompt_name, kwargs)
1919
content = prompt_result.messages[0].content
20-
if hasattr(content, 'text'):
20+
if hasattr(content, "text"):
2121
instructions = content.text
2222
else:
2323
instructions = str(content)

src/agents/model_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,18 @@ def validate_from_none(value: None) -> _Omit:
4242
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None),
4343
)
4444

45+
4546
@dataclass
4647
class MCPToolChoice:
4748
server_label: str
4849
name: str
4950

51+
5052
Omit = Annotated[_Omit, _OmitTypeAnnotation]
5153
Headers: TypeAlias = Mapping[str, Union[str, Omit]]
5254
ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None]
5355

56+
5457
@dataclass
5558
class ModelSettings:
5659
"""Settings to use when calling an LLM.

src/agents/models/openai_responses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def convert_tool_choice(
343343
elif tool_choice == "mcp":
344344
# Note that this is still here for backwards compatibility,
345345
# but migrating to MCPToolChoice is recommended.
346-
return { "type": "mcp" } # type: ignore [typeddict-item]
346+
return {"type": "mcp"} # type: ignore [typeddict-item]
347347
else:
348348
return {
349349
"type": "function",

tests/realtime/test_session.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -974,27 +974,30 @@ class TestGuardrailFunctionality:
974974
async def _wait_for_guardrail_tasks(self, session):
975975
"""Wait for all pending guardrail tasks to complete."""
976976
import asyncio
977+
977978
if session._guardrail_tasks:
978979
await asyncio.gather(*session._guardrail_tasks, return_exceptions=True)
979980

980981
@pytest.fixture
981982
def triggered_guardrail(self):
982983
"""Creates a guardrail that always triggers"""
984+
983985
def guardrail_func(context, agent, output):
984986
return GuardrailFunctionOutput(
985-
output_info={"reason": "test trigger"},
986-
tripwire_triggered=True
987+
output_info={"reason": "test trigger"}, tripwire_triggered=True
987988
)
989+
988990
return OutputGuardrail(guardrail_function=guardrail_func, name="triggered_guardrail")
989991

990992
@pytest.fixture
991993
def safe_guardrail(self):
992994
"""Creates a guardrail that never triggers"""
995+
993996
def guardrail_func(context, agent, output):
994997
return GuardrailFunctionOutput(
995-
output_info={"reason": "safe content"},
996-
tripwire_triggered=False
998+
output_info={"reason": "safe content"}, tripwire_triggered=False
997999
)
1000+
9981001
return OutputGuardrail(guardrail_function=guardrail_func, name="safe_guardrail")
9991002

10001003
@pytest.mark.asyncio
@@ -1004,7 +1007,7 @@ async def test_transcript_delta_triggers_guardrail_at_threshold(
10041007
"""Test that guardrails run when transcript delta reaches debounce threshold"""
10051008
run_config: RealtimeRunConfig = {
10061009
"output_guardrails": [triggered_guardrail],
1007-
"guardrails_settings": {"debounce_text_length": 10}
1010+
"guardrails_settings": {"debounce_text_length": 10},
10081011
}
10091012

10101013
session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
@@ -1041,20 +1044,20 @@ async def test_transcript_delta_multiple_thresholds_same_item(
10411044
"""Test guardrails run at 1x, 2x, 3x thresholds for same item_id"""
10421045
run_config: RealtimeRunConfig = {
10431046
"output_guardrails": [triggered_guardrail],
1044-
"guardrails_settings": {"debounce_text_length": 5}
1047+
"guardrails_settings": {"debounce_text_length": 5},
10451048
}
10461049

10471050
session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
10481051

10491052
# First delta - reaches 1x threshold (5 chars)
1050-
await session.on_event(RealtimeModelTranscriptDeltaEvent(
1051-
item_id="item_1", delta="12345", response_id="resp_1"
1052-
))
1053+
await session.on_event(
1054+
RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="12345", response_id="resp_1")
1055+
)
10531056

10541057
# Second delta - reaches 2x threshold (10 chars total)
1055-
await session.on_event(RealtimeModelTranscriptDeltaEvent(
1056-
item_id="item_1", delta="67890", response_id="resp_1"
1057-
))
1058+
await session.on_event(
1059+
RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="67890", response_id="resp_1")
1060+
)
10581061

10591062
# Wait for async guardrail tasks to complete
10601063
await self._wait_for_guardrail_tasks(session)
@@ -1070,28 +1073,32 @@ async def test_transcript_delta_different_items_tracked_separately(
10701073
"""Test that different item_ids are tracked separately for debouncing"""
10711074
run_config: RealtimeRunConfig = {
10721075
"output_guardrails": [safe_guardrail],
1073-
"guardrails_settings": {"debounce_text_length": 10}
1076+
"guardrails_settings": {"debounce_text_length": 10},
10741077
}
10751078

10761079
session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
10771080

10781081
# Add text to item_1 (8 chars - below threshold)
1079-
await session.on_event(RealtimeModelTranscriptDeltaEvent(
1080-
item_id="item_1", delta="12345678", response_id="resp_1"
1081-
))
1082+
await session.on_event(
1083+
RealtimeModelTranscriptDeltaEvent(
1084+
item_id="item_1", delta="12345678", response_id="resp_1"
1085+
)
1086+
)
10821087

10831088
# Add text to item_2 (8 chars - below threshold)
1084-
await session.on_event(RealtimeModelTranscriptDeltaEvent(
1085-
item_id="item_2", delta="abcdefgh", response_id="resp_2"
1086-
))
1089+
await session.on_event(
1090+
RealtimeModelTranscriptDeltaEvent(
1091+
item_id="item_2", delta="abcdefgh", response_id="resp_2"
1092+
)
1093+
)
10871094

10881095
# Neither should trigger guardrails yet
10891096
assert mock_model.interrupts_called == 0
10901097

10911098
# Add more text to item_1 (total 12 chars - above threshold)
1092-
await session.on_event(RealtimeModelTranscriptDeltaEvent(
1093-
item_id="item_1", delta="90ab", response_id="resp_1"
1094-
))
1099+
await session.on_event(
1100+
RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="90ab", response_id="resp_1")
1101+
)
10951102

10961103
# item_1 should have triggered guardrail run (but not interrupted since safe)
10971104
assert session._item_guardrail_run_counts["item_1"] == 1
@@ -1107,15 +1114,17 @@ async def test_turn_ended_clears_guardrail_state(
11071114
"""Test that turn_ended event clears guardrail state for next turn"""
11081115
run_config: RealtimeRunConfig = {
11091116
"output_guardrails": [triggered_guardrail],
1110-
"guardrails_settings": {"debounce_text_length": 5}
1117+
"guardrails_settings": {"debounce_text_length": 5},
11111118
}
11121119

11131120
session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
11141121

11151122
# Trigger guardrail
1116-
await session.on_event(RealtimeModelTranscriptDeltaEvent(
1117-
item_id="item_1", delta="trigger", response_id="resp_1"
1118-
))
1123+
await session.on_event(
1124+
RealtimeModelTranscriptDeltaEvent(
1125+
item_id="item_1", delta="trigger", response_id="resp_1"
1126+
)
1127+
)
11191128

11201129
# Wait for async guardrail tasks to complete
11211130
await self._wait_for_guardrail_tasks(session)
@@ -1132,31 +1141,30 @@ async def test_turn_ended_clears_guardrail_state(
11321141
assert len(session._item_guardrail_run_counts) == 0
11331142

11341143
@pytest.mark.asyncio
1135-
async def test_multiple_guardrails_all_triggered(
1136-
self, mock_model, mock_agent
1137-
):
1144+
async def test_multiple_guardrails_all_triggered(self, mock_model, mock_agent):
11381145
"""Test that all triggered guardrails are included in the event"""
1146+
11391147
def create_triggered_guardrail(name):
11401148
def guardrail_func(context, agent, output):
1141-
return GuardrailFunctionOutput(
1142-
output_info={"name": name},
1143-
tripwire_triggered=True
1144-
)
1149+
return GuardrailFunctionOutput(output_info={"name": name}, tripwire_triggered=True)
1150+
11451151
return OutputGuardrail(guardrail_function=guardrail_func, name=name)
11461152

11471153
guardrail1 = create_triggered_guardrail("guardrail_1")
11481154
guardrail2 = create_triggered_guardrail("guardrail_2")
11491155

11501156
run_config: RealtimeRunConfig = {
11511157
"output_guardrails": [guardrail1, guardrail2],
1152-
"guardrails_settings": {"debounce_text_length": 5}
1158+
"guardrails_settings": {"debounce_text_length": 5},
11531159
}
11541160

11551161
session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
11561162

1157-
await session.on_event(RealtimeModelTranscriptDeltaEvent(
1158-
item_id="item_1", delta="trigger", response_id="resp_1"
1159-
))
1163+
await session.on_event(
1164+
RealtimeModelTranscriptDeltaEvent(
1165+
item_id="item_1", delta="trigger", response_id="resp_1"
1166+
)
1167+
)
11601168

11611169
# Wait for async guardrail tasks to complete
11621170
await self._wait_for_guardrail_tasks(session)

tests/realtime/test_tracing.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -222,36 +222,30 @@ async def test_tracing_disabled_prevents_tracing(self, mock_websocket):
222222
# Create a test agent and runner with tracing disabled
223223
agent = RealtimeAgent(name="test_agent", instructions="test")
224224

225-
runner = RealtimeRunner(
226-
starting_agent=agent,
227-
config={"tracing_disabled": True}
228-
)
225+
runner = RealtimeRunner(starting_agent=agent, config={"tracing_disabled": True})
229226

230227
# Test the _get_model_settings method directly since that's where the logic is
231228
model_settings = await runner._get_model_settings(
232229
agent=agent,
233230
disable_tracing=True, # This should come from config["tracing_disabled"]
234231
initial_settings=None,
235-
overrides=None
232+
overrides=None,
236233
)
237234

238235
# When tracing is disabled, model settings should have tracing=None
239236
assert model_settings["tracing"] is None
240237

241238
# Also test that the runner passes disable_tracing=True correctly
242-
with patch.object(runner, '_get_model_settings') as mock_get_settings:
239+
with patch.object(runner, "_get_model_settings") as mock_get_settings:
243240
mock_get_settings.return_value = {"tracing": None}
244241

245-
with patch('agents.realtime.session.RealtimeSession') as mock_session_class:
242+
with patch("agents.realtime.session.RealtimeSession") as mock_session_class:
246243
mock_session = AsyncMock()
247244
mock_session_class.return_value = mock_session
248245

249246
await runner.run()
250247

251248
# Verify that _get_model_settings was called with disable_tracing=True
252249
mock_get_settings.assert_called_once_with(
253-
agent=agent,
254-
disable_tracing=True,
255-
initial_settings=None,
256-
overrides=None
250+
agent=agent, disable_tracing=True, initial_settings=None, overrides=None
257251
)

0 commit comments

Comments
 (0)