@@ -1386,6 +1386,126 @@ def create_task_spy(coro):
13861386 assert task_id not in request_handler ._running_agents
13871387
13881388
1389+ async def wait_until (predicate , timeout : float = 0.2 , interval : float = 0.0 ):
1390+ """Await until predicate() is True or timeout elapses."""
1391+ loop = asyncio .get_running_loop ()
1392+ end = loop .time () + timeout
1393+ while True :
1394+ if predicate ():
1395+ return
1396+ if loop .time () >= end :
1397+ raise AssertionError ('condition not met within timeout' )
1398+ await asyncio .sleep (interval )
1399+
1400+
1401+ @pytest .mark .asyncio
1402+ async def test_background_cleanup_task_is_tracked_and_cleared ():
1403+ """Ensure background cleanup task is tracked while pending and removed when done."""
1404+ # Arrange
1405+ mock_task_store = AsyncMock (spec = TaskStore )
1406+ mock_queue_manager = AsyncMock (spec = QueueManager )
1407+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
1408+ mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
1409+
1410+ task_id = 'track_task_1'
1411+ context_id = 'track_ctx_1'
1412+
1413+ # RequestContext with IDs
1414+ mock_request_context = MagicMock (spec = RequestContext )
1415+ mock_request_context .task_id = task_id
1416+ mock_request_context .context_id = context_id
1417+ mock_request_context_builder .build .return_value = mock_request_context
1418+
1419+ mock_queue = AsyncMock (spec = EventQueue )
1420+ mock_queue_manager .create_or_tap .return_value = mock_queue
1421+
1422+ request_handler = DefaultRequestHandler (
1423+ agent_executor = mock_agent_executor ,
1424+ task_store = mock_task_store ,
1425+ queue_manager = mock_queue_manager ,
1426+ request_context_builder = mock_request_context_builder ,
1427+ )
1428+
1429+ params = MessageSendParams (
1430+ message = Message (
1431+ role = Role .user ,
1432+ message_id = 'mid_track' ,
1433+ parts = [],
1434+ task_id = task_id ,
1435+ context_id = context_id ,
1436+ )
1437+ )
1438+
1439+ # Agent executor runs in background until we allow it to finish
1440+ execute_started = asyncio .Event ()
1441+ execute_finish = asyncio .Event ()
1442+
1443+ async def exec_side_effect (* _args , ** _kwargs ):
1444+ execute_started .set ()
1445+ await execute_finish .wait ()
1446+
1447+ mock_agent_executor .execute .side_effect = exec_side_effect
1448+
1449+ # ResultAggregator emits one Task event (so the stream yields once)
1450+ first_event = create_sample_task (task_id = task_id , context_id = context_id )
1451+
1452+ async def single_event_stream ():
1453+ yield first_event
1454+
1455+ mock_result_aggregator_instance = MagicMock (spec = ResultAggregator )
1456+ mock_result_aggregator_instance .consume_and_emit .return_value = (
1457+ single_event_stream ()
1458+ )
1459+
1460+ produced_task : asyncio .Task | None = None
1461+ cleanup_task : asyncio .Task | None = None
1462+
1463+ orig_create_task = asyncio .create_task
1464+
1465+ def create_task_spy (coro ):
1466+ nonlocal produced_task , cleanup_task
1467+ task = orig_create_task (coro )
1468+ if coro .__name__ == '_run_event_stream' :
1469+ produced_task = task
1470+ elif coro .__name__ == '_cleanup_producer' :
1471+ cleanup_task = task
1472+ return task
1473+
1474+ with (
1475+ patch (
1476+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
1477+ return_value = mock_result_aggregator_instance ,
1478+ ),
1479+ patch ('asyncio.create_task' , side_effect = create_task_spy ),
1480+ ):
1481+ # Act: start stream and consume only the first event, then disconnect
1482+ agen = request_handler .on_message_send_stream (
1483+ params , create_server_call_context ()
1484+ )
1485+ first = await agen .__anext__ ()
1486+ assert first == first_event
1487+ # Simulate client disconnect
1488+ await asyncio .wait_for (agen .aclose (), timeout = 0.1 )
1489+
1490+ assert produced_task is not None
1491+ assert cleanup_task is not None
1492+
1493+ # Background cleanup task should be tracked while producer is still running
1494+ await asyncio .wait_for (execute_started .wait (), timeout = 0.1 )
1495+ assert cleanup_task in request_handler ._background_tasks
1496+
1497+ # Allow executor to finish; this should complete producer, then cleanup
1498+ execute_finish .set ()
1499+ await asyncio .wait_for (produced_task , timeout = 0.1 )
1500+ await asyncio .wait_for (cleanup_task , timeout = 0.1 )
1501+
1502+ # Wait for callback to remove task from tracking
1503+ await wait_until (
1504+ lambda : cleanup_task not in request_handler ._background_tasks ,
1505+ timeout = 0.1 ,
1506+ )
1507+
1508+
13891509@pytest .mark .asyncio
13901510async def test_on_message_send_stream_task_id_mismatch ():
13911511 """Test on_message_send_stream raises error if yielded task ID mismatches."""
0 commit comments