Skip to content

Commit 5534151

Browse files
committed
Fix test to use real asyncio tasks instead of mocks
Changed test_exception_during_guardrail_processing to use real asyncio.Task objects instead of Mock objects. This allows the production code to remain simple and clean, without needing isinstance checks to handle test-specific mocks. The test now creates actual tasks using asyncio.create_task() which better reflects real-world usage and naturally works with the cleanup logic that uses asyncio.gather().
1 parent d64ecfe commit 5534151

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

tests/realtime/test_guardrail_cleanup.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,7 @@ async def failing_guardrail_func(context, agent, output):
165165
# Simulate an error during cleanup
166166
raise RuntimeError("Cleanup error") from e
167167

168-
guardrail = OutputGuardrail(
169-
guardrail_function=failing_guardrail_func, name="failing_guardrail"
170-
)
168+
guardrail = OutputGuardrail(guardrail_function=failing_guardrail_func, name="failing_guardrail")
171169

172170
run_config: RealtimeRunConfig = {
173171
"output_guardrails": [guardrail],

tests/test_session_exceptions.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -249,16 +249,21 @@ async def test_exception_during_guardrail_processing(
249249

250250
session = RealtimeSession(fake_model, fake_agent, None)
251251

252-
# Add some fake guardrail tasks
253-
fake_task1 = Mock()
254-
fake_task1.done.return_value = False
255-
fake_task1.cancel = Mock()
252+
# Create real async tasks for testing cleanup
253+
async def long_running_task():
254+
await asyncio.sleep(10)
256255

257-
fake_task2 = Mock()
258-
fake_task2.done.return_value = True
259-
fake_task2.cancel = Mock()
256+
async def completed_task():
257+
pass
260258

261-
session._guardrail_tasks = {fake_task1, fake_task2}
259+
# Create tasks
260+
task1 = asyncio.create_task(long_running_task())
261+
task2 = asyncio.create_task(completed_task())
262+
263+
# Wait for task2 to complete
264+
await task2
265+
266+
session._guardrail_tasks = {task1, task2}
262267

263268
fake_model.set_next_events([exception_event])
264269

@@ -268,8 +273,8 @@ async def test_exception_during_guardrail_processing(
268273
pass
269274

270275
# Verify guardrail tasks were properly cleaned up
271-
fake_task1.cancel.assert_called_once()
272-
fake_task2.cancel.assert_not_called() # Already done
276+
assert task1.cancelled() # Should be cancelled
277+
assert task2.done() # Was already done
273278
assert len(session._guardrail_tasks) == 0
274279

275280
@pytest.mark.asyncio

0 commit comments

Comments
 (0)