Skip to content

Commit 1b9da65

Browse files
authored
Merge branch 'main' into fix/logging-and-string-formatting
2 parents ecd9b3c + c147a83 commit 1b9da65

File tree

7 files changed

+426
-23
lines changed

7 files changed

+426
-23
lines changed

src/a2a/server/events/event_consumer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def consume_all(self) -> AsyncGenerator[Event]:
125125
# other part is waiting for an event or a closed queue.
126126
if is_final_event:
127127
logger.debug('Stopping event consumption in consume_all.')
128-
await self.queue.close()
128+
await self.queue.close(True)
129129
yield event
130130
break
131131
yield event
@@ -135,7 +135,7 @@ async def consume_all(self) -> AsyncGenerator[Event]:
135135
except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
136136
# This class was made an alias of build-in TimeoutError after 3.11
137137
continue
138-
except QueueClosed:
138+
except (QueueClosed, asyncio.QueueEmpty):
139139
# Confirm that the queue is closed, e.g. we aren't on
140140
# python 3.12 and get a queue empty error on an open queue
141141
if self.queue.is_closed():

src/a2a/server/events/event_queue.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ async def dequeue_event(self, no_wait: bool = False) -> Event:
9090
asyncio.QueueShutDown: If the queue has been closed and is empty.
9191
"""
9292
async with self._lock:
93-
if self._is_closed and self.queue.empty():
93+
if (
94+
sys.version_info < (3, 13)
95+
and self._is_closed
96+
and self.queue.empty()
97+
):
98+
# On 3.13+, skip early raise; await self.queue.get() will raise QueueShutDown after shutdown()
9499
logger.warning('Queue is closed. Event will not be dequeued.')
95100
raise asyncio.QueueEmpty('Queue is closed.')
96101

@@ -127,25 +132,38 @@ def tap(self) -> 'EventQueue':
127132
self._children.append(queue)
128133
return queue
129134

130-
async def close(self) -> None:
131-
"""Closes the queue for future push events.
135+
async def close(self, immediate: bool = False) -> None:
136+
"""Closes the queue for future push events and also closes all child queues.
137+
138+
Once closed, no new events can be enqueued. For Python 3.13+, this will trigger
139+
`asyncio.QueueShutDown` when the queue is empty and a consumer tries to dequeue.
140+
For lower versions, the queue will be marked as closed and optionally cleared.
141+
142+
Args:
143+
immediate (bool):
144+
- True: Immediately closes the queue and clears all unprocessed events without waiting for them to be consumed. This is suitable for scenarios where you need to forcefully interrupt and quickly release resources.
145+
- False (default): Gracefully closes the queue, waiting for all queued events to be processed (i.e., the queue is drained) before closing. This is suitable when you want to ensure all events are handled.
132146
133-
Once closed, `dequeue_event` will eventually raise `asyncio.QueueShutDown`
134-
when the queue is empty. Also closes all child queues.
135147
"""
136148
logger.debug('Closing EventQueue.')
137149
async with self._lock:
138150
# If already closed, just return.
139-
if self._is_closed:
151+
if self._is_closed and not immediate:
140152
return
141-
self._is_closed = True
153+
if not self._is_closed:
154+
self._is_closed = True
142155
# If using python 3.13 or higher, use the shutdown method
143156
if sys.version_info >= (3, 13):
144-
self.queue.shutdown()
157+
self.queue.shutdown(immediate)
145158
for child in self._children:
146-
await child.close()
159+
await child.close(immediate)
147160
# Otherwise, join the queue
148161
else:
162+
if immediate:
163+
await self.clear_events(True)
164+
for child in self._children:
165+
await child.close(immediate)
166+
return
149167
tasks = [asyncio.create_task(self.queue.join())]
150168
tasks.extend(
151169
asyncio.create_task(child.close()) for child in self._children
@@ -155,3 +173,53 @@ async def close(self) -> None:
155173
def is_closed(self) -> bool:
156174
"""Checks if the queue is closed."""
157175
return self._is_closed
176+
177+
async def clear_events(self, clear_child_queues: bool = True) -> None:
178+
"""Clears all events from the current queue and optionally all child queues.
179+
180+
This method removes all pending events from the queue without processing them.
181+
Child queues can be optionally cleared based on the clear_child_queues parameter.
182+
183+
Args:
184+
clear_child_queues: If True (default), clear all child queues as well.
185+
If False, only clear the current queue, leaving child queues untouched.
186+
"""
187+
logger.debug('Clearing all events from EventQueue and child queues.')
188+
189+
# Clear all events from the queue, even if closed
190+
cleared_count = 0
191+
async with self._lock:
192+
try:
193+
while True:
194+
event = self.queue.get_nowait()
195+
logger.debug(
196+
f'Discarding unprocessed event of type: {type(event)}, content: {event}'
197+
)
198+
self.queue.task_done()
199+
cleared_count += 1
200+
except asyncio.QueueEmpty:
201+
pass
202+
except Exception as e:
203+
# Handle Python 3.13+ QueueShutDown
204+
if (
205+
sys.version_info >= (3, 13)
206+
and type(e).__name__ == 'QueueShutDown'
207+
):
208+
pass
209+
else:
210+
raise
211+
212+
if cleared_count > 0:
213+
logger.debug(
214+
f'Cleared {cleared_count} unprocessed events from EventQueue.'
215+
)
216+
217+
# Clear all child queues (lock released before awaiting child tasks)
218+
if clear_child_queues and self._children:
219+
child_tasks = [
220+
asyncio.create_task(child.clear_events())
221+
for child in self._children
222+
]
223+
224+
if child_tasks:
225+
await asyncio.gather(*child_tasks, return_exceptions=True)

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,19 @@ async def on_message_send(
286286

287287
interrupted_or_non_blocking = False
288288
try:
289+
# Create async callback for push notifications
290+
async def push_notification_callback() -> None:
291+
await self._send_push_notification_if_needed(
292+
task_id, result_aggregator
293+
)
294+
289295
(
290296
result,
291297
interrupted_or_non_blocking,
292298
) = await result_aggregator.consume_and_break_on_interrupt(
293-
consumer, blocking=blocking
299+
consumer,
300+
blocking=blocking,
301+
event_callback=push_notification_callback,
294302
)
295303
if not result:
296304
raise ServerError(error=InternalError()) # noqa: TRY301

src/a2a/server/tasks/result_aggregator.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33

4-
from collections.abc import AsyncGenerator, AsyncIterator
4+
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
55

66
from a2a.server.events import Event, EventConsumer
77
from a2a.server.tasks.task_manager import TaskManager
@@ -24,7 +24,10 @@ class ResultAggregator:
2424
Task object and emit that Task object.
2525
"""
2626

27-
def __init__(self, task_manager: TaskManager):
27+
def __init__(
28+
self,
29+
task_manager: TaskManager,
30+
) -> None:
2831
"""Initializes the ResultAggregator.
2932
3033
Args:
@@ -92,7 +95,10 @@ async def consume_all(
9295
return await self.task_manager.get_task()
9396

9497
async def consume_and_break_on_interrupt(
95-
self, consumer: EventConsumer, blocking: bool = True
98+
self,
99+
consumer: EventConsumer,
100+
blocking: bool = True,
101+
event_callback: Callable[[], Awaitable[None]] | None = None,
96102
) -> tuple[Task | Message | None, bool]:
97103
"""Processes the event stream until completion or an interruptable state is encountered.
98104
@@ -105,6 +111,9 @@ async def consume_and_break_on_interrupt(
105111
consumer: The `EventConsumer` to read events from.
106112
blocking: If `False`, the method returns as soon as a task/message
107113
is available. If `True`, it waits for a terminal state.
114+
event_callback: Optional async callback function to be called after each event
115+
is processed in the background continuation.
116+
Mainly used for push notifications currently.
108117
109118
Returns:
110119
A tuple containing:
@@ -150,13 +159,17 @@ async def consume_and_break_on_interrupt(
150159
if should_interrupt:
151160
# Continue consuming the rest of the events in the background.
152161
# TODO: We should track all outstanding tasks to ensure they eventually complete.
153-
asyncio.create_task(self._continue_consuming(event_stream)) # noqa: RUF006
162+
asyncio.create_task( # noqa: RUF006
163+
self._continue_consuming(event_stream, event_callback)
164+
)
154165
interrupted = True
155166
break
156167
return await self.task_manager.get_task(), interrupted
157168

158169
async def _continue_consuming(
159-
self, event_stream: AsyncIterator[Event]
170+
self,
171+
event_stream: AsyncIterator[Event],
172+
event_callback: Callable[[], Awaitable[None]] | None = None,
160173
) -> None:
161174
"""Continues processing an event stream in a background task.
162175
@@ -165,6 +178,9 @@ async def _continue_consuming(
165178
166179
Args:
167180
event_stream: The remaining `AsyncIterator` of events from the consumer.
181+
event_callback: Optional async callback function to be called after each event is processed.
168182
"""
169183
async for event in event_stream:
170184
await self.task_manager.process(event)
185+
if event_callback:
186+
await event_callback()

tests/server/events/test_event_consumer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,59 @@ async def test_consume_all_continues_on_queue_empty_if_not_really_closed(
324324
assert mock_event_queue.is_closed.call_count == 1
325325

326326

327+
@pytest.mark.asyncio
328+
async def test_consume_all_handles_queue_empty_when_closed_python_version_agnostic(
329+
event_consumer: EventConsumer, mock_event_queue: AsyncMock, monkeypatch
330+
):
331+
"""Ensure consume_all stops with no events when queue is closed and dequeue_event raises asyncio.QueueEmpty (Python version-agnostic)."""
332+
# Make QueueClosed a distinct exception (not QueueEmpty) to emulate py3.13 semantics
333+
from a2a.server.events import event_consumer as ec
334+
335+
class QueueShutDown(Exception):
336+
pass
337+
338+
monkeypatch.setattr(ec, 'QueueClosed', QueueShutDown, raising=True)
339+
340+
# Simulate queue reporting closed while dequeue raises QueueEmpty
341+
mock_event_queue.dequeue_event.side_effect = asyncio.QueueEmpty(
342+
'closed/empty'
343+
)
344+
mock_event_queue.is_closed.return_value = True
345+
346+
consumed_events = []
347+
async for event in event_consumer.consume_all():
348+
consumed_events.append(event)
349+
350+
assert consumed_events == []
351+
mock_event_queue.dequeue_event.assert_called_once()
352+
mock_event_queue.is_closed.assert_called_once()
353+
354+
355+
@pytest.mark.asyncio
356+
async def test_consume_all_continues_on_queue_empty_when_not_closed(
357+
event_consumer: EventConsumer, mock_event_queue: AsyncMock, monkeypatch
358+
):
359+
"""Ensure consume_all continues after asyncio.QueueEmpty when queue is open, yielding the next (final) event."""
360+
# First dequeue raises QueueEmpty (transient empty), then a final Message arrives
361+
final = Message(role='agent', parts=[{'text': 'done'}], message_id='final')
362+
mock_event_queue.dequeue_event.side_effect = [
363+
asyncio.QueueEmpty('temporarily empty'),
364+
final,
365+
]
366+
mock_event_queue.is_closed.return_value = False
367+
368+
# Make the polling responsive in tests
369+
event_consumer._timeout = 0.001
370+
371+
consumed = []
372+
async for ev in event_consumer.consume_all():
373+
consumed.append(ev)
374+
375+
assert consumed == [final]
376+
assert mock_event_queue.dequeue_event.call_count == 2
377+
mock_event_queue.is_closed.assert_called_once()
378+
379+
327380
def test_agent_task_callback_sets_exception(event_consumer: EventConsumer):
328381
"""Test that agent_task_callback sets _exception if the task had one."""
329382
mock_task = MagicMock(spec=asyncio.Task)

0 commit comments

Comments
 (0)