Skip to content

Commit 3e1e6f7

Browse files
authored
Merge branch 'main' into chore/improve-coverage-grpc-client
2 parents cfbd278 + db82a65 commit 3e1e6f7

File tree

7 files changed

+227
-14
lines changed

7 files changed

+227
-14
lines changed

src/a2a/server/events/event_consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ async def consume_all(self) -> AsyncGenerator[Event]:
135135
except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
136136
# This class was made an alias of build-in TimeoutError after 3.11
137137
continue
138-
except QueueClosed:
138+
except (QueueClosed, asyncio.QueueEmpty):
139139
# Confirm that the queue is closed, e.g. we aren't on
140140
# python 3.12 and get a queue empty error on an open queue
141141
if self.queue.is_closed():

src/a2a/server/events/event_queue.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ async def dequeue_event(self, no_wait: bool = False) -> Event:
9090
asyncio.QueueShutDown: If the queue has been closed and is empty.
9191
"""
9292
async with self._lock:
93-
if self._is_closed and self.queue.empty():
93+
if (
94+
sys.version_info < (3, 13)
95+
and self._is_closed
96+
and self.queue.empty()
97+
):
98+
# On 3.13+, skip early raise; await self.queue.get() will raise QueueShutDown after shutdown()
9499
logger.warning('Queue is closed. Event will not be dequeued.')
95100
raise asyncio.QueueEmpty('Queue is closed.')
96101

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,19 @@ async def on_message_send(
286286

287287
interrupted_or_non_blocking = False
288288
try:
289+
# Create async callback for push notifications
290+
async def push_notification_callback() -> None:
291+
await self._send_push_notification_if_needed(
292+
task_id, result_aggregator
293+
)
294+
289295
(
290296
result,
291297
interrupted_or_non_blocking,
292298
) = await result_aggregator.consume_and_break_on_interrupt(
293-
consumer, blocking=blocking
299+
consumer,
300+
blocking=blocking,
301+
event_callback=push_notification_callback,
294302
)
295303
if not result:
296304
raise ServerError(error=InternalError()) # noqa: TRY301

src/a2a/server/tasks/result_aggregator.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33

4-
from collections.abc import AsyncGenerator, AsyncIterator
4+
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
55

66
from a2a.server.events import Event, EventConsumer
77
from a2a.server.tasks.task_manager import TaskManager
@@ -24,7 +24,10 @@ class ResultAggregator:
2424
Task object and emit that Task object.
2525
"""
2626

