Skip to content

Commit 2d7bcba

Browse files
committed
fix(stream): don't block event loop in EventQueue
1 parent 1107151 commit 2d7bcba

File tree

5 files changed

+323
-208
lines changed

5 files changed

+323
-208
lines changed

src/a2a/server/events/event_queue.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
1818
"""Type alias for events that can be enqueued."""
1919

20+
DEFAULT_MAX_QUEUE_SIZE = 1024
21+
2022

2123
@trace_class(kind=SpanKind.SERVER)
2224
class EventQueue:
@@ -27,27 +29,38 @@ class EventQueue:
2729
to create child queues that receive the same events.
2830
"""
2931

30-
def __init__(self) -> None:
32+
def __init__(self, max_queue_size=DEFAULT_MAX_QUEUE_SIZE) -> None:
3133
"""Initializes the EventQueue."""
32-
self.queue: asyncio.Queue[Event] = asyncio.Queue()
34+
35+
# Make sure the `asyncio.Queue` is bounded.
36+
# If it's unbounded (maxsize=0), then `queue.put()` never needs to wait,
37+
# and so the streaming won't work correctly.
38+
if max_queue_size <= 0:
39+
raise ValueError('max_queue_size must be greater than 0')
40+
41+
self.queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=max_queue_size)
3342
self._children: list[EventQueue] = []
3443
self._is_closed = False
3544
self._lock = asyncio.Lock()
3645
logger.debug('EventQueue initialized.')
3746

38-
def enqueue_event(self, event: Event):
47+
async def enqueue_event(self, event: Event):
3948
"""Enqueues an event to this queue and all its children.
4049
4150
Args:
4251
event: The event object to enqueue.
4352
"""
44-
if self._is_closed:
45-
logger.warning('Queue is closed. Event will not be enqueued.')
46-
return
53+
async with self._lock:
54+
if self._is_closed:
55+
logger.warning('Queue is closed. Event will not be enqueued.')
56+
return
57+
4758
logger.debug(f'Enqueuing event of type: {type(event)}')
48-
self.queue.put_nowait(event)
59+
60+
# Make sure to use put instead of put_nowait to avoid blocking the event loop.
61+
await self.queue.put(event)
4962
for child in self._children:
50-
child.enqueue_event(event)
63+
await child.enqueue_event(event)
5164

5265
async def dequeue_event(self, no_wait: bool = False) -> Event:
5366
"""Dequeues an event from the queue.

src/a2a/server/tasks/task_updater.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
3434
self.task_id = task_id
3535
self.context_id = context_id
3636

