Skip to content

Commit e716689

Browse files
committed
test: add test for soft cancel after handoff to verify session state
1 parent 4aaa0bb commit e716689

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

tests/test_soft_cancel.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,65 @@ async def test_soft_cancel_does_not_clear_queues_immediately():
377377
assert len(events) >= 0 # Events may or may not be present depending on timing
378378

379379

380+
@pytest.mark.asyncio
381+
async def test_soft_cancel_with_handoff():
382+
"""Verify soft cancel after handoff saves the handoff turn."""
383+
from agents import Handoff
384+
385+
model = FakeModel()
386+
387+
# Create two agents with handoff
388+
agent2 = Agent(name="Agent2", model=model)
389+
390+
async def on_invoke_handoff(context, data):
391+
return agent2
392+
393+
agent1 = Agent(
394+
name="Agent1",
395+
model=model,
396+
handoffs=[
397+
Handoff(
398+
tool_name=Handoff.default_tool_name(agent2),
399+
tool_description=Handoff.default_tool_description(agent2),
400+
input_json_schema={},
401+
on_invoke_handoff=on_invoke_handoff,
402+
agent_name=agent2.name,
403+
)
404+
],
405+
)
406+
407+
# Setup: Agent1 does handoff, Agent2 responds
408+
model.add_multiple_turn_outputs(
409+
[
410+
# Agent1's turn - triggers handoff
411+
[get_function_tool_call(Handoff.default_tool_name(agent2), "{}")],
412+
# Agent2's turn after handoff
413+
[get_text_message("Agent2 response")],
414+
]
415+
)
416+
417+
session = SQLiteSession("test_soft_cancel_handoff")
418+
await session.clear_session()
419+
420+
result = Runner.run_streamed(agent1, input="Hello", session=session)
421+
422+
handoff_seen = False
423+
async for event in result.stream_events():
424+
if event.type == "run_item_stream_event" and event.name == "handoff_occured":
425+
handoff_seen = True
426+
# Cancel right after handoff
427+
result.cancel(mode="after_turn")
428+
429+
assert handoff_seen, "Handoff should have occurred"
430+
431+
# Verify session has items from the handoff turn
432+
items = await session.get_items()
433+
assert len(items) > 0, "Session should have saved the handoff turn"
434+
435+
# Cleanup
436+
await session.clear_session()
437+
438+
380439
@pytest.mark.asyncio
381440
async def test_soft_cancel_with_session_and_multiple_turns():
382441
"""Verify soft cancel with session across multiple turns."""

0 commit comments

Comments
 (0)