Skip to content

Commit 00be6cb

Browse files
committed
Merge remote-tracking branch 'upstream/main' into chore/improve-coverage-grpc-client
Signed-off-by: Shingo OKAWA <[email protected]>
2 parents e984261 + dec4b48 commit 00be6cb

File tree

4 files changed

+47
-21
lines changed

4 files changed

+47
-21
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies = [
1414
"pydantic>=2.11.3",
1515
"sse-starlette",
1616
"starlette",
17-
"protobuf==5.29.5",
17+
"protobuf>=5.29.5",
1818
"google-api-core>=1.26.0",
1919
]
2020

src/a2a/server/apps/jsonrpc/fastapi_app.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import logging
22

3-
from collections.abc import AsyncIterator
4-
from contextlib import asynccontextmanager
53
from typing import Any
64

75
from fastapi import FastAPI
@@ -21,6 +19,28 @@
2119
logger = logging.getLogger(__name__)
2220

2321

22+
class A2AFastAPI(FastAPI):
23+
"""A FastAPI application that adds A2A-specific OpenAPI components."""
24+
25+
_a2a_components_added: bool = False
26+
27+
def openapi(self) -> dict[str, Any]:
28+
"""Generates the OpenAPI schema for the application."""
29+
openapi_schema = super().openapi()
30+
if not self._a2a_components_added:
31+
a2a_request_schema = A2ARequest.model_json_schema(
32+
ref_template='#/components/schemas/{model}'
33+
)
34+
defs = a2a_request_schema.pop('$defs', {})
35+
component_schemas = openapi_schema.setdefault(
36+
'components', {}
37+
).setdefault('schemas', {})
38+
component_schemas.update(defs)
39+
component_schemas['A2ARequest'] = a2a_request_schema
40+
self._a2a_components_added = True
41+
return openapi_schema
42+
43+
2444
class A2AFastAPIApplication(JSONRPCApplication):
2545
"""A FastAPI application implementing the A2A protocol server endpoints.
2646
@@ -92,23 +112,7 @@ def build(
92112
Returns:
93113
A configured FastAPI application instance.
94114
"""
95-
96-
@asynccontextmanager
97-
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
98-
a2a_request_schema = A2ARequest.model_json_schema(
99-
ref_template='#/components/schemas/{model}'
100-
)
101-
defs = a2a_request_schema.pop('$defs', {})
102-
openapi_schema = app.openapi()
103-
component_schemas = openapi_schema.setdefault(
104-
'components', {}
105-
).setdefault('schemas', {})
106-
component_schemas.update(defs)
107-
component_schemas['A2ARequest'] = a2a_request_schema
108-
109-
yield
110-
111-
app = FastAPI(lifespan=lifespan, **kwargs)
115+
app = A2AFastAPI(**kwargs)
112116

113117
self.add_routes_to_app(
114118
app, agent_card_url, rpc_url, extended_agent_card_url

tests/client/test_grpc_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
TaskStatusUpdateEvent,
2727
TextPart,
2828
)
29-
from a2a.utils import proto_utils
29+
from a2a.utils import get_text_parts, proto_utils
3030
from a2a.utils.errors import ServerError
3131

3232

@@ -209,6 +209,9 @@ async def test_send_message_message_response(
209209
mock_grpc_stub.SendMessage.assert_awaited_once()
210210
assert isinstance(response, Message)
211211
assert response.message_id == sample_message.message_id
212+
assert get_text_parts(response.parts) == get_text_parts(
213+
sample_message.parts
214+
)
212215

213216

214217
@pytest.mark.asyncio

tests/server/apps/jsonrpc/test_serialization.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest import mock
22

33
import pytest
4+
from fastapi import FastAPI
45

56
from pydantic import ValidationError
67
from starlette.testclient import TestClient
@@ -183,3 +184,21 @@ def test_handle_unicode_characters(agent_card_with_api_key: AgentCard):
183184
data = response.json()
184185
assert 'error' not in data or data['error'] is None
185186
assert data['result']['parts'][0]['text'] == f'Received: {unicode_text}'
187+
188+
189+
def test_fastapi_sub_application(agent_card_with_api_key: AgentCard):
190+
"""
191+
Tests that the A2AFastAPIApplication endpoint correctly passes the url in sub-application.
192+
"""
193+
handler = mock.AsyncMock()
194+
sub_app_instance = A2AFastAPIApplication(agent_card_with_api_key, handler)
195+
app_instance = FastAPI()
196+
app_instance.mount('/a2a', sub_app_instance.build())
197+
client = TestClient(app_instance)
198+
199+
response = client.get('/a2a/openapi.json')
200+
assert response.status_code == 200
201+
response_data = response.json()
202+
203+
assert 'servers' in response_data
204+
assert response_data['servers'] == [{'url': '/a2a'}]

0 commit comments

Comments
 (0)