Skip to content

Commit 66526b9

Browse files
authored
Merge branch 'main' into fix-bug#367-client_hangs
2 parents 516e4d2 + dec4b48 commit 66526b9

File tree

5 files changed

+67
-22
lines changed

5 files changed

+67
-22
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies = [
1414
"pydantic>=2.11.3",
1515
"sse-starlette",
1616
"starlette",
17-
"protobuf==5.29.5",
17+
"protobuf>=5.29.5",
1818
"google-api-core>=1.26.0",
1919
]
2020

src/a2a/client/transports/grpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def send_message(
8787
metadata=proto_utils.ToProto.metadata(request.metadata),
8888
)
8989
)
90-
if response.task:
90+
if response.HasField('task'):
9191
return proto_utils.FromProto.task(response.task)
9292
return proto_utils.FromProto.message(response.msg)
9393

src/a2a/server/apps/jsonrpc/fastapi_app.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import logging
22

3-
from collections.abc import AsyncIterator
4-
from contextlib import asynccontextmanager
53
from typing import Any
64

75
from fastapi import FastAPI
@@ -21,6 +19,28 @@
2119
logger = logging.getLogger(__name__)
2220

2321

22+
class A2AFastAPI(FastAPI):
23+
"""A FastAPI application that adds A2A-specific OpenAPI components."""
24+
25+
_a2a_components_added: bool = False
26+
27+
def openapi(self) -> dict[str, Any]:
28+
"""Generates the OpenAPI schema for the application."""
29+
openapi_schema = super().openapi()
30+
if not self._a2a_components_added:
31+
a2a_request_schema = A2ARequest.model_json_schema(
32+
ref_template='#/components/schemas/{model}'
33+
)
34+
defs = a2a_request_schema.pop('$defs', {})
35+
component_schemas = openapi_schema.setdefault(
36+
'components', {}
37+
).setdefault('schemas', {})
38+
component_schemas.update(defs)
39+
component_schemas['A2ARequest'] = a2a_request_schema
40+
self._a2a_components_added = True
41+
return openapi_schema
42+
43+
2444
class A2AFastAPIApplication(JSONRPCApplication):
2545
"""A FastAPI application implementing the A2A protocol server endpoints.
2646
@@ -92,23 +112,7 @@ def build(
92112
Returns:
93113
A configured FastAPI application instance.
94114
"""
95-
96-
@asynccontextmanager
97-
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
98-
a2a_request_schema = A2ARequest.model_json_schema(
99-
ref_template='#/components/schemas/{model}'
100-
)
101-
defs = a2a_request_schema.pop('$defs', {})
102-
openapi_schema = app.openapi()
103-
component_schemas = openapi_schema.setdefault(
104-
'components', {}
105-
).setdefault('schemas', {})
106-
component_schemas.update(defs)
107-
component_schemas['A2ARequest'] = a2a_request_schema
108-
109-
yield
110-
111-
app = FastAPI(lifespan=lifespan, **kwargs)
115+
app = A2AFastAPI(**kwargs)
112116

113117
self.add_routes_to_app(
114118
app, agent_card_url, rpc_url, extended_agent_card_url

tests/client/test_grpc_client.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TaskStatus,
1919
TextPart,
2020
)
21-
from a2a.utils import proto_utils
21+
from a2a.utils import get_text_parts, proto_utils
2222

2323

2424
# Fixtures
@@ -112,6 +112,28 @@ async def test_send_message_task_response(
112112
assert response.id == sample_task.id
113113

114114

115+
@pytest.mark.asyncio
116+
async def test_send_message_message_response(
117+
grpc_transport: GrpcTransport,
118+
mock_grpc_stub: AsyncMock,
119+
sample_message_send_params: MessageSendParams,
120+
sample_message: Message,
121+
):
122+
"""Test send_message that returns a Message."""
123+
mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
124+
msg=proto_utils.ToProto.message(sample_message)
125+
)
126+
127+
response = await grpc_transport.send_message(sample_message_send_params)
128+
129+
mock_grpc_stub.SendMessage.assert_awaited_once()
130+
assert isinstance(response, Message)
131+
assert response.message_id == sample_message.message_id
132+
assert get_text_parts(response.parts) == get_text_parts(
133+
sample_message.parts
134+
)
135+
136+
115137
@pytest.mark.asyncio
116138
async def test_get_task(
117139
grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task

tests/server/apps/jsonrpc/test_serialization.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest import mock
22

33
import pytest
4+
from fastapi import FastAPI
45

56
from pydantic import ValidationError
67
from starlette.testclient import TestClient
@@ -183,3 +184,21 @@ def test_handle_unicode_characters(agent_card_with_api_key: AgentCard):
183184
data = response.json()
184185
assert 'error' not in data or data['error'] is None
185186
assert data['result']['parts'][0]['text'] == f'Received: {unicode_text}'
187+
188+
189+
def test_fastapi_sub_application(agent_card_with_api_key: AgentCard):
190+
"""
191+
Tests that the A2AFastAPIApplication endpoint correctly passes the url in sub-application.
192+
"""
193+
handler = mock.AsyncMock()
194+
sub_app_instance = A2AFastAPIApplication(agent_card_with_api_key, handler)
195+
app_instance = FastAPI()
196+
app_instance.mount('/a2a', sub_app_instance.build())
197+
client = TestClient(app_instance)
198+
199+
response = client.get('/a2a/openapi.json')
200+
assert response.status_code == 200
201+
response_data = response.json()
202+
203+
assert 'servers' in response_data
204+
assert response_data['servers'] == [{'url': '/a2a'}]

0 commit comments

Comments
 (0)