27-
def __init__(self, task_manager: TaskManager):
27+
def __init__(
28+
self,
29+
task_manager: TaskManager,
30+
) -> None:
2831
"""Initializes the ResultAggregator.
2932
3033
Args:
@@ -92,7 +95,10 @@ async def consume_all(
9295
return await self.task_manager.get_task()
9396

9497
async def consume_and_break_on_interrupt(
95-
self, consumer: EventConsumer, blocking: bool = True
98+
self,
99+
consumer: EventConsumer,
100+
blocking: bool = True,
101+
event_callback: Callable[[], Awaitable[None]] | None = None,
96102
) -> tuple[Task | Message | None, bool]:
97103
"""Processes the event stream until completion or an interruptable state is encountered.
98104
@@ -105,6 +111,9 @@ async def consume_and_break_on_interrupt(
105111
consumer: The `EventConsumer` to read events from.
106112
blocking: If `False`, the method returns as soon as a task/message
107113
is available. If `True`, it waits for a terminal state.
114+
event_callback: Optional async callback function to be called after each event
115+
is processed in the background continuation.
116+
Mainly used for push notifications currently.
108117
109118
Returns:
110119
A tuple containing:
@@ -150,13 +159,17 @@ async def consume_and_break_on_interrupt(
150159
if should_interrupt:
151160
# Continue consuming the rest of the events in the background.
152161
# TODO: We should track all outstanding tasks to ensure they eventually complete.
153-
asyncio.create_task(self._continue_consuming(event_stream)) # noqa: RUF006
162+
asyncio.create_task( # noqa: RUF006
163+
self._continue_consuming(event_stream, event_callback)
164+
)
154165
interrupted = True
155166
break
156167
return await self.task_manager.get_task(), interrupted
157168

158169
async def _continue_consuming(
159-
self, event_stream: AsyncIterator[Event]
170+
self,
171+
event_stream: AsyncIterator[Event],
172+
event_callback: Callable[[], Awaitable[None]] | None = None,
160173
) -> None:
161174
"""Continues processing an event stream in a background task.
162175
@@ -165,6 +178,9 @@ async def _continue_consuming(
165178
166179
Args:
167180
event_stream: The remaining `AsyncIterator` of events from the consumer.
181+
event_callback: Optional async callback function to be called after each event is processed.
168182
"""
169183
async for event in event_stream:
170184
await self.task_manager.process(event)
185+
if event_callback:
186+
await event_callback()

tests/server/events/test_event_consumer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,59 @@ async def test_consume_all_continues_on_queue_empty_if_not_really_closed(
324324
assert mock_event_queue.is_closed.call_count == 1
325325

326326

327+
@pytest.mark.asyncio
328+
async def test_consume_all_handles_queue_empty_when_closed_python_version_agnostic(
329+
event_consumer: EventConsumer, mock_event_queue: AsyncMock, monkeypatch
330+
):
331+
"""Ensure consume_all stops with no events when queue is closed and dequeue_event raises asyncio.QueueEmpty (Python version-agnostic)."""
332+
# Make QueueClosed a distinct exception (not QueueEmpty) to emulate py3.13 semantics
333+
from a2a.server.events import event_consumer as ec
334+
335+
class QueueShutDown(Exception):
336+
pass
337+
338+
monkeypatch.setattr(ec, 'QueueClosed', QueueShutDown, raising=True)
339+
340+
# Simulate queue reporting closed while dequeue raises QueueEmpty
341+
mock_event_queue.dequeue_event.side_effect = asyncio.QueueEmpty(
342+
'closed/empty'
343+
)
344+
mock_event_queue.is_closed.return_value = True
345+
346+
consumed_events = []
347+
async for event in event_consumer.consume_all():
348+
consumed_events.append(event)
349+
350+
assert consumed_events == []
351+
mock_event_queue.dequeue_event.assert_called_once()
352+
mock_event_queue.is_closed.assert_called_once()
353+
354+
355+
@pytest.mark.asyncio
356+
async def test_consume_all_continues_on_queue_empty_when_not_closed(
357+
event_consumer: EventConsumer, mock_event_queue: AsyncMock, monkeypatch
358+
):
359+
"""Ensure consume_all continues after asyncio.QueueEmpty when queue is open, yielding the next (final) event."""
360+
# First dequeue raises QueueEmpty (transient empty), then a final Message arrives
361+
final = Message(role='agent', parts=[{'text': 'done'}], message_id='final')
362+
mock_event_queue.dequeue_event.side_effect = [
363+
asyncio.QueueEmpty('temporarily empty'),
364+
final,
365+
]
366+
mock_event_queue.is_closed.return_value = False
367+
368+
# Make the polling responsive in tests
369+
event_consumer._timeout = 0.001
370+
371+
consumed = []
372+
async for ev in event_consumer.consume_all():
373+
consumed.append(ev)
374+
375+
assert consumed == [final]
376+
assert mock_event_queue.dequeue_event.call_count == 2
377+
mock_event_queue.is_closed.assert_called_once()
378+
379+
327380
def test_agent_task_callback_sets_exception(event_consumer: EventConsumer):
328381
"""Test that agent_task_callback sets _exception if the task had one."""
329382
mock_task = MagicMock(spec=asyncio.Task)

tests/server/events/test_event_queue.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ async def test_enqueue_event_propagates_to_children(
169169

170170

171171
@pytest.mark.asyncio
172-
async def test_enqueue_event_when_closed(event_queue: EventQueue) -> None:
172+
async def test_enqueue_event_when_closed(
173+
event_queue: EventQueue, expected_queue_closed_exception
174+
) -> None:
173175
"""Test that no event is enqueued if the parent queue is closed."""
174176
await event_queue.close() # Close the queue first
175177

@@ -178,7 +180,7 @@ async def test_enqueue_event_when_closed(event_queue: EventQueue) -> None:
178180
await event_queue.enqueue_event(event)
179181

180182
# Verify the queue is still empty
181-
with pytest.raises(asyncio.QueueEmpty):
183+
with pytest.raises(expected_queue_closed_exception):
182184
await event_queue.dequeue_event(no_wait=True)
183185

184186
# Also verify child queues are not affected directly by parent's enqueue attempt when closed
@@ -192,7 +194,7 @@ async def test_enqueue_event_when_closed(event_queue: EventQueue) -> None:
192194
await (
193195
child_queue.close()
194196
) # ensure child is also seen as closed for this test's purpose
195-
with pytest.raises(asyncio.QueueEmpty):
197+
with pytest.raises(expected_queue_closed_exception):
196198
await child_queue.dequeue_event(no_wait=True)
197199

198200

@@ -214,7 +216,7 @@ async def test_dequeue_event_closed_and_empty_no_wait(
214216
with pytest.raises(expected_queue_closed_exception):
215217
event_queue.queue.get_nowait()
216218

217-
with pytest.raises(asyncio.QueueEmpty, match='Queue is closed.'):
219+
with pytest.raises(expected_queue_closed_exception):
218220
await event_queue.dequeue_event(no_wait=True)
219221

220222

@@ -230,7 +232,8 @@ async def test_dequeue_event_closed_and_empty_waits_then_raises(
230232

231233
# This test is tricky because await event_queue.dequeue_event() would hang if not for the close check.
232234
# The current implementation's dequeue_event checks `is_closed` first.
233-
# If closed and empty, it raises QueueEmpty immediately.
235+
# If closed and empty, it raises QueueEmpty immediately (on Python <= 3.12).
236+
# On Python 3.13+, this check is skipped and asyncio.Queue.get() raises QueueShutDown instead.
234237
# The "waits_then_raises" scenario described in the subtask implies the `get()` might wait.
235238
# However, the current code:
236239
# async with self._lock:
@@ -240,7 +243,7 @@ async def test_dequeue_event_closed_and_empty_waits_then_raises(
240243
# event = await self.queue.get() -> this line is not reached if closed and empty.
241244

242245
# So, for the current implementation, it will raise QueueEmpty immediately.
243-
with pytest.raises(asyncio.QueueEmpty, match='Queue is closed.'):
246+
with pytest.raises(expected_queue_closed_exception):
244247
await event_queue.dequeue_event(no_wait=False)
245248

246249
# If the implementation were to change to allow `await self.queue.get()`

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,134 @@ async def get_current_result():
405405
mock_agent_executor.execute.assert_awaited_once()
406406

407407

408+
@pytest.mark.asyncio
409+
async def test_on_message_send_with_push_notification_in_non_blocking_request():
410+
"""Test that push notification callback is called during background event processing for non-blocking requests."""
411+
mock_task_store = AsyncMock(spec=TaskStore)
412+
mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore)
413+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
414+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
415+
mock_push_sender = AsyncMock()
416+
417+
task_id = 'non_blocking_task_1'
418+
context_id = 'non_blocking_ctx_1'
419+
420+
# Create a task that will be returned after the first event
421+
initial_task = create_sample_task(
422+
task_id=task_id, context_id=context_id, status_state=TaskState.working
423+
)
424+
425+
# Create a final task that will be available during background processing
426+
final_task = create_sample_task(
427+
task_id=task_id, context_id=context_id, status_state=TaskState.completed
428+
)
429+
430+
mock_task_store.get.return_value = None
431+
432+
# Mock request context
433+
mock_request_context = MagicMock(spec=RequestContext)
434+
mock_request_context.task_id = task_id
435+
mock_request_context.context_id = context_id
436+
mock_request_context_builder.build.return_value = mock_request_context
437+
438+
request_handler = DefaultRequestHandler(
439+
agent_executor=mock_agent_executor,
440+
task_store=mock_task_store,
441+
push_config_store=mock_push_notification_store,
442+
request_context_builder=mock_request_context_builder,
443+
push_sender=mock_push_sender,
444+
)
445+
446+
# Configure push notification
447+
push_config = PushNotificationConfig(url='http://callback.com/push')
448+
message_config = MessageSendConfiguration(
449+
push_notification_config=push_config,
450+
accepted_output_modes=['text/plain'],
451+
blocking=False, # Non-blocking request
452+
)
453+
params = MessageSendParams(
454+
message=Message(
455+
role=Role.user,
456+
message_id='msg_non_blocking',
457+
parts=[],
458+
task_id=task_id,
459+
context_id=context_id,
460+
),
461+
configuration=message_config,
462+
)
463+
464+
# Mock ResultAggregator with custom behavior
465+
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
466+
467+
# First call returns the initial task and indicates interruption (non-blocking)
468+
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
469+
initial_task,
470+
True, # interrupted = True for non-blocking
471+
)
472+
473+
# Mock the current_result property to return the final task
474+
async def get_current_result():
475+
return final_task
476+
477+
type(mock_result_aggregator_instance).current_result = PropertyMock(
478+
return_value=get_current_result()
479+
)
480+
481+
# Track if the event_callback was passed to consume_and_break_on_interrupt
482+
event_callback_passed = False
483+
event_callback_received = None
484+
485+
async def mock_consume_and_break_on_interrupt(
486+
consumer, blocking=True, event_callback=None
487+
):
488+
nonlocal event_callback_passed, event_callback_received
489+
event_callback_passed = event_callback is not None
490+
event_callback_received = event_callback
491+
return initial_task, True # interrupted = True for non-blocking
492+
493+
mock_result_aggregator_instance.consume_and_break_on_interrupt = (
494+
mock_consume_and_break_on_interrupt
495+
)
496+
497+
with (
498+
patch(
499+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
500+
return_value=mock_result_aggregator_instance,
501+
),
502+
patch(
503+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
504+
return_value=initial_task,
505+
),
506+
patch(
507+
'a2a.server.request_handlers.default_request_handler.TaskManager.update_with_message',
508+
return_value=initial_task,
509+
),
510+
):
511+
# Execute the non-blocking request
512+
result = await request_handler.on_message_send(
513+
params, create_server_call_context()
514+
)
515+
516+
# Verify the result is the initial task (non-blocking behavior)
517+
assert result == initial_task
518+
519+
# Verify that the event_callback was passed to consume_and_break_on_interrupt
520+
assert event_callback_passed, (
521+
'event_callback should have been passed to consume_and_break_on_interrupt'
522+
)
523+
assert event_callback_received is not None, (
524+
'event_callback should not be None'
525+
)
526+
527+
# Verify that the push notification was sent with the final task
528+
mock_push_sender.send_notification.assert_called_with(final_task)
529+
530+
# Verify that the push notification config was stored
531+
mock_push_notification_store.set_info.assert_awaited_once_with(
532+
task_id, push_config
533+
)
534+
535+
408536
@pytest.mark.asyncio
409537
async def test_on_message_send_with_push_notification_no_existing_Task():
410538
"""Test on_message_send for new task sets push notification info if provided."""

0 commit comments

Comments
 (0)