Skip to content

Commit abf68d0

Browse files
pstephengoogleGerrit Code Review
authored andcommitted
Merge changes I78b86be0,I50f02088 into main
* changes: Run `uv run ruff format` from repo root Turn Message -> str into a resuable utility method
2 parents cbbb7bb + 946381c commit abf68d0

File tree

10 files changed

+91
-23
lines changed

10 files changed

+91
-23
lines changed

examples/helloworld/__main__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
task_store=InMemoryTaskStore(),
3838
)
3939

40-
server = A2AStarletteApplication(agent_card=agent_card, http_handler=request_handler)
40+
server = A2AStarletteApplication(
41+
agent_card=agent_card, http_handler=request_handler
42+
)
4143
import uvicorn
4244

4345
uvicorn.run(server.build(), host='0.0.0.0', port=9999)

examples/langgraph/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def main(host: str, port: int):
3636
)
3737

3838
server = A2AStarletteApplication(
39-
agent_card=get_agent_card(host, port),
40-
http_handler=request_handler)
39+
agent_card=get_agent_card(host, port), http_handler=request_handler
40+
)
4141
import uvicorn
4242

4343
uvicorn.run(server.build(), host=host, port=port)

src/a2a/server/agent_execution/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,3 @@
22
from a2a.server.agent_execution.context import RequestContext
33

44
__all__ = ['AgentExecutor', 'RequestContext']
5-

src/a2a/server/agent_execution/context.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Task,
88
TextPart,
99
)
10+
from a2a.utils import get_message_text
1011
from a2a.utils.errors import ServerError
1112

1213

@@ -48,11 +49,7 @@ def get_user_input(self, delimiter='\n') -> str:
4849
if not self._params:
4950
return ''
5051

51-
parts = []
52-
for part in self._params.message.parts:
53-
if isinstance(part.root, TextPart):
54-
parts.append(part.root.text)
55-
return delimiter.join(parts)
52+
return get_message_text(self._params.message, delimiter)
5653

5754
def attach_related_task(self, task: Task):
5855
self._related_tasks.append(task)

src/a2a/utils/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,20 @@
44
build_text_artifact,
55
create_task_obj,
66
)
7-
from a2a.utils.message import new_agent_text_message
7+
from a2a.utils.message import (
8+
get_message_text,
9+
get_text_parts,
10+
new_agent_text_message,
11+
)
812
from a2a.utils.task import new_task
913

1014

1115
__all__ = [
1216
'append_artifact_to_task',
1317
'build_text_artifact',
1418
'create_task_obj',
19+
'get_message_text',
20+
'get_text_parts',
1521
'new_agent_text_message',
1622
'new_task',
1723
'new_text_artifact',

src/a2a/utils/message.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,12 @@ def new_agent_text_message(
1919
taskId=task_id,
2020
contextId=context_id,
2121
)
22+
23+
24+
def get_text_parts(parts: list[Part]) -> list[str]:
25+
"""Return all text parts from a list of parts."""
26+
return [part.root.text for part in parts if isinstance(part.root, TextPart)]
27+
28+
29+
def get_message_text(message: Message, delimiter='\n') -> str:
30+
return delimiter.join(get_text_parts(message.parts))

tests/server/events/test_event_consumer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import asyncio
32
import pytest
43
from unittest.mock import AsyncMock, MagicMock, patch
@@ -36,16 +35,17 @@
3635
'messageId': '111',
3736
}
3837

38+
3939
@pytest.fixture
4040
def mock_event_queue():
4141
return AsyncMock(spec=EventQueue)
4242

43+
4344
@pytest.fixture
44-
def event_consumer(
45-
mock_event_queue: EventQueue
46-
):
45+
def event_consumer(mock_event_queue: EventQueue):
4746
return EventConsumer(queue=mock_event_queue)
4847

