@@ -1386,6 +1386,111 @@ def create_task_spy(coro):
13861386 assert task_id not in request_handler ._running_agents
13871387
13881388
1389+ @pytest .mark .asyncio
1390+ async def test_background_cleanup_task_is_tracked_and_cleared ():
1391+ """Ensure background cleanup task is tracked while pending and removed when done."""
1392+ # Arrange
1393+ mock_task_store = AsyncMock (spec = TaskStore )
1394+ mock_queue_manager = AsyncMock (spec = QueueManager )
1395+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
1396+ mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
1397+
1398+ task_id = 'track_task_1'
1399+ context_id = 'track_ctx_1'
1400+
1401+ # RequestContext with IDs
1402+ mock_request_context = MagicMock (spec = RequestContext )
1403+ mock_request_context .task_id = task_id
1404+ mock_request_context .context_id = context_id
1405+ mock_request_context_builder .build .return_value = mock_request_context
1406+
1407+ mock_queue = AsyncMock (spec = EventQueue )
1408+ mock_queue_manager .create_or_tap .return_value = mock_queue
1409+
1410+ request_handler = DefaultRequestHandler (
1411+ agent_executor = mock_agent_executor ,
1412+ task_store = mock_task_store ,
1413+ queue_manager = mock_queue_manager ,
1414+ request_context_builder = mock_request_context_builder ,
1415+ )
1416+
1417+ params = MessageSendParams (
1418+ message = Message (
1419+ role = Role .user ,
1420+ message_id = 'mid_track' ,
1421+ parts = [],
1422+ task_id = task_id ,
1423+ context_id = context_id ,
1424+ )
1425+ )
1426+
1427+ # Agent executor runs in background until we allow it to finish
1428+ execute_started = asyncio .Event ()
1429+ execute_finish = asyncio .Event ()
1430+
1431+ async def exec_side_effect (* _args , ** _kwargs ):
1432+ execute_started .set ()
1433+ await execute_finish .wait ()
1434+
1435+ mock_agent_executor .execute .side_effect = exec_side_effect
1436+
1437+ # ResultAggregator emits one Task event (so the stream yields once)
1438+ first_event = create_sample_task (task_id = task_id , context_id = context_id )
1439+
1440+ async def single_event_stream ():
1441+ yield first_event
1442+
1443+ mock_result_aggregator_instance = MagicMock (spec = ResultAggregator )
1444+ mock_result_aggregator_instance .consume_and_emit .return_value = (
1445+ single_event_stream ()
1446+ )
1447+
1448+ produced_task : asyncio .Task | None = None
1449+ cleanup_task : asyncio .Task | None = None
1450+
1451+ orig_create_task = asyncio .create_task
1452+
1453+ def create_task_spy (coro ):
1454+ nonlocal produced_task , cleanup_task
1455+ task = orig_create_task (coro )
1456+ if coro .__name__ == '_run_event_stream' :
1457+ produced_task = task
1458+ elif coro .__name__ == '_cleanup_producer' :
1459+ cleanup_task = task
1460+ return task
1461+
1462+ with (
1463+ patch (
1464+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
1465+ return_value = mock_result_aggregator_instance ,
1466+ ),
1467+ patch ('asyncio.create_task' , side_effect = create_task_spy ),
1468+ ):
1469+ # Act: start stream and consume only the first event, then disconnect
1470+ agen = request_handler .on_message_send_stream (
1471+ params , create_server_call_context ()
1472+ )
1473+ first = await agen .__anext__ ()
1474+ assert first == first_event
1475+ # Simulate client disconnect
1476+ await asyncio .wait_for (agen .aclose (), timeout = 0.1 )
1477+
1478+ assert produced_task is not None
1479+ assert cleanup_task is not None
1480+
1481+ # Background cleanup task should be tracked while producer is still running
1482+ await asyncio .wait_for (execute_started .wait (), timeout = 0.1 )
1483+ assert cleanup_task in request_handler ._background_tasks
1484+
1485+ # Allow executor to finish; this should complete producer, then cleanup
1486+ execute_finish .set ()
1487+ await asyncio .wait_for (produced_task , timeout = 0.2 )
1488+ await asyncio .wait_for (cleanup_task , timeout = 0.2 )
1489+
1490+ # Cleanup task should be removed from tracking after completion
1491+ assert cleanup_task not in request_handler ._background_tasks
1492+
1493+
13891494@pytest .mark .asyncio
13901495async def test_on_message_send_stream_task_id_mismatch ():
13911496 """Test on_message_send_stream raises error if yielded task ID mismatches."""
0 commit comments