Skip to content

Commit 36321d9

Browse files
authored
Merge branch 'main' into test/refactor-tests
2 parents 76ea260 + 5342ca4 commit 36321d9

File tree

5 files changed

+183
-36
lines changed

5 files changed

+183
-36
lines changed

src/a2a/client/transports/rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ async def get_task(
206206
context: ClientCallContext | None = None,
207207
) -> Task:
208208
"""Retrieves the current state and history of a specific task."""
209-
payload, modified_kwargs = await self._apply_interceptors(
209+
_payload, modified_kwargs = await self._apply_interceptors(
210210
request.model_dump(mode='json', exclude_none=True),
211211
self._get_http_args(context),
212212
context,

src/a2a/server/events/event_queue.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,18 @@ def tap(self) -> 'EventQueue':
135135
async def close(self, immediate: bool = False) -> None:
136136
"""Closes the queue for future push events and also closes all child queues.
137137
138-
Once closed, no new events can be enqueued. For Python 3.13+, this will trigger
139-
`asyncio.QueueShutDown` when the queue is empty and a consumer tries to dequeue.
140-
For lower versions, the queue will be marked as closed and optionally cleared.
138+
Once closed, no new events can be enqueued. Behavior is consistent across
139+
Python versions:
140+
- Python >= 3.13: Uses `asyncio.Queue.shutdown` to stop the queue. With
141+
`immediate=True` the queue is shut down and pending events are cleared; with
142+
`immediate=False` the queue is shut down and we wait for it to drain via
143+
`queue.join()`.
144+
- Python < 3.13: Emulates the same semantics by clearing on `immediate=True`
145+
or awaiting `queue.join()` on `immediate=False`.
146+
147+
Consumers attempting to dequeue after close on an empty queue will observe
148+
`asyncio.QueueShutDown` on Python >= 3.13 and `asyncio.QueueEmpty` on
149+
Python < 3.13.
141150
142151
Args:
143152
immediate (bool):
@@ -152,23 +161,30 @@ async def close(self, immediate: bool = False) -> None:
152161
return
153162
if not self._is_closed:
154163
self._is_closed = True
155-
# If using python 3.13 or higher, use the shutdown method
164+
# If using python 3.13 or higher, use shutdown but match <3.13 semantics
156165
if sys.version_info >= (3, 13):
157-
self.queue.shutdown(immediate)
158-
for child in self._children:
159-
await child.close(immediate)
166+
if immediate:
167+
# Immediate: stop queue and clear any pending events, then close children
168+
self.queue.shutdown(True)
169+
await self.clear_events(True)
170+
for child in self._children:
171+
await child.close(True)
172+
return
173+
# Graceful: prevent further gets/puts via shutdown, then wait for drain and children
174+
self.queue.shutdown(False)
175+
await asyncio.gather(
176+
self.queue.join(), *(child.close() for child in self._children)
177+
)
160178
# Otherwise, join the queue
161179
else:
162180
if immediate:
163181
await self.clear_events(True)
164182
for child in self._children:
165183
await child.close(immediate)
166184
return
167-
tasks = [asyncio.create_task(self.queue.join())]
168-
tasks.extend(
169-
asyncio.create_task(child.close()) for child in self._children
185+
await asyncio.gather(
186+
self.queue.join(), *(child.close() for child in self._children)
170187
)
171-
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
172188

173189
def is_closed(self) -> bool:
174190
"""Checks if the queue is closed."""

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ async def on_message_send(
314314
result (Task or Message).
315315
"""
316316
(
317-
task_manager,
317+
_task_manager,
318318
task_id,
319319
queue,
320320
result_aggregator,
@@ -379,16 +379,16 @@ async def on_message_send_stream(
379379
by the agent.
380380
"""
381381
(
382-
task_manager,
382+
_task_manager,
383383
task_id,
384384
queue,
385385
result_aggregator,
386386
producer_task,
387387
) = await self._setup_message_execution(params, context)
388+
consumer = EventConsumer(queue)
389+
producer_task.add_done_callback(consumer.agent_task_callback)
388390

389391
try:
390-
consumer = EventConsumer(queue)
391-
producer_task.add_done_callback(consumer.agent_task_callback)
392392
async for event in result_aggregator.consume_and_emit(consumer):
393393
if isinstance(event, Task):
394394
self._validate_task_id_match(task_id, event.id)
@@ -397,6 +397,14 @@ async def on_message_send_stream(
397397
task_id, result_aggregator
398398
)
399399
yield event
400+
except (asyncio.CancelledError, GeneratorExit):
401+
# Client disconnected: continue consuming and persisting events in the background
402+
bg_task = asyncio.create_task(
403+
result_aggregator.consume_all(consumer)
404+
)
405+
bg_task.set_name(f'background_consume:{task_id}')
406+
self._track_background_task(bg_task)
407+
raise
400408
finally:
401409
cleanup_task = asyncio.create_task(
402410
self._cleanup_producer(producer_task, task_id)

tests/server/events/test_event_queue.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -271,15 +271,7 @@ async def test_tap_creates_child_queue(event_queue: EventQueue) -> None:
271271

272272

273273
@pytest.mark.asyncio
274-
@patch(
275-
'asyncio.wait'
276-
) # To monitor calls to asyncio.wait for older Python versions
277-
@patch(
278-
'asyncio.create_task'
279-
) # To monitor calls to asyncio.create_task for older Python versions
280274
async def test_close_sets_flag_and_handles_internal_queue_old_python(
281-
mock_create_task: MagicMock,
282-
mock_asyncio_wait: AsyncMock,
283275
event_queue: EventQueue,
284276
) -> None:
285277
"""Test close behavior on Python < 3.13 (using queue.join)."""
@@ -290,24 +282,47 @@ async def test_close_sets_flag_and_handles_internal_queue_old_python(
290282
await event_queue.close()
291283

292284
assert event_queue.is_closed() is True
293-
event_queue.queue.join.assert_called_once() # specific to <3.13
294-
mock_create_task.assert_called_once() # create_task for join
295-
mock_asyncio_wait.assert_called_once() # wait for join
285+
event_queue.queue.join.assert_awaited_once() # waited for drain
296286

297287

298288
@pytest.mark.asyncio
299289
async def test_close_sets_flag_and_handles_internal_queue_new_python(
300290
event_queue: EventQueue,
301291
) -> None:
302292
"""Test close behavior on Python >= 3.13 (using queue.shutdown)."""
303-
with patch('sys.version_info', (3, 13, 0)): # Simulate Python 3.13+
304-
# Mock queue.shutdown as it's called in newer versions
305-
event_queue.queue.shutdown = MagicMock() # shutdown is not async
293+
with patch('sys.version_info', (3, 13, 0)):
294+
# Inject a dummy shutdown method for non-3.13 runtimes
295+
from typing import cast
306296

297+
queue = cast('Any', event_queue.queue)
298+
queue.shutdown = MagicMock() # type: ignore[attr-defined]
307299
await event_queue.close()
308-
309300
assert event_queue.is_closed() is True
310-
event_queue.queue.shutdown.assert_called_once() # specific to >=3.13
301+
queue.shutdown.assert_called_once_with(False)
302+
303+
304+
@pytest.mark.asyncio
305+
async def test_close_graceful_py313_waits_for_join_and_children(
306+
event_queue: EventQueue,
307+
) -> None:
308+
"""For Python >=3.13 and immediate=False, close should shutdown(False), then wait for join and children."""
309+
with patch('sys.version_info', (3, 13, 0)):
310+
# Arrange
311+
from typing import cast
312+
313+
q_any = cast('Any', event_queue.queue)
314+
q_any.shutdown = MagicMock() # type: ignore[attr-defined]
315+
event_queue.queue.join = AsyncMock()
316+
317+
child = event_queue.tap()
318+
child.close = AsyncMock()
319+
320+
# Act
321+
await event_queue.close(immediate=False)
322+
323+
# Assert
324+
event_queue.queue.join.assert_awaited_once()
325+
child.close.assert_awaited_once()
311326

312327

313328
@pytest.mark.asyncio
@@ -345,15 +360,18 @@ async def test_close_idempotent(event_queue: EventQueue) -> None:
345360

346361
# Reset for new Python version test
347362
event_queue_new = EventQueue() # New queue for fresh state
348-
with patch('sys.version_info', (3, 13, 0)): # Test with newer version logic
349-
event_queue_new.queue.shutdown = MagicMock()
363+
with patch('sys.version_info', (3, 13, 0)):
364+
from typing import cast
365+
366+
queue = cast('Any', event_queue_new.queue)
367+
queue.shutdown = MagicMock() # type: ignore[attr-defined]
350368
await event_queue_new.close()
351369
assert event_queue_new.is_closed() is True
352-
event_queue_new.queue.shutdown.assert_called_once()
370+
queue.shutdown.assert_called_once()
353371

354372
await event_queue_new.close()
355373
assert event_queue_new.is_closed() is True
356-
event_queue_new.queue.shutdown.assert_called_once() # Still only called once
374+
queue.shutdown.assert_called_once() # Still only called once
357375

358376

359377
@pytest.mark.asyncio

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import logging
34
import time
45

@@ -48,6 +49,7 @@
4849
TaskQueryParams,
4950
TaskState,
5051
TaskStatus,
52+
TaskStatusUpdateEvent,
5153
TextPart,
5254
UnsupportedOperationError,
5355
)
@@ -1331,6 +1333,15 @@ async def single_event_stream():
13311333
mock_result_aggregator_instance.consume_and_emit.return_value = (
13321334
single_event_stream()
13331335
)
1336+
# Signal when background consume_all is started
1337+
bg_started = asyncio.Event()
1338+
1339+
async def mock_consume_all(_consumer):
1340+
bg_started.set()
1341+
# emulate short-running background work
1342+
await asyncio.sleep(0)
1343+
1344+
mock_result_aggregator_instance.consume_all = mock_consume_all
13341345

13351346
produced_task: asyncio.Task | None = None
13361347
cleanup_task: asyncio.Task | None = None
@@ -1367,6 +1378,9 @@ def create_task_spy(coro):
13671378
assert produced_task is not None
13681379
assert cleanup_task is not None
13691380

1381+
# Assert background consume_all started
1382+
await asyncio.wait_for(bg_started.wait(), timeout=0.2)
1383+
13701384
# execute should have started
13711385
await asyncio.wait_for(execute_started.wait(), timeout=0.1)
13721386

@@ -1385,6 +1399,91 @@ def create_task_spy(coro):
13851399
# Running agents is cleared
13861400
assert task_id not in request_handler._running_agents
13871401

1402+
# Cleanup any lingering background tasks started by on_message_send_stream
1403+
# (e.g., background_consume)
1404+
for t in list(request_handler._background_tasks):
1405+
t.cancel()
1406+
with contextlib.suppress(asyncio.CancelledError):
1407+
await t
1408+
1409+
1410+
@pytest.mark.asyncio
1411+
async def test_disconnect_persists_final_task_to_store():
1412+
"""After client disconnect, ensure background consumer persists final Task to store."""
1413+
task_store = InMemoryTaskStore()
1414+
queue_manager = InMemoryQueueManager()
1415+
1416+
# Custom agent that emits a working update then a completed final update
1417+
class FinishingAgent(AgentExecutor):
1418+
def __init__(self):
1419+
self.allow_finish = asyncio.Event()
1420+
1421+
async def execute(
1422+
self, context: RequestContext, event_queue: EventQueue
1423+
):
1424+
from typing import cast
1425+
1426+
updater = TaskUpdater(
1427+
event_queue,
1428+
cast('str', context.task_id),
1429+
cast('str', context.context_id),
1430+
)
1431+
await updater.update_status(TaskState.working)
1432+
await self.allow_finish.wait()
1433+
await updater.update_status(TaskState.completed)
1434+
1435+
async def cancel(
1436+
self, context: RequestContext, event_queue: EventQueue
1437+
):
1438+
return None
1439+
1440+
agent = FinishingAgent()
1441+
1442+
handler = DefaultRequestHandler(
1443+
agent_executor=agent, task_store=task_store, queue_manager=queue_manager
1444+
)
1445+
1446+
params = MessageSendParams(
1447+
message=Message(
1448+
role=Role.user,
1449+
message_id='msg_persist',
1450+
parts=[],
1451+
)
1452+
)
1453+
1454+
# Start streaming and consume the first event (working)
1455+
agen = handler.on_message_send_stream(params, create_server_call_context())
1456+
first = await agen.__anext__()
1457+
if isinstance(first, TaskStatusUpdateEvent):
1458+
assert first.status.state == TaskState.working
1459+
task_id = first.task_id
1460+
else:
1461+
assert (
1462+
isinstance(first, Task) and first.status.state == TaskState.working
1463+
)
1464+
task_id = first.id
1465+
1466+
# Disconnect client
1467+
await asyncio.wait_for(agen.aclose(), timeout=0.1)
1468+
1469+
# Finish agent and allow background consumer to persist final state
1470+
agent.allow_finish.set()
1471+
1472+
# Wait until background_consume task for this task_id is gone
1473+
await wait_until(
1474+
lambda: all(
1475+
not t.get_name().startswith(f'background_consume:{task_id}')
1476+
for t in handler._background_tasks
1477+
),
1478+
timeout=1.0,
1479+
interval=0.01,
1480+
)
1481+
1482+
# Verify task is persisted as completed
1483+
persisted = await task_store.get(task_id, create_server_call_context())
1484+
assert persisted is not None
1485+
assert persisted.status.state == TaskState.completed
1486+
13881487

13891488
async def wait_until(predicate, timeout: float = 0.2, interval: float = 0.0):
13901489
"""Await until predicate() is True or timeout elapses."""
@@ -1505,6 +1604,12 @@ def create_task_spy(coro):
15051604
timeout=0.1,
15061605
)
15071606

1607+
# Cleanup any lingering background tasks
1608+
for t in list(request_handler._background_tasks):
1609+
t.cancel()
1610+
with contextlib.suppress(asyncio.CancelledError):
1611+
await t
1612+
15081613

15091614
@pytest.mark.asyncio
15101615
async def test_on_message_send_stream_task_id_mismatch():

0 commit comments

Comments
 (0)