48+
4949
@pytest.mark.asyncio
5050
async def test_consume_one_task_event(
5151
event_consumer: MagicMock,
@@ -57,6 +57,7 @@ async def test_consume_one_task_event(
5757
assert result == task_event
5858
mock_event_queue.task_done.assert_called_once()
5959

60+
6061
@pytest.mark.asyncio
6162
async def test_consume_one_message_event(
6263
event_consumer: MagicMock,
@@ -68,6 +69,7 @@ async def test_consume_one_message_event(
6869
assert result == message_event
6970
mock_event_queue.task_done.assert_called_once()
7071

72+
7173
@pytest.mark.asyncio
7274
async def test_consume_one_a2a_error_event(
7375
event_consumer: MagicMock,
@@ -79,6 +81,7 @@ async def test_consume_one_a2a_error_event(
7981
assert result == error_event
8082
mock_event_queue.task_done.assert_called_once()
8183

84+
8285
@pytest.mark.asyncio
8386
async def test_consume_one_jsonrpc_error_event(
8487
event_consumer: MagicMock,
@@ -90,6 +93,7 @@ async def test_consume_one_jsonrpc_error_event(
9093
assert result == error_event
9194
mock_event_queue.task_done.assert_called_once()
9295

96+
9397
@pytest.mark.asyncio
9498
async def test_consume_one_queue_empty(
9599
event_consumer: MagicMock,
@@ -103,6 +107,7 @@ async def test_consume_one_queue_empty(
103107
pass
104108
mock_event_queue.task_done.assert_not_called()
105109

110+
106111
@pytest.mark.asyncio
107112
async def test_consume_all_multiple_events(
108113
event_consumer: MagicMock,
@@ -125,12 +130,14 @@ async def test_consume_all_multiple_events(
125130
),
126131
]
127132
cursor = 0
133+
128134
async def mock_dequeue() -> Any:
129135
nonlocal cursor
130136
if cursor < len(events):
131137
event = events[cursor]
132138
cursor += 1
133139
return event
140+
134141
mock_event_queue.dequeue_event = mock_dequeue
135142
consumed_events: list[Any] = []
136143
async for event in event_consumer.consume_all():
@@ -141,6 +148,7 @@ async def mock_dequeue() -> Any:
141148
assert consumed_events[2] == events[2]
142149
assert mock_event_queue.task_done.call_count == 3
143150

151+
144152
@pytest.mark.asyncio
145153
async def test_consume_until_message(
146154
event_consumer: MagicMock,
@@ -164,12 +172,14 @@ async def test_consume_until_message(
164172
),
165173
]
166174
cursor = 0
175+
167176
async def mock_dequeue() -> Any:
168177
nonlocal cursor
169178
if cursor < len(events):
170179
event = events[cursor]
171180
cursor += 1
172181
return event
182+
173183
mock_event_queue.dequeue_event = mock_dequeue
174184
consumed_events: list[Any] = []
175185
async for event in event_consumer.consume_all():
@@ -180,6 +190,7 @@ async def mock_dequeue() -> Any:
180190
assert consumed_events[2] == events[2]
181191
assert mock_event_queue.task_done.call_count == 3
182192

193+
183194
@pytest.mark.asyncio
184195
async def test_consume_message_events(
185196
event_consumer: MagicMock,
@@ -190,12 +201,14 @@ async def test_consume_message_events(
190201
Message(**MESSAGE_PAYLOAD, final=True),
191202
]
192203
cursor = 0
204+
193205
async def mock_dequeue() -> Any:
194206
nonlocal cursor
195207
if cursor < len(events):
196208
event = events[cursor]
197209
cursor += 1
198210
return event
211+
199212
mock_event_queue.dequeue_event = mock_dequeue
200213
consumed_events: list[Any] = []
201214
async for event in event_consumer.consume_all():

tests/server/events/test_event_queue.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import asyncio
32
import pytest
43
from a2a.server.events.event_queue import EventQueue
@@ -17,6 +16,7 @@
1716
TaskNotFoundError,
1817
)
1918
from typing import Any
19+
2020
MINIMAL_TASK: dict[str, Any] = {
2121
'id': '123',
2222
'contextId': 'session-xyz',
@@ -28,30 +28,40 @@
2828
'parts': [{'text': 'test message'}],
2929
'messageId': '111',
3030
}
31+
32+
3133
@pytest.fixture
3234
def event_queue() -> EventQueue:
3335
return EventQueue()
36+
37+
3438
@pytest.mark.asyncio
3539
async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:
3640
"""Test that an event can be enqueued and dequeued."""
3741
event = Message(**MESSAGE_PAYLOAD)
3842
event_queue.enqueue_event(event)
3943
dequeued_event = await event_queue.dequeue_event()
4044
assert dequeued_event == event
45+
46+
4147
@pytest.mark.asyncio
4248
async def test_dequeue_event_no_wait(event_queue: EventQueue) -> None:
4349
"""Test dequeue_event with no_wait=True."""
4450
event = Task(**MINIMAL_TASK)
4551
event_queue.enqueue_event(event)
4652
dequeued_event = await event_queue.dequeue_event(no_wait=True)
4753
assert dequeued_event == event
54+
55+
4856
@pytest.mark.asyncio
4957
async def test_dequeue_event_empty_queue_no_wait(
5058
event_queue: EventQueue,
5159
) -> None:
5260
"""Test dequeue_event with no_wait=True when the queue is empty."""
5361
with pytest.raises(asyncio.QueueEmpty):
5462
await event_queue.dequeue_event(no_wait=True)
63+
64+
5565
@pytest.mark.asyncio
5666
async def test_dequeue_event_wait(event_queue: EventQueue) -> None:
5767
"""Test dequeue_event with the default wait behavior."""
@@ -64,6 +74,8 @@ async def test_dequeue_event_wait(event_queue: EventQueue) -> None:
6474
event_queue.enqueue_event(event)
6575
dequeued_event = await event_queue.dequeue_event()
6676
assert dequeued_event == event
77+
78+
6779
@pytest.mark.asyncio
6880
async def test_task_done(event_queue: EventQueue) -> None:
6981
"""Test the task_done method."""
@@ -75,6 +87,8 @@ async def test_task_done(event_queue: EventQueue) -> None:
7587
event_queue.enqueue_event(event)
7688
_ = await event_queue.dequeue_event()
7789
event_queue.task_done()
90+
91+
7892
@pytest.mark.asyncio
7993
async def test_enqueue_different_event_types(
8094
event_queue: EventQueue,
@@ -88,4 +102,3 @@ async def test_enqueue_different_event_types(
88102
event_queue.enqueue_event(event)
89103
dequeued_event = await event_queue.dequeue_event()
90104
assert dequeued_event == event
91-

tests/server/tasks/test_inmemory_task_store.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
'type': 'task',
1212
}
1313

14+
1415
@pytest.mark.asyncio
1516
async def test_in_memory_task_store_save_and_get() -> None:
1617
"""Test saving and retrieving a task from the in-memory store."""
@@ -20,13 +21,15 @@ async def test_in_memory_task_store_save_and_get() -> None:
2021
retrieved_task = await store.get(MINIMAL_TASK['id'])
2122
assert retrieved_task == task
2223

24+
2325
@pytest.mark.asyncio
2426
async def test_in_memory_task_store_get_nonexistent() -> None:
2527
"""Test retrieving a non-existent task."""
2628
store = InMemoryTaskStore()
2729
retrieved_task = await store.get('nonexistent')
2830
assert retrieved_task is None
2931

32+
3033
@pytest.mark.asyncio
3134
async def test_in_memory_task_store_delete() -> None:
3235
"""Test deleting a task from the store."""
@@ -37,9 +40,9 @@ async def test_in_memory_task_store_delete() -> None:
3740
retrieved_task = await store.get(MINIMAL_TASK['id'])
3841
assert retrieved_task is None
3942

43+
4044
@pytest.mark.asyncio
4145
async def test_in_memory_task_store_delete_nonexistent() -> None:
4246
"""Test deleting a non-existent task."""
4347
store = InMemoryTaskStore()
4448
await store.delete('nonexistent')
45-

0 commit comments

Comments
 (0)