Skip to content

Commit efd9080

Browse files
authored
fix(stream): don't block event loop in EventQueue (#151)
# Description Fixes #111 🦕 > [!WARNING] > Backwards incompatible change, as `EventQueue.enqueue_event()` becomes async. The implementation of the `AgentExecutor` usually looks like this: ``` class SomeAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: task_updater = TaskUpdater(event_queue, context.task_id, context.context_id) task_updater.start_work() async for event in self._run_async(...): if event.is_final_response(): task_updater.add_artifact(parts) task_updater.complete() break task_updater.update_status( TaskState.working, message=task_updater.new_agent_message(parts), ) ``` The issue is that the loop inside `execute()` calls `task_updater.update_status()` synchronously. If `update_status` is not an async function, then the `execute()` method does not yield control back to the event loop during each iteration. That’s why everything gets blocked until the loop finishes. So the first step is to make sure that `update_status()` and `event_queue.enqueue_event()` are async functions. Still, this is not enough, since `EventQueue` uses `put_nowait()`, so it never actually suspends. So the 2nd step is to switch from `queue.put_nowait()` to `queue.put()`. This is still not enough. By default, an `asyncio.Queue` is unbounded , so `queue.put` never actually suspends the producer. Therefore, the producer loop in `SomeAgentExecutor` still runs full-speed without yielding. So the 3rd step is to make the queue bounded. I have set the size to 1024, but it can be actually lower or higher if needed.
1 parent 98c200e commit efd9080

File tree

5 files changed

+323
-208
lines changed

5 files changed

+323
-208
lines changed

src/a2a/server/events/event_queue.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
1818
"""Type alias for events that can be enqueued."""
1919

20+
DEFAULT_MAX_QUEUE_SIZE = 1024
21+
2022

2123
@trace_class(kind=SpanKind.SERVER)
2224
class EventQueue:
@@ -27,27 +29,38 @@ class EventQueue:
2729
to create child queues that receive the same events.
2830
"""
2931

30-
def __init__(self) -> None:
32+
def __init__(self, max_queue_size=DEFAULT_MAX_QUEUE_SIZE) -> None:
3133
"""Initializes the EventQueue."""
32-
self.queue: asyncio.Queue[Event] = asyncio.Queue()
34+
35+
# Make sure the `asyncio.Queue` is bounded.
36+
# If it's unbounded (maxsize=0), then `queue.put()` never needs to wait,
37+
# and so the streaming won't work correctly.
38+
if max_queue_size <= 0:
39+
raise ValueError('max_queue_size must be greater than 0')
40+
41+
self.queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=max_queue_size)
3342
self._children: list[EventQueue] = []
3443
self._is_closed = False
3544
self._lock = asyncio.Lock()
3645
logger.debug('EventQueue initialized.')
3746

38-
def enqueue_event(self, event: Event):
47+
async def enqueue_event(self, event: Event):
3948
"""Enqueues an event to this queue and all its children.
4049
4150
Args:
4251
event: The event object to enqueue.
4352
"""
44-
if self._is_closed:
45-
logger.warning('Queue is closed. Event will not be enqueued.')
46-
return
53+
async with self._lock:
54+
if self._is_closed:
55+
logger.warning('Queue is closed. Event will not be enqueued.')
56+
return
57+
4758
logger.debug(f'Enqueuing event of type: {type(event)}')
48-
self.queue.put_nowait(event)
59+
60+
# Make sure to use put instead of put_nowait to avoid blocking the event loop.
61+
await self.queue.put(event)
4962
for child in self._children:
50-
child.enqueue_event(event)
63+
await child.enqueue_event(event)
5164

5265
async def dequeue_event(self, no_wait: bool = False) -> Event:
5366
"""Dequeues an event from the queue.

src/a2a/server/tasks/task_updater.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
3434
self.task_id = task_id
3535
self.context_id = context_id
3636

