Skip to content

Commit a25dac7

Browse files
authored
Realtime: only update model settings from session (#1169)
### Summary: Was running into bugs. Because the model settings were being set from both runner and session, and that was causing issues. Among other things, handoffs were broken because the runner wasn't reading them, and the session wasn't setting them in the connect() call. ### Test Plan: Unit tests.
1 parent 2f8ea0a commit a25dac7

File tree

6 files changed

+244
-138
lines changed

6 files changed

+244
-138
lines changed

src/agents/realtime/runner.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22

33
from __future__ import annotations
44

5-
import asyncio
6-
7-
from ..run_context import RunContextWrapper, TContext
5+
from ..run_context import TContext
86
from .agent import RealtimeAgent
97
from .config import (
108
RealtimeRunConfig,
11-
RealtimeSessionModelSettings,
129
)
1310
from .model import (
1411
RealtimeModel,
@@ -67,16 +64,6 @@ async def run(
6764
print(event)
6865
```
6966
"""
70-
model_settings = await self._get_model_settings(
71-
agent=self._starting_agent,
72-
disable_tracing=self._config.get("tracing_disabled", False) if self._config else False,
73-
initial_settings=model_config.get("initial_model_settings") if model_config else None,
74-
overrides=self._config.get("model_settings") if self._config else None,
75-
)
76-
77-
model_config = model_config.copy() if model_config else {}
78-
model_config["initial_model_settings"] = model_settings
79-
8067
# Create and return the connection
8168
session = RealtimeSession(
8269
model=self._model,
@@ -87,32 +74,3 @@ async def run(
8774
)
8875

8976
return session
90-
91-
async def _get_model_settings(
92-
self,
93-
agent: RealtimeAgent,
94-
disable_tracing: bool,
95-
context: TContext | None = None,
96-
initial_settings: RealtimeSessionModelSettings | None = None,
97-
overrides: RealtimeSessionModelSettings | None = None,
98-
) -> RealtimeSessionModelSettings:
99-
context_wrapper = RunContextWrapper(context)
100-
model_settings = initial_settings.copy() if initial_settings else {}
101-
102-
instructions, tools = await asyncio.gather(
103-
agent.get_system_prompt(context_wrapper),
104-
agent.get_all_tools(context_wrapper),
105-
)
106-
107-
if instructions is not None:
108-
model_settings["instructions"] = instructions
109-
if tools is not None:
110-
model_settings["tools"] = tools
111-
112-
if overrides:
113-
model_settings.update(overrides)
114-
115-
if disable_tracing:
116-
model_settings["tracing"] = None
117-
118-
return model_settings

src/agents/realtime/session.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,13 @@ async def __aenter__(self) -> RealtimeSession:
114114
# Add ourselves as a listener
115115
self._model.add_listener(self)
116116

117+
model_config = self._model_config.copy()
118+
model_config["initial_model_settings"] = await self._get_updated_model_settings_from_agent(
119+
self._current_agent
120+
)
121+
117122
# Connect to the model
118-
await self._model.connect(self._model_config)
123+
await self._model.connect(model_config)
119124

120125
# Emit initial history update
121126
await self._put_event(
@@ -319,7 +324,9 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
319324
self._current_agent = result
320325

321326
# Get updated model settings from new agent
322-
updated_settings = await self._get__updated_model_settings(self._current_agent)
327+
updated_settings = await self._get_updated_model_settings_from_agent(
328+
self._current_agent
329+
)
323330

324331
# Send handoff event
325332
await self._put_event(
@@ -495,19 +502,28 @@ async def _cleanup(self) -> None:
495502
# Mark as closed
496503
self._closed = True
497504

498-
async def _get__updated_model_settings(
499-
self, new_agent: RealtimeAgent
505+
async def _get_updated_model_settings_from_agent(
506+
self,
507+
agent: RealtimeAgent,
500508
) -> RealtimeSessionModelSettings:
501509
updated_settings: RealtimeSessionModelSettings = {}
502510
instructions, tools, handoffs = await asyncio.gather(
503-
new_agent.get_system_prompt(self._context_wrapper),
504-
new_agent.get_all_tools(self._context_wrapper),
505-
self._get_handoffs(new_agent, self._context_wrapper),
511+
agent.get_system_prompt(self._context_wrapper),
512+
agent.get_all_tools(self._context_wrapper),
513+
self._get_handoffs(agent, self._context_wrapper),
506514
)
507515
updated_settings["instructions"] = instructions or ""
508516
updated_settings["tools"] = tools or []
509517
updated_settings["handoffs"] = handoffs or []
510518

519+
# Override with initial settings
520+
initial_settings = self._model_config.get("initial_model_settings", {})
521+
updated_settings.update(initial_settings)
522+
523+
disable_tracing = self._run_config.get("tracing_disabled", False)
524+
if disable_tracing:
525+
updated_settings["tracing"] = None
526+
511527
return updated_settings
512528

513529
@classmethod

tests/realtime/test_runner.py

Lines changed: 87 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
from unittest.mock import AsyncMock, Mock, patch
22

33
import pytest
4-
from inline_snapshot import snapshot
54

65
from agents.realtime.agent import RealtimeAgent
76
from agents.realtime.config import RealtimeRunConfig, RealtimeSessionModelSettings
87
from agents.realtime.model import RealtimeModel, RealtimeModelConfig
98
from agents.realtime.runner import RealtimeRunner
109
from agents.realtime.session import RealtimeSession
10+
from agents.tool import function_tool
1111

1212

1313
class MockRealtimeModel(RealtimeModel):
14+
def __init__(self):
15+
self.connect_args = None
16+
1417
async def connect(self, options=None):
15-
pass
18+
self.connect_args = options
1619

1720
def add_listener(self, listener):
1821
pass
@@ -53,7 +56,9 @@ def mock_model():
5356

5457

5558
@pytest.mark.asyncio
56-
async def test_run_creates_session_with_no_settings(mock_agent, mock_model):
59+
async def test_run_creates_session_with_no_settings(
60+
mock_agent: Mock, mock_model: MockRealtimeModel
61+
):
5762
"""Test that run() creates a session correctly if no settings are provided"""
5863
runner = RealtimeRunner(mock_agent, model=mock_model)
5964

@@ -71,22 +76,17 @@ async def test_run_creates_session_with_no_settings(mock_agent, mock_model):
7176
assert call_args[1]["agent"] == mock_agent
7277
assert call_args[1]["context"] is None
7378

74-
# Verify model_config contains expected settings from agent
79+
# With no settings provided, model_config should be None
7580
model_config = call_args[1]["model_config"]
76-
assert model_config == snapshot(
77-
{
78-
"initial_model_settings": {
79-
"instructions": "Test instructions",
80-
"tools": [{"type": "function", "name": "test_tool"}],
81-
}
82-
}
83-
)
81+
assert model_config is None
8482

8583
assert session == mock_session
8684

8785

8886
@pytest.mark.asyncio
89-
async def test_run_creates_session_with_settings_only_in_init(mock_agent, mock_model):
87+
async def test_run_creates_session_with_settings_only_in_init(
88+
mock_agent: Mock, mock_model: MockRealtimeModel
89+
):
9090
"""Test that it creates a session with the right settings if they are provided only in init"""
9191
config = RealtimeRunConfig(
9292
model_settings=RealtimeSessionModelSettings(model_name="gpt-4o-realtime", voice="nova")
@@ -99,28 +99,19 @@ async def test_run_creates_session_with_settings_only_in_init(mock_agent, mock_m
9999

100100
_ = await runner.run()
101101

102-
# Verify session was created with config overrides
102+
# Verify session was created - runner no longer processes settings
103103
call_args = mock_session_class.call_args
104104
model_config = call_args[1]["model_config"]
105105

106-
# Should have agent settings plus config overrides
107-
assert model_config == snapshot(
108-
{
109-
"initial_model_settings": {
110-
"instructions": "Test instructions",
111-
"tools": [{"type": "function", "name": "test_tool"}],
112-
"model_name": "gpt-4o-realtime",
113-
"voice": "nova",
114-
}
115-
}
116-
)
106+
# Runner should pass None for model_config when none provided to run()
107+
assert model_config is None
117108

118109

119110
@pytest.mark.asyncio
120111
async def test_run_creates_session_with_settings_in_both_init_and_run_overrides(
121-
mock_agent, mock_model
112+
mock_agent: Mock, mock_model: MockRealtimeModel
122113
):
123-
"""Test settings in both init and run() - init should override run()"""
114+
"""Test settings provided in run() parameter are passed through"""
124115
init_config = RealtimeRunConfig(
125116
model_settings=RealtimeSessionModelSettings(model_name="gpt-4o-realtime", voice="nova")
126117
)
@@ -138,26 +129,18 @@ async def test_run_creates_session_with_settings_in_both_init_and_run_overrides(
138129

139130
_ = await runner.run(model_config=run_model_config)
140131

141-
# Verify run() settings override init settings
132+
# Verify run() model_config is passed through as-is
142133
call_args = mock_session_class.call_args
143134
model_config = call_args[1]["model_config"]
144135

145-
# Should have agent settings, then init config, then run config overrides
146-
assert model_config == snapshot(
147-
{
148-
"initial_model_settings": {
149-
"voice": "nova",
150-
"input_audio_format": "pcm16",
151-
"instructions": "Test instructions",
152-
"tools": [{"type": "function", "name": "test_tool"}],
153-
"model_name": "gpt-4o-realtime",
154-
}
155-
}
156-
)
136+
# Runner should pass the model_config from run() parameter directly
137+
assert model_config == run_model_config
157138

158139

159140
@pytest.mark.asyncio
160-
async def test_run_creates_session_with_settings_only_in_run(mock_agent, mock_model):
141+
async def test_run_creates_session_with_settings_only_in_run(
142+
mock_agent: Mock, mock_model: MockRealtimeModel
143+
):
161144
"""Test settings provided only in run()"""
162145
runner = RealtimeRunner(mock_agent, model=mock_model)
163146

@@ -173,26 +156,16 @@ async def test_run_creates_session_with_settings_only_in_run(mock_agent, mock_mo
173156

174157
_ = await runner.run(model_config=run_model_config)
175158

176-
# Verify run() settings are applied
159+
# Verify run() model_config is passed through as-is
177160
call_args = mock_session_class.call_args
178161
model_config = call_args[1]["model_config"]
179162

180-
# Should have agent settings plus run() settings
181-
assert model_config == snapshot(
182-
{
183-
"initial_model_settings": {
184-
"model_name": "gpt-4o-realtime-preview",
185-
"voice": "shimmer",
186-
"modalities": ["text", "audio"],
187-
"instructions": "Test instructions",
188-
"tools": [{"type": "function", "name": "test_tool"}],
189-
}
190-
}
191-
)
163+
# Runner should pass the model_config from run() parameter directly
164+
assert model_config == run_model_config
192165

193166

194167
@pytest.mark.asyncio
195-
async def test_run_with_context_parameter(mock_agent, mock_model):
168+
async def test_run_with_context_parameter(mock_agent: Mock, mock_model: MockRealtimeModel):
196169
"""Test that context parameter is passed through to session"""
197170
runner = RealtimeRunner(mock_agent, model=mock_model)
198171
test_context = {"user_id": "test123"}
@@ -208,17 +181,69 @@ async def test_run_with_context_parameter(mock_agent, mock_model):
208181

209182

210183
@pytest.mark.asyncio
211-
async def test_get_model_settings_with_none_values(mock_model):
212-
"""Test _get_model_settings handles None values from agent properly"""
184+
async def test_run_with_none_values_from_agent_does_not_crash(mock_model: MockRealtimeModel):
185+
"""Test that runner handles agents with None values without crashing"""
213186
agent = Mock(spec=RealtimeAgent)
214187
agent.get_system_prompt = AsyncMock(return_value=None)
215188
agent.get_all_tools = AsyncMock(return_value=None)
216189

217190
runner = RealtimeRunner(agent, model=mock_model)
218191

219-
with patch("agents.realtime.runner.RealtimeSession"):
220-
await runner.run()
192+
with patch("agents.realtime.runner.RealtimeSession") as mock_session_class:
193+
mock_session = Mock(spec=RealtimeSession)
194+
mock_session_class.return_value = mock_session
195+
196+
session = await runner.run()
197+
198+
# Should not crash and return session
199+
assert session == mock_session
200+
# Runner no longer calls agent methods directly - session does that
201+
agent.get_system_prompt.assert_not_called()
202+
agent.get_all_tools.assert_not_called()
203+
204+
205+
@pytest.mark.asyncio
206+
async def test_tool_and_handoffs_are_correct(mock_model: MockRealtimeModel):
207+
@function_tool
208+
def tool_one():
209+
return "result_one"
210+
211+
agent_1 = RealtimeAgent(
212+
name="one",
213+
instructions="instr_one",
214+
)
215+
agent_2 = RealtimeAgent(
216+
name="two",
217+
instructions="instr_two",
218+
tools=[tool_one],
219+
handoffs=[agent_1],
220+
)
221+
222+
session = RealtimeSession(
223+
model=mock_model,
224+
agent=agent_2,
225+
context=None,
226+
model_config=None,
227+
run_config=None,
228+
)
229+
230+
async with session:
231+
pass
221232

222-
# Should not crash and agent methods should be called
223-
agent.get_system_prompt.assert_called_once()
224-
agent.get_all_tools.assert_called_once()
233+
# Assert that the model.connect() was called with the correct settings
234+
connect_args = mock_model.connect_args
235+
assert connect_args is not None
236+
assert isinstance(connect_args, dict)
237+
initial_model_settings = connect_args["initial_model_settings"]
238+
assert initial_model_settings is not None
239+
assert isinstance(initial_model_settings, dict)
240+
assert initial_model_settings["instructions"] == "instr_two"
241+
assert len(initial_model_settings["tools"]) == 1
242+
tool = initial_model_settings["tools"][0]
243+
assert tool.name == "tool_one"
244+
245+
handoffs = initial_model_settings["handoffs"]
246+
assert len(handoffs) == 1
247+
handoff = handoffs[0]
248+
assert handoff.tool_name == "transfer_to_one"
249+
assert handoff.agent_name == "one"

0 commit comments

Comments
 (0)