@@ -38,24 +38,26 @@ def get_len(data: HandoffInputData) -> int:
3838 return input_len + pre_handoff_len + new_items_len
3939
4040
41- def test_single_handoff_setup ():
41+ @pytest .mark .asyncio
42+ async def test_single_handoff_setup ():
4243 agent_1 = Agent (name = "test_1" )
4344 agent_2 = Agent (name = "test_2" , handoffs = [agent_1 ])
4445
4546 assert not agent_1 .handoffs
4647 assert agent_2 .handoffs == [agent_1 ]
4748
48- assert not AgentRunner ._get_handoffs (agent_1 )
49+ assert not ( await AgentRunner ._get_handoffs (agent_1 , RunContextWrapper ( agent_1 )) )
4950
50- handoff_objects = AgentRunner ._get_handoffs (agent_2 )
51+ handoff_objects = await AgentRunner ._get_handoffs (agent_2 , RunContextWrapper ( agent_2 ) )
5152 assert len (handoff_objects ) == 1
5253 obj = handoff_objects [0 ]
5354 assert obj .tool_name == Handoff .default_tool_name (agent_1 )
5455 assert obj .tool_description == Handoff .default_tool_description (agent_1 )
5556 assert obj .agent_name == agent_1 .name
5657
5758
58- def test_multiple_handoffs_setup ():
59+ @pytest .mark .asyncio
60+ async def test_multiple_handoffs_setup ():
5961 agent_1 = Agent (name = "test_1" )
6062 agent_2 = Agent (name = "test_2" )
6163 agent_3 = Agent (name = "test_3" , handoffs = [agent_1 , agent_2 ])
@@ -64,7 +66,7 @@ def test_multiple_handoffs_setup():
6466 assert not agent_1 .handoffs
6567 assert not agent_2 .handoffs
6668
67- handoff_objects = AgentRunner ._get_handoffs (agent_3 )
69+ handoff_objects = await AgentRunner ._get_handoffs (agent_3 , RunContextWrapper ( agent_3 ) )
6870 assert len (handoff_objects ) == 2
6971 assert handoff_objects [0 ].tool_name == Handoff .default_tool_name (agent_1 )
7072 assert handoff_objects [1 ].tool_name == Handoff .default_tool_name (agent_2 )
@@ -76,7 +78,8 @@ def test_multiple_handoffs_setup():
7678 assert handoff_objects [1 ].agent_name == agent_2 .name
7779
7880
79- def test_custom_handoff_setup ():
81+ @pytest .mark .asyncio
82+ async def test_custom_handoff_setup ():
8083 agent_1 = Agent (name = "test_1" )
8184 agent_2 = Agent (name = "test_2" )
8285 agent_3 = Agent (
@@ -95,7 +98,7 @@ def test_custom_handoff_setup():
9598 assert not agent_1 .handoffs
9699 assert not agent_2 .handoffs
97100
98- handoff_objects = AgentRunner ._get_handoffs (agent_3 )
101+ handoff_objects = await AgentRunner ._get_handoffs (agent_3 , RunContextWrapper ( agent_3 ) )
99102 assert len (handoff_objects ) == 2
100103
101104 first_handoff = handoff_objects [0 ]
@@ -284,3 +287,86 @@ def test_get_transfer_message_is_valid_json() -> None:
284287 obj = handoff (agent )
285288 transfer = obj .get_transfer_message (agent )
286289 assert json .loads (transfer ) == {"assistant" : agent .name }
290+
291+
292+ def test_handoff_is_enabled_bool ():
293+ """Test that handoff respects is_enabled boolean parameter."""
294+ agent = Agent (name = "test" )
295+
296+ # Test enabled handoff (default)
297+ handoff_enabled = handoff (agent )
298+ assert handoff_enabled .is_enabled is True
299+
300+ # Test explicitly enabled handoff
301+ handoff_explicit_enabled = handoff (agent , is_enabled = True )
302+ assert handoff_explicit_enabled .is_enabled is True
303+
304+ # Test disabled handoff
305+ handoff_disabled = handoff (agent , is_enabled = False )
306+ assert handoff_disabled .is_enabled is False
307+
308+
309+ @pytest .mark .asyncio
310+ async def test_handoff_is_enabled_callable ():
311+ """Test that handoff respects is_enabled callable parameter."""
312+ agent = Agent (name = "test" )
313+
314+ # Test callable that returns True
315+ def always_enabled (ctx : RunContextWrapper [Any ], agent : Agent [Any ]) -> bool :
316+ return True
317+
318+ handoff_callable_enabled = handoff (agent , is_enabled = always_enabled )
319+ assert callable (handoff_callable_enabled .is_enabled )
320+ result = handoff_callable_enabled .is_enabled (RunContextWrapper (agent ), agent )
321+ assert result is True
322+
323+ # Test callable that returns False
324+ def always_disabled (ctx : RunContextWrapper [Any ], agent : Agent [Any ]) -> bool :
325+ return False
326+
327+ handoff_callable_disabled = handoff (agent , is_enabled = always_disabled )
328+ assert callable (handoff_callable_disabled .is_enabled )
329+ result = handoff_callable_disabled .is_enabled (RunContextWrapper (agent ), agent )
330+ assert result is False
331+
332+ # Test async callable
333+ async def async_enabled (ctx : RunContextWrapper [Any ], agent : Agent [Any ]) -> bool :
334+ return True
335+
336+ handoff_async_enabled = handoff (agent , is_enabled = async_enabled )
337+ assert callable (handoff_async_enabled .is_enabled )
338+ result = await handoff_async_enabled .is_enabled (RunContextWrapper (agent ), agent ) # type: ignore
339+ assert result is True
340+
341+
342+ @pytest .mark .asyncio
343+ async def test_handoff_is_enabled_filtering_integration ():
344+ """Integration test that disabled handoffs are filtered out by the runner."""
345+
346+ # Set up agents
347+ agent_1 = Agent (name = "agent_1" )
348+ agent_2 = Agent (name = "agent_2" )
349+ agent_3 = Agent (name = "agent_3" )
350+
351+ # Create main agent with mixed enabled/disabled handoffs
352+ main_agent = Agent (
353+ name = "main_agent" ,
354+ handoffs = [
355+ handoff (agent_1 , is_enabled = True ), # enabled
356+ handoff (agent_2 , is_enabled = False ), # disabled
357+ handoff (agent_3 , is_enabled = lambda ctx , agent : True ), # enabled callable
358+ ],
359+ )
360+
361+ context_wrapper = RunContextWrapper (main_agent )
362+
363+ # Get filtered handoffs using the runner's method
364+ filtered_handoffs = await AgentRunner ._get_handoffs (main_agent , context_wrapper )
365+
366+ # Should only have 2 handoffs (agent_1 and agent_3), agent_2 should be filtered out
367+ assert len (filtered_handoffs ) == 2
368+
369+ # Check that the correct agents are present
370+ agent_names = {h .agent_name for h in filtered_handoffs }
371+ assert agent_names == {"agent_1" , "agent_3" }
372+ assert "agent_2" not in agent_names
0 commit comments