37-
def update_status(
37+
async def update_status(
3838
self,
3939
state: TaskState,
4040
message: Message | None = None,
@@ -52,7 +52,7 @@ def update_status(
5252
current_timestamp = (
5353
timestamp if timestamp else datetime.now(timezone.utc).isoformat()
5454
)
55-
self.event_queue.enqueue_event(
55+
await self.event_queue.enqueue_event(
5656
TaskStatusUpdateEvent(
5757
taskId=self.task_id,
5858
contextId=self.context_id,
@@ -65,7 +65,7 @@ def update_status(
6565
)
6666
)
6767

68-
def add_artifact(
68+
async def add_artifact(
6969
self,
7070
parts: list[Part],
7171
artifact_id: str = str(uuid.uuid4()),
@@ -82,7 +82,7 @@ def add_artifact(
8282
append: Optional boolean indicating if this chunk appends to a previous one.
8383
last_chunk: Optional boolean indicating if this is the last chunk.
8484
"""
85-
self.event_queue.enqueue_event(
85+
await self.event_queue.enqueue_event(
8686
TaskArtifactUpdateEvent(
8787
taskId=self.task_id,
8888
contextId=self.context_id,
@@ -95,32 +95,32 @@ def add_artifact(
9595
)
9696
)
9797

98-
def complete(self, message: Message | None = None):
98+
async def complete(self, message: Message | None = None):
9999
"""Marks the task as completed and publishes a final status update."""
100-
self.update_status(
100+
await self.update_status(
101101
TaskState.completed,
102102
message=message,
103103
final=True,
104104
)
105105

106-
def failed(self, message: Message | None = None):
106+
async def failed(self, message: Message | None = None):
107107
"""Marks the task as failed and publishes a final status update."""
108-
self.update_status(TaskState.failed, message=message, final=True)
108+
await self.update_status(TaskState.failed, message=message, final=True)
109109

110-
def reject(self, message: Message | None = None):
110+
async def reject(self, message: Message | None = None):
111111
"""Marks the task as rejected and publishes a final status update."""
112-
self.update_status(TaskState.rejected, message=message, final=True)
112+
await self.update_status(TaskState.rejected, message=message, final=True)
113113

114-
def submit(self, message: Message | None = None):
114+
async def submit(self, message: Message | None = None):
115115
"""Marks the task as submitted and publishes a status update."""
116-
self.update_status(
116+
await self.update_status(
117117
TaskState.submitted,
118118
message=message,
119119
)
120120

121-
def start_work(self, message: Message | None = None):
121+
async def start_work(self, message: Message | None = None):
122122
"""Marks the task as working and publishes a status update."""
123-
self.update_status(
123+
await self.update_status(
124124
TaskState.working,
125125
message=message,
126126
)

tests/server/events/test_event_queue.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def event_queue() -> EventQueue:
3939
async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:
4040
"""Test that an event can be enqueued and dequeued."""
4141
event = Message(**MESSAGE_PAYLOAD)
42-
event_queue.enqueue_event(event)
42+
await event_queue.enqueue_event(event)
4343
dequeued_event = await event_queue.dequeue_event()
4444
assert dequeued_event == event
4545

@@ -48,7 +48,7 @@ async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:
4848
async def test_dequeue_event_no_wait(event_queue: EventQueue) -> None:
4949
"""Test dequeue_event with no_wait=True."""
5050
event = Task(**MINIMAL_TASK)
51-
event_queue.enqueue_event(event)
51+
await event_queue.enqueue_event(event)
5252
dequeued_event = await event_queue.dequeue_event(no_wait=True)
5353
assert dequeued_event == event
5454

@@ -71,7 +71,7 @@ async def test_dequeue_event_wait(event_queue: EventQueue) -> None:
7171
status=TaskStatus(state=TaskState.working),
7272
final=True,
7373
)
74-
event_queue.enqueue_event(event)
74+
await event_queue.enqueue_event(event)
7575
dequeued_event = await event_queue.dequeue_event()
7676
assert dequeued_event == event
7777

@@ -84,7 +84,7 @@ async def test_task_done(event_queue: EventQueue) -> None:
8484
contextId='session-xyz',
8585
artifact=Artifact(artifactId='11', parts=[Part(TextPart(text='text'))]),
8686
)
87-
event_queue.enqueue_event(event)
87+
await event_queue.enqueue_event(event)
8888
_ = await event_queue.dequeue_event()
8989
event_queue.task_done()
9090

@@ -99,6 +99,6 @@ async def test_enqueue_different_event_types(
9999
JSONRPCError(code=111, message='rpc error'),
100100
]
101101
for event in events:
102-
event_queue.enqueue_event(event)
102+
await event_queue.enqueue_event(event)
103103
dequeued_event = await event_queue.dequeue_event()
104104
assert dequeued_event == event
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import time
2+
3+
import pytest
4+
5+
from a2a.server.agent_execution import AgentExecutor, RequestContext
6+
from a2a.server.events import EventQueue
7+
from a2a.server.request_handlers import DefaultRequestHandler
8+
from a2a.server.tasks import InMemoryTaskStore, TaskUpdater
9+
from a2a.types import (
10+
Message,
11+
MessageSendParams,
12+
Part,
13+
Role,
14+
TaskState,
15+
TextPart,
16+
)
17+
18+
19+
class DummyAgentExecutor(AgentExecutor):
20+
async def execute(self, context: RequestContext, event_queue: EventQueue):
21+
task_updater = TaskUpdater(
22+
event_queue, context.task_id, context.context_id
23+
)
24+
async for i in self._run():
25+
parts = [Part(root=TextPart(text=f'Event {i}'))]
26+
try:
27+
await task_updater.update_status(
28+
TaskState.working,
29+
message=task_updater.new_agent_message(parts),
30+
)
31+
except RuntimeError:
32+
# Stop processing when the event loop is closed
33+
break
34+
35+
async def _run(self):
36+
for i in range(1_000_000): # Simulate a long-running stream
37+
yield i
38+
39+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
40+
pass
41+
42+
43+
@pytest.mark.asyncio
44+
async def test_on_message_send_stream():
45+
request_handler = DefaultRequestHandler(
46+
DummyAgentExecutor(), InMemoryTaskStore()
47+
)
48+
message_params = MessageSendParams(
49+
message=Message(
50+
role=Role.user,
51+
messageId='msg-123',
52+
parts=[Part(root=TextPart(text='How are you?'))],
53+
),
54+
)
55+
56+
async def consume_stream():
57+
events = []
58+
async for event in request_handler.on_message_send_stream(
59+
message_params
60+
):
61+
events.append(event)
62+
if len(events) >= 3:
63+
break # Stop after a few events
64+
65+
return events
66+
67+
# Consume first 3 events from the stream and measure time
68+
start = time.perf_counter()
69+
events = await consume_stream()
70+
elapsed = time.perf_counter() - start
71+
72+
# Assert we received events quickly
73+
assert len(events) == 3
74+
assert elapsed < 0.5
75+
76+
texts = [p.root.text for e in events for p in e.status.message.parts]
77+
assert texts == ['Event 0', 'Event 1', 'Event 2']

0 commit comments

Comments
 (0)