Skip to content

Commit 6f592bb

Browse files
committed
Add unit tests for fastapi_app sendMessage api
1 parent 39a430d commit 6f592bb

File tree

6 files changed

+244
-12
lines changed

6 files changed

+244
-12
lines changed

Gemini.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
**A2A specification:** https://a2a-protocol.org/latest/specification/
2+
3+
## Project frameworks
4+
- uv as package manager
5+
6+
## How to run all tests
7+
1. If dependencies are not installed install them using following command
8+
```
9+
uv sync --all-extras
10+
```
11+
12+
2. Run tests
13+
```
14+
uv run pytest
15+
```
16+
17+
## Other instructions
18+
1. Whenever writing python code, write types as well.
19+
2. After making the changes run ruff to check and fix the formatting issues
20+
```
21+
uv run ruff check --fix
22+
```
23+
3. Run mypy type checkers to check for type errors
24+
```
25+
uv run mypy
26+
```
27+
4. Run the unit tests to make sure that none of the unit tests are broken.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ authors = [{ name = "Google LLC", email = "[email protected]" }]
88
requires-python = ">=3.10"
99
keywords = ["A2A", "A2A SDK", "A2A Protocol", "Agent2Agent", "Agent 2 Agent"]
1010
dependencies = [
11-
"fastapi>=0.115.2",
11+
"fastapi>=0.116.1",
1212
"httpx>=0.28.1",
1313
"httpx-sse>=0.4.0",
1414
"opentelemetry-api>=1.33.0",
@@ -93,6 +93,7 @@ dev = [
9393
"pyupgrade",
9494
"autoflake",
9595
"no_implicit_optional",
96+
"trio",
9697
]
9798

9899
[[tool.uv.index]]

src/a2a/server/apps/rest/rest_app.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
77
from typing import Any
88

9+
from google.protobuf import message as message_pb2
910
from pydantic import ValidationError
1011
from sse_starlette.sse import EventSourceResponse
1112
from starlette.requests import Request
12-
from starlette.responses import JSONResponse
13+
from starlette.responses import JSONResponse, Response
1314

1415
from a2a.server.apps.jsonrpc import (
1516
CallContextBuilder,
@@ -111,9 +112,11 @@ def _handle_error(self, error: Exception) -> JSONResponse:
111112

112113
async def _handle_request(
113114
self,
114-
method: Callable[[Request, ServerCallContext], Awaitable[str]],
115+
method: Callable[
116+
[Request, ServerCallContext], Awaitable[dict[str, Any]]
117+
],
115118
request: Request,
116-
) -> JSONResponse:
119+
) -> Response:
117120
try:
118121
call_context = self._context_builder.build(request)
119122
response = await method(request, call_context)
@@ -123,7 +126,9 @@ async def _handle_request(
123126

124127
async def _handle_streaming_request(
125128
self,
126-
method: Callable[[Request, ServerCallContext], AsyncIterator[str]],
129+
method: Callable[
130+
[Request, ServerCallContext], AsyncIterator[message_pb2.Message]
131+
],
127132
request: Request,
128133
) -> EventSourceResponse:
129134
try:

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
22

33
from collections.abc import AsyncIterable
4+
from typing import Any
45

5-
from google.protobuf.json_format import MessageToJson, Parse
6+
from google.protobuf.json_format import MessageToDict, MessageToJson, Parse
67
from starlette.requests import Request
78

89
from a2a.grpc import a2a_pb2
@@ -60,7 +61,7 @@ async def on_message_send(
6061
self,
6162
request: Request,
6263
context: ServerCallContext,
63-
) -> str:
64+
) -> dict[str, Any]:
6465
"""Handles the 'message/send' REST method.
6566
6667
Args:
@@ -85,7 +86,7 @@ async def on_message_send(
8586
task_or_message = await self.request_handler.on_message_send(
8687
a2a_request, context
8788
)
88-
return MessageToJson(
89+
return MessageToDict(
8990
proto_utils.ToProto.task_or_message(task_or_message)
9091
)
9192
except ServerError as e:
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
from collections.abc import AsyncGenerator
2+
from unittest.mock import MagicMock
3+
import logging
4+
5+
from google.protobuf import json_format
6+
import pytest
7+
from fastapi import FastAPI
8+
from httpx import ASGITransport, AsyncClient
9+
10+
from a2a.grpc import a2a_pb2
11+
from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication
12+
from a2a.server.request_handlers.request_handler import RequestHandler
13+
from a2a.types import (
14+
AgentCard,
15+
Message,
16+
Part,
17+
Role,
18+
SendMessageRequest,
19+
SendMessageResponse,
20+
SendMessageSuccessResponse,
21+
Task,
22+
TaskState,
23+
TaskStatus,
24+
TextPart,
25+
)
26+
from a2a.utils import proto_utils
27+
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
@pytest.fixture
33+
async def agent_card() -> AgentCard:
34+
mock_agent_card = MagicMock(spec=AgentCard)
35+
mock_agent_card.url = 'http://mockurl.com'
36+
mock_agent_card.supports_authenticated_extended_card = False
37+
return mock_agent_card
38+
39+
40+
@pytest.fixture
41+
async def request_handler() -> RequestHandler:
42+
return MagicMock(spec=RequestHandler)
43+
44+
45+
@pytest.fixture
46+
async def app(
47+
agent_card: AgentCard, request_handler: RequestHandler
48+
) -> FastAPI:
49+
"""Builds the FastAPI application for testing."""
50+
51+
return A2ARESTFastAPIApplication(agent_card, request_handler).build(
52+
agent_card_url='/well-known/agent.json', rpc_url=''
53+
)
54+
55+
56+
@pytest.fixture
57+
async def client(app: FastAPI) -> AsyncClient:
58+
return AsyncClient(
59+
transport=ASGITransport(app=app), base_url='http://testapp'
60+
)
61+
62+
63+
@pytest.mark.anyio
64+
async def test_send_message_success_message(
65+
client: AsyncClient, request_handler: MagicMock
66+
) -> None:
67+
expected_response = a2a_pb2.SendMessageResponse(
68+
msg=a2a_pb2.Message(
69+
message_id='test',
70+
role=a2a_pb2.Role.ROLE_AGENT,
71+
content=[
72+
a2a_pb2.Part(text='response message'),
73+
],
74+
),
75+
)
76+
request_handler.on_message_send.return_value = Message(
77+
message_id='test',
78+
role=Role.agent,
79+
parts=[Part(TextPart(text='response message'))],
80+
)
81+
82+
request = a2a_pb2.SendMessageRequest(
83+
request=a2a_pb2.Message(),
84+
configuration=a2a_pb2.SendMessageConfiguration(),
85+
)
86+
# To see log output, run pytest with '--log-cli=true --log-cli-level=INFO'
87+
response = await client.post(
88+
'/v1/message:send', json=json_format.MessageToDict(request)
89+
)
90+
# request should always be successful
91+
response.raise_for_status()
92+
93+
actual_response = a2a_pb2.SendMessageResponse()
94+
json_format.Parse(response.text, actual_response)
95+
assert expected_response == actual_response
96+
97+
98+
@pytest.mark.anyio
99+
async def test_send_message_success_task(
100+
client: AsyncClient, request_handler: MagicMock
101+
) -> None:
102+
expected_response = a2a_pb2.SendMessageResponse(
103+
task=a2a_pb2.Task(
104+
id='test_task_id',
105+
context_id='test_context_id',
106+
status=a2a_pb2.TaskStatus(
107+
state=a2a_pb2.TaskState.TASK_STATE_COMPLETED,
108+
update=a2a_pb2.Message(
109+
message_id='test',
110+
role=a2a_pb2.ROLE_AGENT,
111+
content=[
112+
a2a_pb2.Part(text='response task message'),
113+
],
114+
),
115+
),
116+
),
117+
)
118+
request_handler.on_message_send.return_value = Task(
119+
id='test_task_id',
120+
context_id='test_context_id',
121+
status=TaskStatus(
122+
state=TaskState.completed,
123+
message=Message(
124+
message_id='test',
125+
role=Role.agent,
126+
parts=[Part(TextPart(text='response task message'))],
127+
),
128+
),
129+
)
130+
131+
request = a2a_pb2.SendMessageRequest(
132+
request=a2a_pb2.Message(),
133+
configuration=a2a_pb2.SendMessageConfiguration(),
134+
)
135+
# To see log output, run pytest with '--log-cli=true --log-cli-level=INFO'
136+
response = await client.post(
137+
'/v1/message:send', json=json_format.MessageToDict(request)
138+
)
139+
# request should always be successful
140+
response.raise_for_status()
141+
142+
actual_response = a2a_pb2.SendMessageResponse()
143+
json_format.Parse(response.text, actual_response)
144+
assert expected_response == actual_response
145+
146+
147+
if __name__ == '__main__':
148+
pytest.main([__file__])

uv.lock

Lines changed: 54 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)