37-
def update_status(
37+
async def update_status(
3838
self,
3939
state: TaskState,
4040
message: Message | None = None,
@@ -52,7 +52,7 @@ def update_status(
5252
current_timestamp = (
5353
timestamp if timestamp else datetime.now(timezone.utc).isoformat()
5454
)
55-
self.event_queue.enqueue_event(
55+
await self.event_queue.enqueue_event(
5656
TaskStatusUpdateEvent(
5757
taskId=self.task_id,
5858
contextId=self.context_id,
@@ -65,7 +65,7 @@ def update_status(
6565
)
6666
)
6767

68-
def add_artifact(
68+
async def add_artifact(
6969
self,
7070
parts: list[Part],
7171
artifact_id: str = str(uuid.uuid4()),
@@ -82,7 +82,7 @@ def add_artifact(
8282
append: Optional boolean indicating if this chunk appends to a previous one.
8383
last_chunk: Optional boolean indicating if this is the last chunk.
8484
"""
85-
self.event_queue.enqueue_event(
85+
await self.event_queue.enqueue_event(
8686
TaskArtifactUpdateEvent(
8787
taskId=self.task_id,
8888
contextId=self.context_id,
@@ -95,32 +95,32 @@ def add_artifact(
9595
)
9696
)
9797

98-
def complete(self, message: Message | None = None):
98+
async def complete(self, message: Message | None = None):
9999
"""Marks the task as completed and publishes a final status update."""
100-
self.update_status(
100+
await self.update_status(
101101
TaskState.completed,
102102
message=message,
103103
final=True,
104104
)
105105

106-
def failed(self, message: Message | None = None):
106+
async def failed(self, message: Message | None = None):
107107
"""Marks the task as failed and publishes a final status update."""
108-
self.update_status(TaskState.failed, message=message, final=True)
108+
await self.update_status(TaskState.failed, message=message, final=True)
109109

110-
def reject(self, message: Message | None = None):
110+
async def reject(self, message: Message | None = None):
111111
"""Marks the task as rejected and publishes a final status update."""
112-
self.update_status(TaskState.rejected, message=message, final=True)
112+
await self.update_status(TaskState.rejected, message=message, final=True)
113113

114-
def submit(self, message: Message | None = None):
114+
async def submit(self, message: Message | None = None):
115115
"""Marks the task as submitted and publishes a status update."""
116-
self.update_status(
116+
await self.update_status(
117117
TaskState.submitted,
118118
message=message,
119119
)
120120

121-
def start_work(self, message: Message | None = None):
121+
async def start_work(self, message: Message | None = None):
122122
"""Marks the task as working and publishes a status update."""
123-
self.update_status(
123+
await self.update_status(
124124
TaskState.working,
125125
message=message,
126126
)

tests/server/events/test_event_queue.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def event_queue() -> EventQueue:
3939
async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:
4040
"""Test that an event can be enqueued and dequeued."""
4141
event = Message(**MESSAGE_PAYLOAD)
42-
event_queue.enqueue_event(event)
42+
await event_queue.enqueue_event(event)
4343
dequeued_event = await event_queue.dequeue_event()
4444
assert dequeued_event == event
4545

@@ -48,7 +48,7 @@ async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:
4848
async def test_dequeue_event_no_wait(event_queue: EventQueue) -> None:
4949
"""Test dequeue_event with no_wait=True."""
5050
event = Task(**MINIMAL_TASK)
51-
event_queue.enqueue_event(event)
51+
await event_queue.enqueue_event(event)
5252
dequeued_event = await event_queue.dequeue_event(no_wait=True)
5353
assert dequeued_event == event
5454

@@ -71,7 +71,7 @@ async def test_dequeue_event_wait(event_queue: EventQueue) -> None:
7171
status=TaskStatus(state=TaskState.working),
7272
final=True,
7373
)
74-
event_queue.enqueue_event(event)
74+
await event_queue.enqueue_event(event)
7575
dequeued_event = await event_queue.dequeue_event()
7676
assert dequeued_event == event
7777

@@ -84,7 +84,7 @@ async def test_task_done(event_queue: EventQueue) -> None:
8484
contextId='session-xyz',
8585
artifact=Artifact(artifactId='11', parts=[Part(TextPart(text='text'))]),
8686
)
87-
event_queue.enqueue_event(event)
87+
await event_queue.enqueue_event(event)
8888
_ = await event_queue.dequeue_event()
8989
event_queue.task_done()
9090

@@ -99,6 +99,6 @@ async def test_enqueue_different_event_types(
9999
JSONRPCError(code=111, message='rpc error'),
100100
]
101101
for event in events:
102-
event_queue.enqueue_event(event)
102+
await event_queue.enqueue_event(event)
103103
dequeued_event = await event_queue.dequeue_event()
104104
assert dequeued_event == event
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import time
2+
3+
import pytest
4+
5+
from a2a.server.agent_execution import AgentExecutor, RequestContext
6+
from a2a.server.events import EventQueue
7+
from a2a.server.request_handlers import DefaultRequestHandler
8+
from a2a.server.tasks import InMemoryTaskStore, TaskUpdater
9+
from a2a.types import (
10+
Message,
11+
MessageSendParams,
12+
Part,
13+
Role,
14+
TaskState,
15+
TextPart,
16+
)
17+
18+
19+
class DummyAgentExecutor(AgentExecutor):
20+
async def execute(self, context: RequestContext, event_queue: EventQueue):
21+
task_updater = TaskUpdater(
22+
event_queue, context.task_id, context.context_id
23+
)
24+
async for i in self._run():
25+
parts = [Part(root=TextPart(text=f'Event {i}'))]
26+
try:
27+
await task_updater.update_status(
28+
TaskState.working,
29+
message=task_updater.new_agent_message(parts),
30+
)
31+
except RuntimeError:
32+
# Stop processing when the event loop is closed
33+
break
34+
35+
async def _run(self):
36+
for i in range(1_000_000): # Simulate a long-running stream
37+
yield i
38+
39+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
40+
pass
41+
42+
43+
@pytest.mark.asyncio
44+
async def test_on_message_send_stream():
45+
request_handler = DefaultRequestHandler(
46+
DummyAgentExecutor(), InMemoryTaskStore()
47+
)
48+
message_params = MessageSendParams(
49+
message=Message(
50+
role=Role.user,
51+
messageId='msg-123',
52+
parts=[Part(root=TextPart(text='How are you?'))],
53+
),
54+
)
55+
56+
async def consume_stream():
57+
events = []
58+
async for event in request_handler.on_message_send_stream(
59+
message_params
60+
):
61+
events.append(event)
62+
if len(events) >= 3:
63+
break # Stop after a few events
64+
65+
return events
66+
67+
# Consume first 3 events from the stream and measure time
68+
start = time.perf_counter()
69+
events = await consume_stream()
70+
elapsed = time.perf_counter() - start
71+
72+
# Assert we received events quickly
73+
assert len(events) == 3
74+
assert elapsed < 0.5
75+
76+
texts = [p.root.text for e in events for p in e.status.message.parts]
77+
assert texts == ['Event 0', 'Event 1', 'Event 2']

0 commit comments

Comments
 (0)