Skip to content

Commit bd555ee

Browse files
add tracking for cleanup background tasks
1 parent b426524 commit bd555ee

File tree

2 files changed

+153
-4
lines changed

2 files changed

+153
-4
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class DefaultRequestHandler(RequestHandler):
6767
"""
6868

6969
_running_agents: dict[str, asyncio.Task]
70+
_background_tasks: set[asyncio.Task]
7071

7172
def __init__( # noqa: PLR0913
7273
self,
@@ -102,6 +103,9 @@ def __init__( # noqa: PLR0913
102103
# TODO: Likely want an interface for managing this, like AgentExecutionManager.
103104
self._running_agents = {}
104105
self._running_agents_lock = asyncio.Lock()
106+
# Tracks background tasks (e.g., deferred cleanups) to avoid orphaning
107+
# asyncio tasks and to surface unexpected exceptions.
108+
self._background_tasks = set()
105109

106110
async def on_get_task(
107111
self,
@@ -355,10 +359,11 @@ async def push_notification_callback() -> None:
355359
raise
356360
finally:
357361
if interrupted_or_non_blocking:
358-
# TODO: Track this disconnected cleanup task.
359-
asyncio.create_task( # noqa: RUF006
362+
cleanup_task = asyncio.create_task(
360363
self._cleanup_producer(producer_task, task_id)
361364
)
365+
cleanup_task.set_name(f'cleanup_producer:{task_id}')
366+
self._track_background_task(cleanup_task)
362367
else:
363368
await self._cleanup_producer(producer_task, task_id)
364369

@@ -394,10 +399,11 @@ async def on_message_send_stream(
394399
)
395400
yield event
396401
finally:
397-
# TODO: Track this disconnected cleanup task.
398-
asyncio.create_task( # noqa: RUF006
402+
cleanup_task = asyncio.create_task(
399403
self._cleanup_producer(producer_task, task_id)
400404
)
405+
cleanup_task.set_name(f'cleanup_producer:{task_id}')
406+
self._track_background_task(cleanup_task)
401407

402408
async def _register_producer(
403409
self, task_id: str, producer_task: asyncio.Task
@@ -406,6 +412,29 @@ async def _register_producer(
406412
async with self._running_agents_lock:
407413
self._running_agents[task_id] = producer_task
408414

415+
def _track_background_task(self, task: asyncio.Task) -> None:
416+
"""Tracks a background task and logs exceptions on completion.
417+
418+
This avoids unreferenced tasks (and associated lint warnings) while
419+
ensuring any exceptions are surfaced in logs.
420+
"""
421+
self._background_tasks.add(task)
422+
423+
def _on_done(completed: asyncio.Task) -> None:
424+
try:
425+
# Retrieve result to raise exceptions, if any
426+
completed.result()
427+
except asyncio.CancelledError:
428+
name = completed.get_name()
429+
logger.debug('Background task %s cancelled', name)
430+
except Exception:
431+
name = completed.get_name()
432+
logger.exception('Background task %s failed', name)
433+
finally:
434+
self._background_tasks.discard(completed)
435+
436+
task.add_done_callback(_on_done)
437+
409438
async def _cleanup_producer(
410439
self,
411440
producer_task: asyncio.Task,

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,126 @@ def create_task_spy(coro):
13861386
assert task_id not in request_handler._running_agents
13871387

13881388

1389+
async def wait_until(predicate, timeout: float = 0.2, interval: float = 0.0):
1390+
"""Await until predicate() is True or timeout elapses."""
1391+
loop = asyncio.get_running_loop()
1392+
end = loop.time() + timeout
1393+
while True:
1394+
if predicate():
1395+
return
1396+
if loop.time() >= end:
1397+
raise AssertionError('condition not met within timeout')
1398+
await asyncio.sleep(interval)
1399+
1400+
1401+
@pytest.mark.asyncio
1402+
async def test_background_cleanup_task_is_tracked_and_cleared():
1403+
"""Ensure background cleanup task is tracked while pending and removed when done."""
1404+
# Arrange
1405+
mock_task_store = AsyncMock(spec=TaskStore)
1406+
mock_queue_manager = AsyncMock(spec=QueueManager)
1407+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
1408+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
1409+
1410+
task_id = 'track_task_1'
1411+
context_id = 'track_ctx_1'
1412+
1413+
# RequestContext with IDs
1414+
mock_request_context = MagicMock(spec=RequestContext)
1415+
mock_request_context.task_id = task_id
1416+
mock_request_context.context_id = context_id
1417+
mock_request_context_builder.build.return_value = mock_request_context
1418+
1419+
mock_queue = AsyncMock(spec=EventQueue)
1420+
mock_queue_manager.create_or_tap.return_value = mock_queue
1421+
1422+
request_handler = DefaultRequestHandler(
1423+
agent_executor=mock_agent_executor,
1424+
task_store=mock_task_store,
1425+
queue_manager=mock_queue_manager,
1426+
request_context_builder=mock_request_context_builder,
1427+
)
1428+
1429+
params = MessageSendParams(
1430+
message=Message(
1431+
role=Role.user,
1432+
message_id='mid_track',
1433+
parts=[],
1434+
task_id=task_id,
1435+
context_id=context_id,
1436+
)
1437+
)
1438+
1439+
# Agent executor runs in background until we allow it to finish
1440+
execute_started = asyncio.Event()
1441+
execute_finish = asyncio.Event()
1442+
1443+
async def exec_side_effect(*_args, **_kwargs):
1444+
execute_started.set()
1445+
await execute_finish.wait()
1446+
1447+
mock_agent_executor.execute.side_effect = exec_side_effect
1448+
1449+
# ResultAggregator emits one Task event (so the stream yields once)
1450+
first_event = create_sample_task(task_id=task_id, context_id=context_id)
1451+
1452+
async def single_event_stream():
1453+
yield first_event
1454+
1455+
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)
1456+
mock_result_aggregator_instance.consume_and_emit.return_value = (
1457+
single_event_stream()
1458+
)
1459+
1460+
produced_task: asyncio.Task | None = None
1461+
cleanup_task: asyncio.Task | None = None
1462+
1463+
orig_create_task = asyncio.create_task
1464+
1465+
def create_task_spy(coro):
1466+
nonlocal produced_task, cleanup_task
1467+
task = orig_create_task(coro)
1468+
if coro.__name__ == '_run_event_stream':
1469+
produced_task = task
1470+
elif coro.__name__ == '_cleanup_producer':
1471+
cleanup_task = task
1472+
return task
1473+
1474+
with (
1475+
patch(
1476+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
1477+
return_value=mock_result_aggregator_instance,
1478+
),
1479+
patch('asyncio.create_task', side_effect=create_task_spy),
1480+
):
1481+
# Act: start stream and consume only the first event, then disconnect
1482+
agen = request_handler.on_message_send_stream(
1483+
params, create_server_call_context()
1484+
)
1485+
first = await agen.__anext__()
1486+
assert first == first_event
1487+
# Simulate client disconnect
1488+
await asyncio.wait_for(agen.aclose(), timeout=0.1)
1489+
1490+
assert produced_task is not None
1491+
assert cleanup_task is not None
1492+
1493+
# Background cleanup task should be tracked while producer is still running
1494+
await asyncio.wait_for(execute_started.wait(), timeout=0.1)
1495+
assert cleanup_task in request_handler._background_tasks
1496+
1497+
# Allow executor to finish; this should complete producer, then cleanup
1498+
execute_finish.set()
1499+
await asyncio.wait_for(produced_task, timeout=0.1)
1500+
await asyncio.wait_for(cleanup_task, timeout=0.1)
1501+
1502+
# Wait for callback to remove task from tracking
1503+
await wait_until(
1504+
lambda: cleanup_task not in request_handler._background_tasks,
1505+
timeout=0.1,
1506+
)
1507+
1508+
13891509
@pytest.mark.asyncio
13901510
async def test_on_message_send_stream_task_id_mismatch():
13911511
"""Test on_message_send_stream raises error if yielded task ID mismatches."""

0 commit comments

Comments
 (0)