Skip to content

Commit 5daecb7

Browse files
add tracking for cleanup background tasks
1 parent b426524 commit 5daecb7

File tree

2 files changed

+138
-4
lines changed

2 files changed

+138
-4
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class DefaultRequestHandler(RequestHandler):
6767
"""
6868

6969
_running_agents: dict[str, asyncio.Task]
70+
_background_tasks: set[asyncio.Task]
7071

7172
def __init__( # noqa: PLR0913
7273
self,
@@ -102,6 +103,9 @@ def __init__( # noqa: PLR0913
102103
# TODO: Likely want an interface for managing this, like AgentExecutionManager.
103104
self._running_agents = {}
104105
self._running_agents_lock = asyncio.Lock()
106+
# Tracks background tasks (e.g., deferred cleanups) to avoid orphaning
107+
# asyncio tasks and to surface unexpected exceptions.
108+
self._background_tasks = set()
105109

106110
async def on_get_task(
107111
self,
@@ -355,10 +359,11 @@ async def push_notification_callback() -> None:
355359
raise
356360
finally:
357361
if interrupted_or_non_blocking:
358-
# TODO: Track this disconnected cleanup task.
359-
asyncio.create_task( # noqa: RUF006
362+
cleanup_task = asyncio.create_task(
360363
self._cleanup_producer(producer_task, task_id)
361364
)
365+
cleanup_task.set_name(f'cleanup_producer:{task_id}')
366+
self._track_background_task(cleanup_task)
362367
else:
363368
await self._cleanup_producer(producer_task, task_id)
364369

@@ -394,10 +399,11 @@ async def on_message_send_stream(
394399
)
395400
yield event
396401
finally:
397-
# TODO: Track this disconnected cleanup task.
398-
asyncio.create_task( # noqa: RUF006
402+
cleanup_task = asyncio.create_task(
399403
self._cleanup_producer(producer_task, task_id)
400404
)
405+
cleanup_task.set_name(f'cleanup_producer:{task_id}')
406+
self._track_background_task(cleanup_task)
401407

402408
async def _register_producer(
403409
self, task_id: str, producer_task: asyncio.Task
@@ -406,6 +412,29 @@ async def _register_producer(
406412
async with self._running_agents_lock:
407413
self._running_agents[task_id] = producer_task
408414

415+
def _track_background_task(self, task: asyncio.Task) -> None:
416+
"""Tracks a background task and logs exceptions on completion.
417+
418+
This avoids unreferenced tasks (and associated lint warnings) while
419+
ensuring any exceptions are surfaced in logs.
420+
"""
421+
self._background_tasks.add(task)
422+
423+
def _on_done(completed: asyncio.Task) -> None:
424+
try:
425+
# Retrieve result to raise exceptions, if any
426+
completed.result()
427+
except asyncio.CancelledError:
428+
name = completed.get_name()
429+
logger.debug('Background task %s cancelled', name)
430+
except Exception:
431+
name = completed.get_name()
432+
logger.exception('Background task %s failed', name)
433+
finally:
434+
self._background_tasks.discard(completed)
435+
436+
task.add_done_callback(_on_done)
437+
409438
async def _cleanup_producer(
410439
self,
411440
producer_task: asyncio.Task,

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,111 @@ def create_task_spy(coro):
13861386
assert task_id not in request_handler._running_agents
13871387

13881388

1389+
@pytest.mark.asyncio
1390+
async def test_background_cleanup_task_is_tracked_and_cleared():
1391+
"""Ensure background cleanup task is tracked while pending and removed when done."""
1392+
# Arrange
1393+
mock_task_store = AsyncMock(spec=TaskStore)
1394+
mock_queue_manager = AsyncMock(spec=QueueManager)
1395+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
1396+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
1397+
1398+
task_id = 'track_task_1'
1399+
context_id = 'track_ctx_1'
1400+
1401+
# RequestContext with IDs
1402+
mock_request_context = MagicMock(spec=RequestContext)
1403+
mock_request_context.task_id = task_id
1404+
mock_request_context.context_id = context_id
1405+
mock_request_context_builder.build.return_value = mock_request_context
1406+
1407+
mock_queue = AsyncMock(spec=EventQueue)
1408+
mock_queue_manager.create_or_tap.return_value = mock_queue
1409+
1410+
request_handler = DefaultRequestHandler(
1411+
agent_executor=mock_agent_executor,
1412+
task_store=mock_task_store,
1413+
queue_manager=mock_queue_manager,
1414+
request_context_builder=mock_request_context_builder,
1415+
)
1416+
1417+
params = MessageSendParams(
1418+
message=Message(
1419+
role=Role.user,
1420+
message_id='mid_track',
1421+
parts=[],
1422+
task_id=task_id,
1423+
context_id=context_id,
1424+
)
1425+
)
1426+
1427+
# Agent executor runs in background until we allow it to finish
1428+
execute_started = asyncio.Event()
1429+
execute_finish = asyncio.Event()
1430+
1431+
async def exec_side_effect(*_args, **_kwargs):
1432+
execute_started.set()
1433+
await execute_finish.wait()
1434+
1435+
mock_agent_executor.execute.side_effect = exec_side_effect
1436+
1437+
# ResultAggregator emits one Task event (so the stream yields once)
1438+
first_event = create_sample_task(task_id=task_id, context_id=context_id)
1439+
1440+
async def single_event_stream():
1441+
yield first_event
1442+
1443+
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)
1444+
mock_result_aggregator_instance.consume_and_emit.return_value = (
1445+
single_event_stream()
1446+
)
1447+
1448+
produced_task: asyncio.Task | None = None
1449+
cleanup_task: asyncio.Task | None = None
1450+
1451+
orig_create_task = asyncio.create_task
1452+
1453+
def create_task_spy(coro):
1454+
nonlocal produced_task, cleanup_task
1455+
task = orig_create_task(coro)
1456+
if coro.__name__ == '_run_event_stream':
1457+
produced_task = task
1458+
elif coro.__name__ == '_cleanup_producer':
1459+
cleanup_task = task
1460+
return task
1461+
1462+
with (
1463+
patch(
1464+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
1465+
return_value=mock_result_aggregator_instance,
1466+
),
1467+
patch('asyncio.create_task', side_effect=create_task_spy),
1468+
):
1469+
# Act: start stream and consume only the first event, then disconnect
1470+
agen = request_handler.on_message_send_stream(
1471+
params, create_server_call_context()
1472+
)
1473+
first = await agen.__anext__()
1474+
assert first == first_event
1475+
# Simulate client disconnect
1476+
await asyncio.wait_for(agen.aclose(), timeout=0.1)
1477+
1478+
assert produced_task is not None
1479+
assert cleanup_task is not None
1480+
1481+
# Background cleanup task should be tracked while producer is still running
1482+
await asyncio.wait_for(execute_started.wait(), timeout=0.1)
1483+
assert cleanup_task in request_handler._background_tasks
1484+
1485+
# Allow executor to finish; this should complete producer, then cleanup
1486+
execute_finish.set()
1487+
await asyncio.wait_for(produced_task, timeout=0.2)
1488+
await asyncio.wait_for(cleanup_task, timeout=0.2)
1489+
1490+
# Cleanup task should be removed from tracking after completion
1491+
assert cleanup_task not in request_handler._background_tasks
1492+
1493+
13891494
@pytest.mark.asyncio
13901495
async def test_on_message_send_stream_task_id_mismatch():
13911496
"""Test on_message_send_stream raises error if yielded task ID mismatches."""

0 commit comments

Comments
 (0)