Skip to content

Commit d2cbb05

Browse files
Merge branch 'main' into add-fastapi-app
2 parents 920f97f + 5e7d418 commit d2cbb05

File tree

9 files changed

+329
-217
lines changed

9 files changed

+329
-217
lines changed

.github/workflows/update-a2a-types.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ jobs:
6868
token: ${{ secrets.A2A_BOT_PAT }}
6969
committer: "a2a-bot <[email protected]>"
7070
author: "a2a-bot <[email protected]>"
71-
commit-message: "${{github.event.client_payload.message}} 🤖"
72-
title: "${{github.event.client_payload.message}} 🤖"
71+
commit-message: "chore: Update A2A types from specification 🤖"
72+
title: "chore: Update A2A types from specification 🤖"
7373
body: |
7474
This PR updates `src/a2a/types.py` based on the latest `specification/json/a2a.json` from [google-a2a/A2A](https://github.com/google-a2a/A2A/commit/${{ github.event.client_payload.sha }}).
7575
branch: "auto-update-a2a-types-${{ github.event.client_payload.sha }}"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
<h2 align="center">
1111
<img src="https://raw.githubusercontent.com/google-a2a/A2A/refs/heads/main/docs/assets/a2a-logo-black.svg" width="256" alt="A2A Logo"/>
1212
</h2>
13-
<h3 align="center">A Python library that helps run agentic applications as A2AServers following the <a href="https://google.github.io/A2A">Agent2Agent (A2A) Protocol</a>.</h3>
13+
<h3 align="center">A Python library that helps run agentic applications as A2AServers following the <a href="https://google-a2a.github.io/A2A">Agent2Agent (A2A) Protocol</a>.</h3>
1414
</html>
1515

1616
<!-- markdownlint-enable no-inline-html -->

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ classifiers = [
3232
]
3333

3434
[project.urls]
35-
homepage = "https://google.github.io/A2A/"
35+
homepage = "https://google-a2a.github.io/A2A/"
3636
repository = "https://github.com/google-a2a/a2a-python"
3737
changelog = "https://github.com/google-a2a/a2a-python/blob/main/CHANGELOG.md"
38-
documentation = "https://google.github.io/A2A/"
38+
documentation = "https://google-a2a.github.io/A2A/sdk/python/"
3939

4040
[tool.hatch.build.targets.wheel]
4141
packages = ["src/a2a"]

src/a2a/server/events/event_queue.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
1818
"""Type alias for events that can be enqueued."""
1919

20+
DEFAULT_MAX_QUEUE_SIZE = 1024
21+
2022

2123
@trace_class(kind=SpanKind.SERVER)
2224
class EventQueue:
@@ -27,27 +29,37 @@ class EventQueue:
2729
to create child queues that receive the same events.
2830
"""
2931

30-
def __init__(self) -> None:
32+
def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None:
3133
"""Initializes the EventQueue."""
32-
self.queue: asyncio.Queue[Event] = asyncio.Queue()
34+
# Make sure the `asyncio.Queue` is bounded.
35+
# If it's unbounded (maxsize=0), then `queue.put()` never needs to wait,
36+
# and so the streaming won't work correctly.
37+
if max_queue_size <= 0:
38+
raise ValueError('max_queue_size must be greater than 0')
39+
40+
self.queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=max_queue_size)
3341
self._children: list[EventQueue] = []
3442
self._is_closed = False
3543
self._lock = asyncio.Lock()
3644
logger.debug('EventQueue initialized.')
3745

38-
def enqueue_event(self, event: Event):
46+
async def enqueue_event(self, event: Event):
3947
"""Enqueues an event to this queue and all its children.
4048
4149
Args:
4250
event: The event object to enqueue.
4351
"""
44-
if self._is_closed:
45-
logger.warning('Queue is closed. Event will not be enqueued.')
46-
return
52+
async with self._lock:
53+
if self._is_closed:
54+
logger.warning('Queue is closed. Event will not be enqueued.')
55+
return
56+
4757
logger.debug(f'Enqueuing event of type: {type(event)}')
48-
self.queue.put_nowait(event)
58+
59+
# Make sure to use put instead of put_nowait to avoid blocking the event loop.
60+
await self.queue.put(event)
4961
for child in self._children:
50-
child.enqueue_event(event)
62+
await child.enqueue_event(event)
5163

5264
async def dequeue_event(self, no_wait: bool = False) -> Event:
5365
"""Dequeues an event from the queue.

src/a2a/server/tasks/task_updater.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
3434
self.task_id = task_id
3535
self.context_id = context_id
3636

37-
def update_status(
37+
async def update_status(
3838
self,
3939
state: TaskState,
4040
message: Message | None = None,
@@ -52,7 +52,7 @@ def update_status(
5252
current_timestamp = (
5353
timestamp if timestamp else datetime.now(timezone.utc).isoformat()
5454
)
55-
self.event_queue.enqueue_event(
55+
await self.event_queue.enqueue_event(
5656
TaskStatusUpdateEvent(
5757
taskId=self.task_id,
5858
contextId=self.context_id,
@@ -65,7 +65,7 @@ def update_status(
6565
)
6666
)
6767

68-
def add_artifact(
68+
async def add_artifact(
6969
self,
7070
parts: list[Part],
7171
artifact_id: str = str(uuid.uuid4()),
@@ -82,7 +82,7 @@ def add_artifact(
8282
append: Optional boolean indicating if this chunk appends to a previous one.
8383
last_chunk: Optional boolean indicating if this is the last chunk.
8484
"""
85-
self.event_queue.enqueue_event(
85+
await self.event_queue.enqueue_event(
8686
TaskArtifactUpdateEvent(
8787
taskId=self.task_id,
8888
contextId=self.context_id,
@@ -95,32 +95,34 @@ def add_artifact(
9595
)
9696
)
9797

98-
def complete(self, message: Message | None = None):
98+
async def complete(self, message: Message | None = None):
9999
"""Marks the task as completed and publishes a final status update."""
100-
self.update_status(
100+
await self.update_status(
101101
TaskState.completed,
102102
message=message,
103103
final=True,
104104
)
105105

106-
def failed(self, message: Message | None = None):
106+
async def failed(self, message: Message | None = None):
107107
"""Marks the task as failed and publishes a final status update."""
108-
self.update_status(TaskState.failed, message=message, final=True)
108+
await self.update_status(TaskState.failed, message=message, final=True)
109109

110-
def reject(self, message: Message | None = None):
110+
async def reject(self, message: Message | None = None):
111111
"""Marks the task as rejected and publishes a final status update."""
112-
self.update_status(TaskState.rejected, message=message, final=True)
112+
await self.update_status(
113+
TaskState.rejected, message=message, final=True
114+
)
113115

114-
def submit(self, message: Message | None = None):
116+
async def submit(self, message: Message | None = None):
115117
"""Marks the task as submitted and publishes a status update."""
116-
self.update_status(
118+
await self.update_status(
117119
TaskState.submitted,
118120
message=message,
119121
)
120122

121-
def start_work(self, message: Message | None = None):
123+
async def start_work(self, message: Message | None = None):
122124
"""Marks the task as working and publishes a status update."""
123-
self.update_status(
125+
await self.update_status(
124126
TaskState.working,
125127
message=message,
126128
)

src/a2a/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,15 @@ class AgentSkill(BaseModel):
100100
"""
101101
The set of interaction modes that the skill supports
102102
(if different than the default).
103-
Supported mime types for input.
103+
Supported media types for input.
104104
"""
105105
name: str
106106
"""
107107
Human readable name of the skill.
108108
"""
109109
outputModes: list[str] | None = None
110110
"""
111-
Supported mime types for output.
111+
Supported media types for output.
112112
"""
113113
tags: list[str]
114114
"""
@@ -1372,11 +1372,11 @@ class AgentCard(BaseModel):
13721372
defaultInputModes: list[str]
13731373
"""
13741374
The set of interaction modes that the agent supports across all skills. This can be overridden per-skill.
1375-
Supported mime types for input.
1375+
Supported media types for input.
13761376
"""
13771377
defaultOutputModes: list[str]
13781378
"""
1379-
Supported mime types for output.
1379+
Supported media types for output.
13801380
"""
13811381
description: str
13821382
"""

tests/server/events/test_event_queue.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def event_queue() -> EventQueue:
3939
async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:
4040
"""Test that an event can be enqueued and dequeued."""
4141
event = Message(**MESSAGE_PAYLOAD)
42-
event_queue.enqueue_event(event)
42+
await event_queue.enqueue_event(event)
4343
dequeued_event = await event_queue.dequeue_event()
4444
assert dequeued_event == event
4545

@@ -48,7 +48,7 @@ async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:
4848
async def test_dequeue_event_no_wait(event_queue: EventQueue) -> None:
4949
"""Test dequeue_event with no_wait=True."""
5050
event = Task(**MINIMAL_TASK)
51-
event_queue.enqueue_event(event)
51+
await event_queue.enqueue_event(event)
5252
dequeued_event = await event_queue.dequeue_event(no_wait=True)
5353
assert dequeued_event == event
5454

@@ -71,7 +71,7 @@ async def test_dequeue_event_wait(event_queue: EventQueue) -> None:
7171
status=TaskStatus(state=TaskState.working),
7272
final=True,
7373
)
74-
event_queue.enqueue_event(event)
74+
await event_queue.enqueue_event(event)
7575
dequeued_event = await event_queue.dequeue_event()
7676
assert dequeued_event == event
7777

@@ -84,7 +84,7 @@ async def test_task_done(event_queue: EventQueue) -> None:
8484
contextId='session-xyz',
8585
artifact=Artifact(artifactId='11', parts=[Part(TextPart(text='text'))]),
8686
)
87-
event_queue.enqueue_event(event)
87+
await event_queue.enqueue_event(event)
8888
_ = await event_queue.dequeue_event()
8989
event_queue.task_done()
9090

@@ -99,6 +99,6 @@ async def test_enqueue_different_event_types(
9999
JSONRPCError(code=111, message='rpc error'),
100100
]
101101
for event in events:
102-
event_queue.enqueue_event(event)
102+
await event_queue.enqueue_event(event)
103103
dequeued_event = await event_queue.dequeue_event()
104104
assert dequeued_event == event
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import time
2+
3+
import pytest
4+
5+
from a2a.server.agent_execution import AgentExecutor, RequestContext
6+
from a2a.server.events import EventQueue
7+
from a2a.server.request_handlers import DefaultRequestHandler
8+
from a2a.server.tasks import InMemoryTaskStore, TaskUpdater
9+
from a2a.types import (
10+
Message,
11+
MessageSendParams,
12+
Part,
13+
Role,
14+
TaskState,
15+
TextPart,
16+
)
17+
18+
19+
class DummyAgentExecutor(AgentExecutor):
20+
async def execute(self, context: RequestContext, event_queue: EventQueue):
21+
task_updater = TaskUpdater(
22+
event_queue, context.task_id, context.context_id
23+
)
24+
async for i in self._run():
25+
parts = [Part(root=TextPart(text=f'Event {i}'))]
26+
try:
27+
await task_updater.update_status(
28+
TaskState.working,
29+
message=task_updater.new_agent_message(parts),
30+
)
31+
except RuntimeError:
32+
# Stop processing when the event loop is closed
33+
break
34+
35+
async def _run(self):
36+
for i in range(1_000_000): # Simulate a long-running stream
37+
yield i
38+
39+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
40+
pass
41+
42+
43+
@pytest.mark.asyncio
44+
async def test_on_message_send_stream():
45+
request_handler = DefaultRequestHandler(
46+
DummyAgentExecutor(), InMemoryTaskStore()
47+
)
48+
message_params = MessageSendParams(
49+
message=Message(
50+
role=Role.user,
51+
messageId='msg-123',
52+
parts=[Part(root=TextPart(text='How are you?'))],
53+
),
54+
)
55+
56+
async def consume_stream():
57+
events = []
58+
async for event in request_handler.on_message_send_stream(
59+
message_params
60+
):
61+
events.append(event)
62+
if len(events) >= 3:
63+
break # Stop after a few events
64+
65+
return events
66+
67+
# Consume first 3 events from the stream and measure time
68+
start = time.perf_counter()
69+
events = await consume_stream()
70+
elapsed = time.perf_counter() - start
71+
72+
# Assert we received events quickly
73+
assert len(events) == 3
74+
assert elapsed < 0.5
75+
76+
texts = [p.root.text for e in events for p in e.status.message.parts]
77+
assert texts == ['Event 0', 'Event 1', 'Event 2']

0 commit comments

Comments
 (0)