Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ async def get_task(
context: ClientCallContext | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""
payload, modified_kwargs = await self._apply_interceptors(
_payload, modified_kwargs = await self._apply_interceptors(
request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
context,
Expand Down
38 changes: 27 additions & 11 deletions src/a2a/server/events/event_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,18 @@ def tap(self) -> 'EventQueue':
async def close(self, immediate: bool = False) -> None:
"""Closes the queue for future push events and also closes all child queues.

Once closed, no new events can be enqueued. For Python 3.13+, this will trigger
`asyncio.QueueShutDown` when the queue is empty and a consumer tries to dequeue.
For lower versions, the queue will be marked as closed and optionally cleared.
Once closed, no new events can be enqueued. Behavior is consistent across
Python versions:
- Python >= 3.13: Uses `asyncio.Queue.shutdown` to stop the queue. With
`immediate=True` the queue is shut down and pending events are cleared; with
`immediate=False` the queue is shut down and we wait for it to drain via
`queue.join()`.
- Python < 3.13: Emulates the same semantics by clearing on `immediate=True`
or awaiting `queue.join()` on `immediate=False`.

Consumers attempting to dequeue after close on an empty queue will observe
`asyncio.QueueShutDown` on Python >= 3.13 and `asyncio.QueueEmpty` on
Python < 3.13.

Args:
immediate (bool):
Expand All @@ -152,23 +161,30 @@ async def close(self, immediate: bool = False) -> None:
return
if not self._is_closed:
self._is_closed = True
# If using python 3.13 or higher, use the shutdown method
# If using python 3.13 or higher, use shutdown but match <3.13 semantics
if sys.version_info >= (3, 13):
self.queue.shutdown(immediate)
for child in self._children:
await child.close(immediate)
if immediate:
# Immediate: stop queue and clear any pending events, then close children
self.queue.shutdown(True)
await self.clear_events(True)
for child in self._children:
await child.close(True)
return
# Graceful: prevent further gets/puts via shutdown, then wait for drain and children
self.queue.shutdown(False)
await asyncio.gather(
self.queue.join(), *(child.close() for child in self._children)
)
# Otherwise, join the queue
else:
if immediate:
await self.clear_events(True)
for child in self._children:
await child.close(immediate)
return
tasks = [asyncio.create_task(self.queue.join())]
tasks.extend(
asyncio.create_task(child.close()) for child in self._children
await asyncio.gather(
self.queue.join(), *(child.close() for child in self._children)
)
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)

def is_closed(self) -> bool:
"""Checks if the queue is closed."""
Expand Down
16 changes: 12 additions & 4 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ async def on_message_send(
result (Task or Message).
"""
(
task_manager,
_task_manager,
task_id,
queue,
result_aggregator,
Expand Down Expand Up @@ -379,16 +379,16 @@ async def on_message_send_stream(
by the agent.
"""
(
task_manager,
_task_manager,
task_id,
queue,
result_aggregator,
producer_task,
) = await self._setup_message_execution(params, context)
consumer = EventConsumer(queue)
producer_task.add_done_callback(consumer.agent_task_callback)

try:
consumer = EventConsumer(queue)
producer_task.add_done_callback(consumer.agent_task_callback)
async for event in result_aggregator.consume_and_emit(consumer):
if isinstance(event, Task):
self._validate_task_id_match(task_id, event.id)
Expand All @@ -397,6 +397,14 @@ async def on_message_send_stream(
task_id, result_aggregator
)
yield event
except (asyncio.CancelledError, GeneratorExit):
# Client disconnected: continue consuming and persisting events in the background
bg_task = asyncio.create_task(
result_aggregator.consume_all(consumer)
)
bg_task.set_name(f'background_consume:{task_id}')
self._track_background_task(bg_task)
raise
finally:
cleanup_task = asyncio.create_task(
self._cleanup_producer(producer_task, task_id)
Expand Down
55 changes: 35 additions & 20 deletions tests/server/events/test_event_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,7 @@ async def test_tap_creates_child_queue(event_queue: EventQueue) -> None:


@pytest.mark.asyncio
@patch(
'asyncio.wait'
) # To monitor calls to asyncio.wait for older Python versions
@patch(
'asyncio.create_task'
) # To monitor calls to asyncio.create_task for older Python versions
async def test_close_sets_flag_and_handles_internal_queue_old_python(
mock_create_task: MagicMock,
mock_asyncio_wait: AsyncMock,
event_queue: EventQueue,
) -> None:
"""Test close behavior on Python < 3.13 (using queue.join)."""
Expand All @@ -290,24 +282,46 @@ async def test_close_sets_flag_and_handles_internal_queue_old_python(
await event_queue.close()

assert event_queue.is_closed() is True
event_queue.queue.join.assert_called_once() # specific to <3.13
mock_create_task.assert_called_once() # create_task for join
mock_asyncio_wait.assert_called_once() # wait for join
event_queue.queue.join.assert_awaited_once() # waited for drain


@pytest.mark.asyncio
async def test_close_sets_flag_and_handles_internal_queue_new_python(
event_queue: EventQueue,
) -> None:
"""Test close behavior on Python >= 3.13 (using queue.shutdown)."""
with patch('sys.version_info', (3, 13, 0)): # Simulate Python 3.13+
# Mock queue.shutdown as it's called in newer versions
event_queue.queue.shutdown = MagicMock() # shutdown is not async
with patch('sys.version_info', (3, 13, 0)):
# Inject a dummy shutdown method for non-3.13 runtimes
from typing import cast

queue = cast('Any', event_queue.queue)
queue.shutdown = MagicMock() # type: ignore[attr-defined]
await event_queue.close()

assert event_queue.is_closed() is True
event_queue.queue.shutdown.assert_called_once() # specific to >=3.13


@pytest.mark.asyncio
async def test_close_graceful_py313_waits_for_join_and_children(
event_queue: EventQueue,
) -> None:
"""For Python >=3.13 and immediate=False, close should shutdown(False), then wait for join and children."""
with patch('sys.version_info', (3, 13, 0)):
# Arrange
from typing import cast

q_any = cast('Any', event_queue.queue)
q_any.shutdown = MagicMock() # type: ignore[attr-defined]
event_queue.queue.join = AsyncMock()

child = event_queue.tap()
child.close = AsyncMock()

# Act
await event_queue.close(immediate=False)

# Assert
event_queue.queue.join.assert_awaited_once()
child.close.assert_awaited_once()


@pytest.mark.asyncio
Expand Down Expand Up @@ -345,15 +359,16 @@ async def test_close_idempotent(event_queue: EventQueue) -> None:

# Reset for new Python version test
event_queue_new = EventQueue() # New queue for fresh state
with patch('sys.version_info', (3, 13, 0)): # Test with newer version logic
event_queue_new.queue.shutdown = MagicMock()
with patch('sys.version_info', (3, 13, 0)):
from typing import cast

queue = cast('Any', event_queue_new.queue)
queue.shutdown = MagicMock() # type: ignore[attr-defined]
await event_queue_new.close()
assert event_queue_new.is_closed() is True
event_queue_new.queue.shutdown.assert_called_once()

await event_queue_new.close()
assert event_queue_new.is_closed() is True
event_queue_new.queue.shutdown.assert_called_once() # Still only called once


@pytest.mark.asyncio
Expand Down
105 changes: 105 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
import logging
import time

Expand Down Expand Up @@ -48,6 +49,7 @@
TaskQueryParams,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
TextPart,
UnsupportedOperationError,
)
Expand Down Expand Up @@ -1331,6 +1333,15 @@ async def single_event_stream():
mock_result_aggregator_instance.consume_and_emit.return_value = (
single_event_stream()
)
# Signal when background consume_all is started
bg_started = asyncio.Event()

async def mock_consume_all(_consumer):
bg_started.set()
# emulate short-running background work
await asyncio.sleep(0)

mock_result_aggregator_instance.consume_all = mock_consume_all

produced_task: asyncio.Task | None = None
cleanup_task: asyncio.Task | None = None
Expand Down Expand Up @@ -1367,6 +1378,9 @@ def create_task_spy(coro):
assert produced_task is not None
assert cleanup_task is not None

# Assert background consume_all started
await asyncio.wait_for(bg_started.wait(), timeout=0.2)

# execute should have started
await asyncio.wait_for(execute_started.wait(), timeout=0.1)

Expand All @@ -1385,6 +1399,91 @@ def create_task_spy(coro):
# Running agents is cleared
assert task_id not in request_handler._running_agents

# Cleanup any lingering background tasks started by on_message_send_stream
# (e.g., background_consume)
for t in list(request_handler._background_tasks):
t.cancel()
with contextlib.suppress(asyncio.CancelledError):
await t


@pytest.mark.asyncio
async def test_disconnect_persists_final_task_to_store():
"""After client disconnect, ensure background consumer persists final Task to store."""
task_store = InMemoryTaskStore()
queue_manager = InMemoryQueueManager()

# Custom agent that emits a working update then a completed final update
class FinishingAgent(AgentExecutor):
def __init__(self):
self.allow_finish = asyncio.Event()

async def execute(
self, context: RequestContext, event_queue: EventQueue
):
from typing import cast

updater = TaskUpdater(
event_queue,
cast('str', context.task_id),
cast('str', context.context_id),
)
await updater.update_status(TaskState.working)
await self.allow_finish.wait()
await updater.update_status(TaskState.completed)

async def cancel(
self, context: RequestContext, event_queue: EventQueue
):
return None

agent = FinishingAgent()

handler = DefaultRequestHandler(
agent_executor=agent, task_store=task_store, queue_manager=queue_manager
)

params = MessageSendParams(
message=Message(
role=Role.user,
message_id='msg_persist',
parts=[],
)
)

# Start streaming and consume the first event (working)
agen = handler.on_message_send_stream(params, create_server_call_context())
first = await agen.__anext__()
if isinstance(first, TaskStatusUpdateEvent):
assert first.status.state == TaskState.working
task_id = first.task_id
else:
assert (
isinstance(first, Task) and first.status.state == TaskState.working
)
task_id = first.id

# Disconnect client
await asyncio.wait_for(agen.aclose(), timeout=0.1)

# Finish agent and allow background consumer to persist final state
agent.allow_finish.set()

# Wait until background_consume task for this task_id is gone
await wait_until(
lambda: all(
not t.get_name().startswith(f'background_consume:{task_id}')
for t in handler._background_tasks
),
timeout=1.0,
interval=0.01,
)

# Verify task is persisted as completed
persisted = await task_store.get(task_id, create_server_call_context())
assert persisted is not None
assert persisted.status.state == TaskState.completed


async def wait_until(predicate, timeout: float = 0.2, interval: float = 0.0):
"""Await until predicate() is True or timeout elapses."""
Expand Down Expand Up @@ -1505,6 +1604,12 @@ def create_task_spy(coro):
timeout=0.1,
)

# Cleanup any lingering background tasks
for t in list(request_handler._background_tasks):
t.cancel()
with contextlib.suppress(asyncio.CancelledError):
await t


@pytest.mark.asyncio
async def test_on_message_send_stream_task_id_mismatch():
Expand Down
Loading