Skip to content

Commit a9b3010

Browse files
fix: Task state is not persisted to TaskStore after client disconnect
1 parent b2e3a29 commit a9b3010

File tree

5 files changed

+200
-21
lines changed

5 files changed

+200
-21
lines changed

src/a2a/client/transports/rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ async def get_task(
206206
context: ClientCallContext | None = None,
207207
) -> Task:
208208
"""Retrieves the current state and history of a specific task."""
209-
payload, modified_kwargs = await self._apply_interceptors(
209+
_payload, modified_kwargs = await self._apply_interceptors(
210210
request.model_dump(mode='json', exclude_none=True),
211211
self._get_http_args(context),
212212
context,

src/a2a/server/events/event_queue.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,18 @@ def tap(self) -> 'EventQueue':
135135
async def close(self, immediate: bool = False) -> None:
136136
"""Closes the queue for future push events and also closes all child queues.
137137
138-
Once closed, no new events can be enqueued. For Python 3.13+, this will trigger
139-
`asyncio.QueueShutDown` when the queue is empty and a consumer tries to dequeue.
140-
For lower versions, the queue will be marked as closed and optionally cleared.
138+
Once closed, no new events can be enqueued. Behavior is consistent across
139+
Python versions:
140+
- Python >= 3.13: Uses `asyncio.Queue.shutdown` to stop the queue. With
141+
`immediate=True` the queue is shut down and pending events are cleared; with
142+
`immediate=False` the queue is shut down and we wait for it to drain via
143+
`queue.join()`.
144+
- Python < 3.13: Emulates the same semantics by clearing on `immediate=True`
145+
or awaiting `queue.join()` on `immediate=False`.
146+
147+
Consumers attempting to dequeue after close on an empty queue will observe
148+
`asyncio.QueueShutDown` on Python >= 3.13 and `asyncio.QueueEmpty` on
149+
Python < 3.13.
141150
142151
Args:
143152
immediate (bool):
@@ -152,11 +161,22 @@ async def close(self, immediate: bool = False) -> None:
152161
return
153162
if not self._is_closed:
154163
self._is_closed = True
155-
# If using python 3.13 or higher, use the shutdown method
164+
# If using python 3.13 or higher, use shutdown but match <3.13 semantics
156165
if sys.version_info >= (3, 13):
157-
self.queue.shutdown(immediate)
158-
for child in self._children:
159-
await child.close(immediate)
166+
if immediate:
167+
# Immediate: stop queue and clear any pending events, then close children
168+
self.queue.shutdown(True)
169+
await self.clear_events(True)
170+
for child in self._children:
171+
await child.close(True)
172+
return
173+
# Graceful: prevent further gets/puts via shutdown, then wait for drain and children
174+
self.queue.shutdown(False)
175+
tasks = [asyncio.create_task(self.queue.join())]
176+
tasks.extend(
177+
asyncio.create_task(child.close()) for child in self._children
178+
)
179+
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
160180
# Otherwise, join the queue
161181
else:
162182
if immediate:

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ async def on_message_send(
314314
result (Task or Message).
315315
"""
316316
(
317-
task_manager,
317+
_task_manager,
318318
task_id,
319319
queue,
320320
result_aggregator,
@@ -379,16 +379,16 @@ async def on_message_send_stream(
379379
by the agent.
380380
"""
381381
(
382-
task_manager,
382+
_task_manager,
383383
task_id,
384384
queue,
385385
result_aggregator,
386386
producer_task,
387387
) = await self._setup_message_execution(params, context)
388+
consumer = EventConsumer(queue)
389+
producer_task.add_done_callback(consumer.agent_task_callback)
388390

389391
try:
390-
consumer = EventConsumer(queue)
391-
producer_task.add_done_callback(consumer.agent_task_callback)
392392
async for event in result_aggregator.consume_and_emit(consumer):
393393
if isinstance(event, Task):
394394
self._validate_task_id_match(task_id, event.id)
@@ -397,6 +397,14 @@ async def on_message_send_stream(
397397
task_id, result_aggregator
398398
)
399399
yield event
400+
except (asyncio.CancelledError, GeneratorExit):
401+
# Client disconnected: continue consuming and persisting events in the background
402+
bg_task = asyncio.create_task(
403+
result_aggregator.consume_all(consumer)
404+
)
405+
bg_task.set_name(f'background_consume:{task_id}')
406+
self._track_background_task(bg_task)
407+
raise
400408
finally:
401409
cleanup_task = asyncio.create_task(
402410
self._cleanup_producer(producer_task, task_id)

tests/server/events/test_event_queue.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,59 @@ async def test_close_sets_flag_and_handles_internal_queue_new_python(
300300
event_queue: EventQueue,
301301
) -> None:
302302
"""Test close behavior on Python >= 3.13 (using queue.shutdown)."""
303-
with patch('sys.version_info', (3, 13, 0)): # Simulate Python 3.13+
304-
# Mock queue.shutdown as it's called in newer versions
305-
event_queue.queue.shutdown = MagicMock() # shutdown is not async
303+
with patch('sys.version_info', (3, 13, 0)):
304+
# Inject a dummy shutdown method for non-3.13 runtimes
305+
from typing import cast
306306

307+
q_any = cast('Any', event_queue.queue)
308+
q_any.shutdown = MagicMock() # type: ignore[attr-defined]
307309
await event_queue.close()
308-
309310
assert event_queue.is_closed() is True
310-
event_queue.queue.shutdown.assert_called_once() # specific to >=3.13
311+
312+
313+
@pytest.mark.asyncio
314+
@patch('asyncio.wait')
315+
@patch('asyncio.create_task')
316+
async def test_close_graceful_py313_waits_for_join_and_children(
317+
mock_create_task: AsyncMock,
318+
mock_asyncio_wait: AsyncMock,
319+
event_queue: EventQueue,
320+
) -> None:
321+
"""For Python >=3.13 and immediate=False, close should shutdown(False), then wait for join and children."""
322+
with patch('sys.version_info', (3, 13, 0)):
323+
# Arrange
324+
from typing import cast
325+
326+
q_any = cast('Any', event_queue.queue)
327+
q_any.shutdown = MagicMock() # type: ignore[attr-defined]
328+
event_queue.queue.join = AsyncMock()
329+
330+
child = event_queue.tap()
331+
child.close = AsyncMock()
332+
333+
# Ensure created tasks actually run their coroutines
334+
async def _runner(coro):
335+
await coro
336+
337+
def _create_task_side_effect(coro):
338+
loop = asyncio.get_running_loop()
339+
return loop.create_task(_runner(coro))
340+
341+
mock_create_task.side_effect = _create_task_side_effect
342+
343+
async def _wait_side_effect(tasks, return_when=None):
344+
await asyncio.gather(*tasks, return_exceptions=True)
345+
return (set(tasks), set())
346+
347+
mock_asyncio_wait.side_effect = _wait_side_effect
348+
349+
# Act
350+
await event_queue.close(immediate=False)
351+
352+
# Assert
353+
event_queue.queue.join.assert_awaited_once()
354+
child.close.assert_awaited_once()
355+
mock_asyncio_wait.assert_called()
311356

312357

313358
@pytest.mark.asyncio
@@ -345,15 +390,16 @@ async def test_close_idempotent(event_queue: EventQueue) -> None:
345390

346391
# Reset for new Python version test
347392
event_queue_new = EventQueue() # New queue for fresh state
348-
with patch('sys.version_info', (3, 13, 0)): # Test with newer version logic
349-
event_queue_new.queue.shutdown = MagicMock()
393+
with patch('sys.version_info', (3, 13, 0)):
394+
from typing import cast
395+
396+
q_any2 = cast('Any', event_queue_new.queue)
397+
q_any2.shutdown = MagicMock() # type: ignore[attr-defined]
350398
await event_queue_new.close()
351399
assert event_queue_new.is_closed() is True
352-
event_queue_new.queue.shutdown.assert_called_once()
353400

354401
await event_queue_new.close()
355402
assert event_queue_new.is_closed() is True
356-
event_queue_new.queue.shutdown.assert_called_once() # Still only called once
357403

358404

359405
@pytest.mark.asyncio

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import logging
34
import time
45

@@ -48,6 +49,7 @@
4849
TaskQueryParams,
4950
TaskState,
5051
TaskStatus,
52+
TaskStatusUpdateEvent,
5153
TextPart,
5254
UnsupportedOperationError,
5355
)
@@ -1331,6 +1333,15 @@ async def single_event_stream():
13311333
mock_result_aggregator_instance.consume_and_emit.return_value = (
13321334
single_event_stream()
13331335
)
1336+
# Signal when background consume_all is started
1337+
bg_started = asyncio.Event()
1338+
1339+
async def mock_consume_all(_consumer):
1340+
bg_started.set()
1341+
# emulate short-running background work
1342+
await asyncio.sleep(0)
1343+
1344+
mock_result_aggregator_instance.consume_all = mock_consume_all
13341345

13351346
produced_task: asyncio.Task | None = None
13361347
cleanup_task: asyncio.Task | None = None
@@ -1367,6 +1378,9 @@ def create_task_spy(coro):
13671378
assert produced_task is not None
13681379
assert cleanup_task is not None
13691380

1381+
# Assert background consume_all started
1382+
await asyncio.wait_for(bg_started.wait(), timeout=0.2)
1383+
13701384
# execute should have started
13711385
await asyncio.wait_for(execute_started.wait(), timeout=0.1)
13721386

@@ -1385,6 +1399,91 @@ def create_task_spy(coro):
13851399
# Running agents is cleared
13861400
assert task_id not in request_handler._running_agents
13871401

1402+
# Cleanup any lingering background tasks started by on_message_send_stream
1403+
# (e.g., background_consume)
1404+
for t in list(request_handler._background_tasks):
1405+
t.cancel()
1406+
with contextlib.suppress(asyncio.CancelledError):
1407+
await t
1408+
1409+
1410+
@pytest.mark.asyncio
1411+
async def test_disconnect_persists_final_task_to_store():
1412+
"""After client disconnect, ensure background consumer persists final Task to store."""
1413+
task_store = InMemoryTaskStore()
1414+
queue_manager = InMemoryQueueManager()
1415+
1416+
# Custom agent that emits a working update then a completed final update
1417+
class FinishingAgent(AgentExecutor):
1418+
def __init__(self):
1419+
self.allow_finish = asyncio.Event()
1420+
1421+
async def execute(
1422+
self, context: RequestContext, event_queue: EventQueue
1423+
):
1424+
from typing import cast
1425+
1426+
updater = TaskUpdater(
1427+
event_queue,
1428+
cast('str', context.task_id),
1429+
cast('str', context.context_id),
1430+
)
1431+
await updater.update_status(TaskState.working)
1432+
await self.allow_finish.wait()
1433+
await updater.update_status(TaskState.completed)
1434+
1435+
async def cancel(
1436+
self, context: RequestContext, event_queue: EventQueue
1437+
):
1438+
return None
1439+
1440+
agent = FinishingAgent()
1441+
1442+
handler = DefaultRequestHandler(
1443+
agent_executor=agent, task_store=task_store, queue_manager=queue_manager
1444+
)
1445+
1446+
params = MessageSendParams(
1447+
message=Message(
1448+
role=Role.user,
1449+
message_id='msg_persist',
1450+
parts=[],
1451+
)
1452+
)
1453+
1454+
# Start streaming and consume the first event (working)
1455+
agen = handler.on_message_send_stream(params, create_server_call_context())
1456+
first = await agen.__anext__()
1457+
if isinstance(first, TaskStatusUpdateEvent):
1458+
assert first.status.state == TaskState.working
1459+
task_id = first.task_id
1460+
else:
1461+
assert (
1462+
isinstance(first, Task) and first.status.state == TaskState.working
1463+
)
1464+
task_id = first.id
1465+
1466+
# Disconnect client
1467+
await asyncio.wait_for(agen.aclose(), timeout=0.1)
1468+
1469+
# Finish agent and allow background consumer to persist final state
1470+
agent.allow_finish.set()
1471+
1472+
# Wait until background_consume task for this task_id is gone
1473+
await wait_until(
1474+
lambda: all(
1475+
not t.get_name().startswith(f'background_consume:{task_id}')
1476+
for t in handler._background_tasks
1477+
),
1478+
timeout=1.0,
1479+
interval=0.01,
1480+
)
1481+
1482+
# Verify task is persisted as completed
1483+
persisted = await task_store.get(task_id, create_server_call_context())
1484+
assert persisted is not None
1485+
assert persisted.status.state == TaskState.completed
1486+
13881487

13891488
async def wait_until(predicate, timeout: float = 0.2, interval: float = 0.0):
13901489
"""Await until predicate() is True or timeout elapses."""
@@ -1505,6 +1604,12 @@ def create_task_spy(coro):
15051604
timeout=0.1,
15061605
)
15071606

1607+
# Cleanup any lingering background tasks
1608+
for t in list(request_handler._background_tasks):
1609+
t.cancel()
1610+
with contextlib.suppress(asyncio.CancelledError):
1611+
await t
1612+
15081613

15091614
@pytest.mark.asyncio
15101615
async def test_on_message_send_stream_task_id_mismatch():

0 commit comments

Comments
 (0)