diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 3a72a5b1..082c21cc 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -206,7 +206,7 @@ async def get_task( context: ClientCallContext | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" - payload, modified_kwargs = await self._apply_interceptors( + _payload, modified_kwargs = await self._apply_interceptors( request.model_dump(mode='json', exclude_none=True), self._get_http_args(context), context, diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index 814bc879..f6599cca 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -135,9 +135,18 @@ def tap(self) -> 'EventQueue': async def close(self, immediate: bool = False) -> None: """Closes the queue for future push events and also closes all child queues. - Once closed, no new events can be enqueued. For Python 3.13+, this will trigger - `asyncio.QueueShutDown` when the queue is empty and a consumer tries to dequeue. - For lower versions, the queue will be marked as closed and optionally cleared. + Once closed, no new events can be enqueued. Behavior is consistent across + Python versions: + - Python >= 3.13: Uses `asyncio.Queue.shutdown` to stop the queue. With + `immediate=True` the queue is shut down and pending events are cleared; with + `immediate=False` the queue is shut down and we wait for it to drain via + `queue.join()`. + - Python < 3.13: Emulates the same semantics by clearing on `immediate=True` + or awaiting `queue.join()` on `immediate=False`. + + Consumers attempting to dequeue after close on an empty queue will observe + `asyncio.QueueShutDown` on Python >= 3.13 and `asyncio.QueueEmpty` on + Python < 3.13. Args: immediate (bool): @@ -152,11 +161,20 @@ async def close(self, immediate: bool = False) -> None: return if not self._is_closed: self._is_closed = True - # If using python 3.13 or higher, use the shutdown method + # If using python 3.13 or higher, use shutdown but match <3.13 semantics if sys.version_info >= (3, 13): - self.queue.shutdown(immediate) - for child in self._children: - await child.close(immediate) + if immediate: + # Immediate: stop queue and clear any pending events, then close children + self.queue.shutdown(True) + await self.clear_events(True) + for child in self._children: + await child.close(True) + return + # Graceful: prevent further gets/puts via shutdown, then wait for drain and children + self.queue.shutdown(False) + await asyncio.gather( + self.queue.join(), *(child.close() for child in self._children) + ) # Otherwise, join the queue else: if immediate: @@ -164,11 +182,9 @@ async def close(self, immediate: bool = False) -> None: for child in self._children: await child.close(immediate) return - tasks = [asyncio.create_task(self.queue.join())] - tasks.extend( - asyncio.create_task(child.close()) for child in self._children + await asyncio.gather( + self.queue.join(), *(child.close() for child in self._children) ) - await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) def is_closed(self) -> bool: """Checks if the queue is closed.""" diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index ee406d6b..5e21fe8b 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -314,7 +314,7 @@ async def on_message_send( result (Task or Message). """ ( - task_manager, + _task_manager, task_id, queue, result_aggregator, @@ -379,16 +379,16 @@ async def on_message_send_stream( by the agent. """ ( - task_manager, + _task_manager, task_id, queue, result_aggregator, producer_task, ) = await self._setup_message_execution(params, context) + consumer = EventConsumer(queue) + producer_task.add_done_callback(consumer.agent_task_callback) try: - consumer = EventConsumer(queue) - producer_task.add_done_callback(consumer.agent_task_callback) async for event in result_aggregator.consume_and_emit(consumer): if isinstance(event, Task): self._validate_task_id_match(task_id, event.id) @@ -397,6 +397,14 @@ async def on_message_send_stream( task_id, result_aggregator ) yield event + except (asyncio.CancelledError, GeneratorExit): + # Client disconnected: continue consuming and persisting events in the background + bg_task = asyncio.create_task( + result_aggregator.consume_all(consumer) + ) + bg_task.set_name(f'background_consume:{task_id}') + self._track_background_task(bg_task) + raise finally: cleanup_task = asyncio.create_task( self._cleanup_producer(producer_task, task_id) diff --git a/tests/server/events/test_event_queue.py b/tests/server/events/test_event_queue.py index fc139ecc..18ebf72b 100644 --- a/tests/server/events/test_event_queue.py +++ b/tests/server/events/test_event_queue.py @@ -271,15 +271,7 @@ async def test_tap_creates_child_queue(event_queue: EventQueue) -> None: @pytest.mark.asyncio -@patch( - 'asyncio.wait' -) # To monitor calls to asyncio.wait for older Python versions -@patch( - 'asyncio.create_task' -) # To monitor calls to asyncio.create_task for older Python versions async def test_close_sets_flag_and_handles_internal_queue_old_python( - mock_create_task: MagicMock, - mock_asyncio_wait: AsyncMock, event_queue: EventQueue, ) -> None: """Test close behavior on Python < 3.13 (using queue.join).""" @@ -290,9 +282,7 @@ async def test_close_sets_flag_and_handles_internal_queue_old_python( await event_queue.close() assert event_queue.is_closed() is True - event_queue.queue.join.assert_called_once() # specific to <3.13 - mock_create_task.assert_called_once() # create_task for join - mock_asyncio_wait.assert_called_once() # wait for join + event_queue.queue.join.assert_awaited_once() # waited for drain @pytest.mark.asyncio @@ -300,14 +290,39 @@ async def test_close_sets_flag_and_handles_internal_queue_new_python( event_queue: EventQueue, ) -> None: """Test close behavior on Python >= 3.13 (using queue.shutdown).""" - with patch('sys.version_info', (3, 13, 0)): # Simulate Python 3.13+ - # Mock queue.shutdown as it's called in newer versions - event_queue.queue.shutdown = MagicMock() # shutdown is not async + with patch('sys.version_info', (3, 13, 0)): + # Inject a dummy shutdown method for non-3.13 runtimes + from typing import cast + queue = cast('Any', event_queue.queue) + queue.shutdown = MagicMock() # type: ignore[attr-defined] await event_queue.close() - assert event_queue.is_closed() is True - event_queue.queue.shutdown.assert_called_once() # specific to >=3.13 + queue.shutdown.assert_called_once_with(False) + + +@pytest.mark.asyncio +async def test_close_graceful_py313_waits_for_join_and_children( + event_queue: EventQueue, +) -> None: + """For Python >=3.13 and immediate=False, close should shutdown(False), then wait for join and children.""" + with patch('sys.version_info', (3, 13, 0)): + # Arrange + from typing import cast + + q_any = cast('Any', event_queue.queue) + q_any.shutdown = MagicMock() # type: ignore[attr-defined] + event_queue.queue.join = AsyncMock() + + child = event_queue.tap() + child.close = AsyncMock() + + # Act + await event_queue.close(immediate=False) + + # Assert + event_queue.queue.join.assert_awaited_once() + child.close.assert_awaited_once() @pytest.mark.asyncio @@ -345,15 +360,18 @@ async def test_close_idempotent(event_queue: EventQueue) -> None: # Reset for new Python version test event_queue_new = EventQueue() # New queue for fresh state - with patch('sys.version_info', (3, 13, 0)): # Test with newer version logic - event_queue_new.queue.shutdown = MagicMock() + with patch('sys.version_info', (3, 13, 0)): + from typing import cast + + queue = cast('Any', event_queue_new.queue) + queue.shutdown = MagicMock() # type: ignore[attr-defined] await event_queue_new.close() assert event_queue_new.is_closed() is True - event_queue_new.queue.shutdown.assert_called_once() + queue.shutdown.assert_called_once() await event_queue_new.close() assert event_queue_new.is_closed() is True - event_queue_new.queue.shutdown.assert_called_once() # Still only called once + queue.shutdown.assert_called_once() # Still only called once @pytest.mark.asyncio diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index f96ce5e6..6765000c 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import logging import time @@ -48,6 +49,7 @@ TaskQueryParams, TaskState, TaskStatus, + TaskStatusUpdateEvent, TextPart, UnsupportedOperationError, ) @@ -1331,6 +1333,15 @@ async def single_event_stream(): mock_result_aggregator_instance.consume_and_emit.return_value = ( single_event_stream() ) + # Signal when background consume_all is started + bg_started = asyncio.Event() + + async def mock_consume_all(_consumer): + bg_started.set() + # emulate short-running background work + await asyncio.sleep(0) + + mock_result_aggregator_instance.consume_all = mock_consume_all produced_task: asyncio.Task | None = None cleanup_task: asyncio.Task | None = None @@ -1367,6 +1378,9 @@ def create_task_spy(coro): assert produced_task is not None assert cleanup_task is not None + # Assert background consume_all started + await asyncio.wait_for(bg_started.wait(), timeout=0.2) + # execute should have started await asyncio.wait_for(execute_started.wait(), timeout=0.1) @@ -1385,6 +1399,91 @@ def create_task_spy(coro): # Running agents is cleared assert task_id not in request_handler._running_agents + # Cleanup any lingering background tasks started by on_message_send_stream + # (e.g., background_consume) + for t in list(request_handler._background_tasks): + t.cancel() + with contextlib.suppress(asyncio.CancelledError): + await t + + +@pytest.mark.asyncio +async def test_disconnect_persists_final_task_to_store(): + """After client disconnect, ensure background consumer persists final Task to store.""" + task_store = InMemoryTaskStore() + queue_manager = InMemoryQueueManager() + + # Custom agent that emits a working update then a completed final update + class FinishingAgent(AgentExecutor): + def __init__(self): + self.allow_finish = asyncio.Event() + + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + from typing import cast + + updater = TaskUpdater( + event_queue, + cast('str', context.task_id), + cast('str', context.context_id), + ) + await updater.update_status(TaskState.working) + await self.allow_finish.wait() + await updater.update_status(TaskState.completed) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + return None + + agent = FinishingAgent() + + handler = DefaultRequestHandler( + agent_executor=agent, task_store=task_store, queue_manager=queue_manager + ) + + params = MessageSendParams( + message=Message( + role=Role.user, + message_id='msg_persist', + parts=[], + ) + ) + + # Start streaming and consume the first event (working) + agen = handler.on_message_send_stream(params, create_server_call_context()) + first = await agen.__anext__() + if isinstance(first, TaskStatusUpdateEvent): + assert first.status.state == TaskState.working + task_id = first.task_id + else: + assert ( + isinstance(first, Task) and first.status.state == TaskState.working + ) + task_id = first.id + + # Disconnect client + await asyncio.wait_for(agen.aclose(), timeout=0.1) + + # Finish agent and allow background consumer to persist final state + agent.allow_finish.set() + + # Wait until background_consume task for this task_id is gone + await wait_until( + lambda: all( + not t.get_name().startswith(f'background_consume:{task_id}') + for t in handler._background_tasks + ), + timeout=1.0, + interval=0.01, + ) + + # Verify task is persisted as completed + persisted = await task_store.get(task_id, create_server_call_context()) + assert persisted is not None + assert persisted.status.state == TaskState.completed + async def wait_until(predicate, timeout: float = 0.2, interval: float = 0.0): """Await until predicate() is True or timeout elapses.""" @@ -1505,6 +1604,12 @@ def create_task_spy(coro): timeout=0.1, ) + # Cleanup any lingering background tasks + for t in list(request_handler._background_tasks): + t.cancel() + with contextlib.suppress(asyncio.CancelledError): + await t + @pytest.mark.asyncio async def test_on_message_send_stream_task_id_mismatch():