Skip to content

Commit 1da4b3b

Browse files
add-fastapi-application
1 parent 36bfdfd commit 1da4b3b

File tree

6 files changed

+891
-771
lines changed

6 files changed

+891
-771
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ description = "A2A Python SDK"
55
readme = "README.md"
66
requires-python = ">=3.13"
77
dependencies = [
8+
"fastapi>=0.115.12",
89
"httpx>=0.28.1",
910
"httpx-sse>=0.4.0",
1011
"opentelemetry-api>=1.33.0",
@@ -13,6 +14,7 @@ dependencies = [
1314
"sse-starlette>=2.3.3",
1415
"starlette>=0.46.2",
1516
"typing-extensions>=4.13.2",
17+
"uvicorn>=0.34.2",
1618
]
1719

1820
[tool.hatch.build.targets.wheel]

src/a2a/server/apps/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
from a2a.server.apps.http_app import HttpApp
2+
from a2a.server.apps.default_app import DefaultA2AApplication
23
from a2a.server.apps.starlette_app import A2AStarletteApplication
4+
from a2a.server.apps.fastapi_app import A2AFastAPIApplication
35

46

5-
__all__ = ['A2AStarletteApplication', 'HttpApp']
7+
__all__ = [
8+
'DefaultA2AApplication',
9+
'A2AStarletteApplication',
10+
'A2AFastAPIApplication',
11+
'HttpApp'
12+
]

src/a2a/server/apps/default_app.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
import json
2+
import logging
3+
import traceback
4+
from abc import ABC, abstractmethod
5+
6+
from collections.abc import AsyncGenerator
7+
from typing import Any, Optional, Union
8+
9+
from pydantic import ValidationError
10+
from sse_starlette.sse import EventSourceResponse
11+
from starlette.applications import Starlette
12+
from fastapi import FastAPI
13+
from starlette.requests import Request
14+
from starlette.responses import JSONResponse, Response
15+
16+
from a2a.server.request_handlers.jsonrpc_handler import (
17+
JSONRPCHandler,
18+
RequestHandler,
19+
)
20+
from a2a.types import (
21+
A2AError,
22+
A2ARequest,
23+
AgentCard,
24+
CancelTaskRequest,
25+
GetTaskPushNotificationConfigRequest,
26+
GetTaskRequest,
27+
InternalError,
28+
InvalidRequestError,
29+
JSONParseError,
30+
JSONRPCError,
31+
JSONRPCErrorResponse,
32+
JSONRPCResponse,
33+
SendMessageRequest,
34+
SendStreamingMessageRequest,
35+
SendStreamingMessageResponse,
36+
SetTaskPushNotificationConfigRequest,
37+
TaskResubscriptionRequest,
38+
UnsupportedOperationError,
39+
)
40+
from a2a.utils.errors import MethodNotImplementedError
41+
42+
logger = logging.getLogger(__name__)
43+
44+
45+
class DefaultA2AApplication(ABC):
46+
"""Base class for A2A applications.
47+
48+
Args:
49+
agent_card: The AgentCard describing the agent's capabilities.
50+
http_handler: The handler instance responsible for processing A2A
51+
requests via http.
52+
"""
53+
54+
def __init__(self, agent_card: AgentCard, http_handler: RequestHandler):
55+
"""Initializes the A2AApplication.
56+
57+
Args:
58+
agent_card: The AgentCard describing the agent's capabilities.
59+
http_handler: The handler instance responsible for processing A2A
60+
requests via http.
61+
"""
62+
self.agent_card = agent_card
63+
self.handler = JSONRPCHandler(
64+
agent_card=agent_card, request_handler=http_handler
65+
)
66+
67+
def _generate_error_response(
68+
self, request_id: str | int | None, error: JSONRPCError | A2AError
69+
) -> JSONResponse:
70+
"""Creates a JSONResponse for a JSON-RPC error."""
71+
error_resp = JSONRPCErrorResponse(
72+
id=request_id,
73+
error=error if isinstance(error, JSONRPCError) else error.root,
74+
)
75+
76+
log_level = (
77+
logging.ERROR
78+
if not isinstance(error, A2AError)
79+
or isinstance(error.root, InternalError)
80+
else logging.WARNING
81+
)
82+
logger.log(
83+
log_level,
84+
f'Request Error (ID: {request_id}: '
85+
f"Code={error_resp.error.code}, Message='{error_resp.error.message}'"
86+
f'{", Data=" + str(error_resp.error.data) if hasattr(error, "data") and error_resp.error.data else ""}',
87+
)
88+
return JSONResponse(
89+
error_resp.model_dump(mode='json', exclude_none=True),
90+
status_code=200,
91+
)
92+
93+
async def _handle_requests(self, request: Request) -> Response:
94+
"""Handles incoming POST requests to the main A2A endpoint.
95+
96+
Parses the request body as JSON, validates it against A2A request types,
97+
dispatches it to the appropriate handler method, and returns the response.
98+
Handles JSON parsing errors, validation errors, and other exceptions,
99+
returning appropriate JSON-RPC error responses.
100+
"""
101+
request_id = None
102+
body = None
103+
104+
try:
105+
body = await request.json()
106+
a2a_request = A2ARequest.model_validate(body)
107+
108+
request_id = a2a_request.root.id
109+
request_obj = a2a_request.root
110+
111+
if isinstance(
112+
request_obj,
113+
TaskResubscriptionRequest | SendStreamingMessageRequest,
114+
):
115+
return await self._process_streaming_request(
116+
request_id, a2a_request
117+
)
118+
119+
return await self._process_non_streaming_request(
120+
request_id, a2a_request
121+
)
122+
except MethodNotImplementedError:
123+
traceback.print_exc()
124+
return self._generate_error_response(
125+
request_id, A2AError(root=UnsupportedOperationError())
126+
)
127+
except json.decoder.JSONDecodeError as e:
128+
traceback.print_exc()
129+
return self._generate_error_response(
130+
None, A2AError(root=JSONParseError(message=str(e)))
131+
)
132+
except ValidationError as e:
133+
traceback.print_exc()
134+
return self._generate_error_response(
135+
request_id,
136+
A2AError(root=InvalidRequestError(data=json.loads(e.json()))),
137+
)
138+
except Exception as e:
139+
logger.error(f'Unhandled exception: {e}')
140+
traceback.print_exc()
141+
return self._generate_error_response(
142+
request_id, A2AError(root=InternalError(message=str(e)))
143+
)
144+
145+
async def _process_streaming_request(
146+
self, request_id: str | int | None, a2a_request: A2ARequest
147+
) -> Response:
148+
"""Processes streaming requests.
149+
150+
Args:
151+
request_id: The ID of the request.
152+
a2a_request: The validated A2ARequest object.
153+
"""
154+
request_obj = a2a_request.root
155+
handler_result: Any = None
156+
if isinstance(
157+
request_obj,
158+
SendStreamingMessageRequest,
159+
):
160+
handler_result = self.handler.on_message_send_stream(request_obj)
161+
elif isinstance(request_obj, TaskResubscriptionRequest):
162+
handler_result = self.handler.on_resubscribe_to_task(request_obj)
163+
164+
return self._create_response(handler_result)
165+
166+
async def _process_non_streaming_request(
167+
self, request_id: str | int | None, a2a_request: A2ARequest
168+
) -> Response:
169+
"""Processes non-streaming requests.
170+
171+
Args:
172+
request_id: The ID of the request.
173+
a2a_request: The validated A2ARequest object.
174+
"""
175+
request_obj = a2a_request.root
176+
handler_result: Any = None
177+
match request_obj:
178+
case SendMessageRequest():
179+
handler_result = await self.handler.on_message_send(request_obj)
180+
case CancelTaskRequest():
181+
handler_result = await self.handler.on_cancel_task(request_obj)
182+
case GetTaskRequest():
183+
handler_result = await self.handler.on_get_task(request_obj)
184+
case SetTaskPushNotificationConfigRequest():
185+
handler_result = await self.handler.set_push_notification(
186+
request_obj
187+
)
188+
case GetTaskPushNotificationConfigRequest():
189+
handler_result = await self.handler.get_push_notification(
190+
request_obj
191+
)
192+
case _:
193+
logger.error(
194+
f'Unhandled validated request type: {type(request_obj)}'
195+
)
196+
error = UnsupportedOperationError(
197+
message=f'Request type {type(request_obj).__name__} is unknown.'
198+
)
199+
handler_result = JSONRPCErrorResponse(
200+
id=request_id, error=error
201+
)
202+
203+
return self._create_response(handler_result)
204+
205+
def _create_response(
206+
self,
207+
handler_result: (
208+
AsyncGenerator[SendStreamingMessageResponse, None]
209+
| JSONRPCErrorResponse
210+
| JSONRPCResponse
211+
),
212+
) -> Response:
213+
"""Creates a Starlette Response based on the result from the request handler.
214+
215+
Handles:
216+
- AsyncGenerator for Server-Sent Events (SSE).
217+
- JSONRPCErrorResponse for explicit errors returned by handlers.
218+
- Pydantic RootModels (like GetTaskResponse) containing success or error
219+
payloads.
220+
- Unexpected types by returning an InternalError.
221+
222+
Args:
223+
handler_result: AsyncGenerator of SendStreamingMessageResponse
224+
225+
Returns:
226+
A Starlette JSONResponse or EventSourceResponse.
227+
"""
228+
if isinstance(handler_result, AsyncGenerator):
229+
# Result is a stream of SendStreamingMessageResponse objects
230+
async def event_generator(
231+
stream: AsyncGenerator[SendStreamingMessageResponse, None],
232+
) -> AsyncGenerator[dict[str, str], None]:
233+
async for item in stream:
234+
yield {'data': item.root.model_dump_json(exclude_none=True)}
235+
236+
return EventSourceResponse(event_generator(handler_result))
237+
if isinstance(handler_result, JSONRPCErrorResponse):
238+
return JSONResponse(
239+
handler_result.model_dump(
240+
mode='json',
241+
exclude_none=True,
242+
)
243+
)
244+
245+
return JSONResponse(
246+
handler_result.root.model_dump(mode='json', exclude_none=True)
247+
)
248+
249+
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
250+
"""Handles GET requests for the agent card."""
251+
return JSONResponse(
252+
self.agent_card.model_dump(mode='json', exclude_none=True)
253+
)
254+
255+
@abstractmethod
256+
def build(
257+
self,
258+
agent_card_url: str = '/.well-known/agent.json',
259+
rpc_url: str = '/',
260+
**kwargs: Any,
261+
) -> Union[Starlette, FastAPI]:
262+
"""Builds and returns the FastAPI application instance.
263+
264+
Args:
265+
agent_card_url: The URL for the agent card endpoint.
266+
rpc_url: The URL for the A2A JSON-RPC endpoint
267+
**kwargs: Additional keyword arguments to pass to the FastAPI constructor.
268+
269+
Returns:
270+
A configured FastAPI application instance.
271+
"""
272+
pass

src/a2a/server/apps/fastapi_app.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import logging
2+
from typing import Any
3+
4+
from fastapi import FastAPI, Request
5+
6+
from a2a.server.apps import DefaultA2AApplication
7+
from a2a.server.request_handlers.jsonrpc_handler import RequestHandler
8+
from a2a.types import AgentCard
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class A2AFastAPIApplication(DefaultA2AApplication):
14+
"""A FastAPI application implementing the A2A protocol server endpoints.
15+
16+
Handles incoming JSON-RPC requests, routes them to the appropriate
17+
handler methods, and manages response generation including Server-Sent Events
18+
(SSE)."""
19+
20+
def __init__(self, agent_card: AgentCard, http_handler: RequestHandler):
21+
"""Initializes the A2A FastAPI application.
22+
23+
Args:
24+
agent_card: The AgentCard describing the agent's capabilities.
25+
http_handler: The handler instance responsible for processing A2A requests via http.
26+
"""
27+
super().__init__(agent_card, http_handler)
28+
29+
def build(
30+
self,
31+
agent_card_url: str = '/.well-known/agent.json',
32+
rpc_url: str = '/',
33+
**kwargs: Any,
34+
) -> FastAPI:
35+
"""Builds and returns the FastAPI application instance.
36+
37+
Args:
38+
agent_card_url: The URL for the agent card endpoint.
39+
rpc_url: The URL for the A2A JSON-RPC endpoint
40+
**kwargs: Additional keyword arguments to pass to the FastAPI constructor.
41+
42+
Returns:
43+
A configured FastAPI application instance.
44+
"""
45+
app = FastAPI(**kwargs)
46+
47+
@app.post(rpc_url)
48+
async def handle_a2a_request(request: Request):
49+
return await self._handle_requests(request)
50+
51+
@app.get(agent_card_url)
52+
async def get_agent_card(request: Request):
53+
return await self._handle_get_agent_card(request)
54+
55+
return app

0 commit comments

Comments
 (0)