Skip to content

Commit 0239fbc

Browse files
committed
feat: add A2AStarletteRouter for modular route definition
Signed-off-by: Shingo OKAWA <[email protected]>
1 parent e87bb3d commit 0239fbc

File tree

1 file changed

+277
-0
lines changed

1 file changed

+277
-0
lines changed

src/a2a/server/apps/starlette_app.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,3 +461,280 @@ def build(
461461
kwargs['routes'] = app_routes
462462

463463
return Starlette(**kwargs)
464+
465+
466+
class A2AStarletteRouter:
467+
"""Defines Starlette routes for serving A2A protocol endpoints.
468+
469+
This router provides the necessary HTTP routes for an A2A-compliant agent.
470+
It handles routing, dispatches requests to appropriate handler methods,
471+
and generates responses—including support for Server-Sent Events (SSE).
472+
473+
Note:
474+
As of 2025-05-24, this class is functionally equivalent to
475+
`A2AStarletteApplication` with the exception that it does not implement
476+
the `build()` method.
477+
"""
478+
479+
def __init__(self, agent_card: AgentCard, http_handler: RequestHandler):
480+
"""Initializes the A2AStarletteRouter.
481+
482+
Args:
483+
agent_card: The AgentCard describing the agent's capabilities.
484+
http_handler: The handler instance responsible for processing A2A
485+
requests via http.
486+
"""
487+
self.agent_card = agent_card
488+
self.handler = JSONRPCHandler(
489+
agent_card=agent_card, request_handler=http_handler
490+
)
491+
492+
def _generate_error_response(
493+
self, request_id: str | int | None, error: JSONRPCError | A2AError
494+
) -> JSONResponse:
495+
"""Creates a Starlette JSONResponse for a JSON-RPC error.
496+
497+
Logs the error based on its type.
498+
499+
Args:
500+
request_id: The ID of the request that caused the error.
501+
error: The `JSONRPCError` or `A2AError` object.
502+
503+
Returns:
504+
A `JSONResponse` object formatted as a JSON-RPC error response.
505+
"""
506+
error_resp = JSONRPCErrorResponse(
507+
id=request_id,
508+
error=error if isinstance(error, JSONRPCError) else error.root,
509+
)
510+
511+
log_level = (
512+
logging.ERROR
513+
if not isinstance(error, A2AError)
514+
or isinstance(error.root, InternalError)
515+
else logging.WARNING
516+
)
517+
logger.log(
518+
log_level,
519+
f'Request Error (ID: {request_id}): '
520+
f"Code={error_resp.error.code}, Message='{error_resp.error.message}'"
521+
f'{", Data=" + str(error_resp.error.data) if hasattr(error, "data") and error_resp.error.data else ""}',
522+
)
523+
return JSONResponse(
524+
error_resp.model_dump(mode='json', exclude_none=True),
525+
status_code=200,
526+
)
527+
528+
async def _handle_requests(self, request: Request) -> Response:
529+
"""Handles incoming POST requests to the main A2A endpoint.
530+
531+
Parses the request body as JSON, validates it against A2A request types,
532+
dispatches it to the appropriate handler method, and returns the response.
533+
Handles JSON parsing errors, validation errors, and other exceptions,
534+
returning appropriate JSON-RPC error responses.
535+
536+
Args:
537+
request: The incoming Starlette Request object.
538+
539+
Returns:
540+
A Starlette Response object (JSONResponse or EventSourceResponse).
541+
542+
Raises:
543+
(Implicitly handled): Various exceptions are caught and converted
544+
into JSON-RPC error responses by this method.
545+
"""
546+
request_id = None
547+
body = None
548+
549+
try:
550+
body = await request.json()
551+
a2a_request = A2ARequest.model_validate(body)
552+
553+
request_id = a2a_request.root.id
554+
request_obj = a2a_request.root
555+
556+
if isinstance(
557+
request_obj,
558+
TaskResubscriptionRequest | SendStreamingMessageRequest,
559+
):
560+
return await self._process_streaming_request(
561+
request_id, a2a_request
562+
)
563+
564+
return await self._process_non_streaming_request(
565+
request_id, a2a_request
566+
)
567+
except MethodNotImplementedError:
568+
traceback.print_exc()
569+
return self._generate_error_response(
570+
request_id, A2AError(root=UnsupportedOperationError())
571+
)
572+
except json.decoder.JSONDecodeError as e:
573+
traceback.print_exc()
574+
return self._generate_error_response(
575+
None, A2AError(root=JSONParseError(message=str(e)))
576+
)
577+
except ValidationError as e:
578+
traceback.print_exc()
579+
return self._generate_error_response(
580+
request_id,
581+
A2AError(root=InvalidRequestError(data=json.loads(e.json()))),
582+
)
583+
except Exception as e:
584+
logger.error(f'Unhandled exception: {e}')
585+
traceback.print_exc()
586+
return self._generate_error_response(
587+
request_id, A2AError(root=InternalError(message=str(e)))
588+
)
589+
590+
async def _process_streaming_request(
591+
self, request_id: str | int | None, a2a_request: A2ARequest
592+
) -> Response:
593+
"""Processes streaming requests (message/stream or tasks/resubscribe).
594+
595+
Args:
596+
request_id: The ID of the request.
597+
a2a_request: The validated A2ARequest object.
598+
599+
Returns:
600+
An `EventSourceResponse` object to stream results to the client.
601+
"""
602+
request_obj = a2a_request.root
603+
handler_result: Any = None
604+
if isinstance(
605+
request_obj,
606+
SendStreamingMessageRequest,
607+
):
608+
handler_result = self.handler.on_message_send_stream(request_obj)
609+
elif isinstance(request_obj, TaskResubscriptionRequest):
610+
handler_result = self.handler.on_resubscribe_to_task(request_obj)
611+
612+
return self._create_response(handler_result)
613+
614+
async def _process_non_streaming_request(
615+
self, request_id: str | int | None, a2a_request: A2ARequest
616+
) -> Response:
617+
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
618+
619+
Args:
620+
request_id: The ID of the request.
621+
a2a_request: The validated A2ARequest object.
622+
623+
Returns:
624+
A `JSONResponse` object containing the result or error.
625+
"""
626+
request_obj = a2a_request.root
627+
handler_result: Any = None
628+
match request_obj:
629+
case SendMessageRequest():
630+
handler_result = await self.handler.on_message_send(request_obj)
631+
case CancelTaskRequest():
632+
handler_result = await self.handler.on_cancel_task(request_obj)
633+
case GetTaskRequest():
634+
handler_result = await self.handler.on_get_task(request_obj)
635+
case SetTaskPushNotificationConfigRequest():
636+
handler_result = await self.handler.set_push_notification(
637+
request_obj
638+
)
639+
case GetTaskPushNotificationConfigRequest():
640+
handler_result = await self.handler.get_push_notification(
641+
request_obj
642+
)
643+
case _:
644+
logger.error(
645+
f'Unhandled validated request type: {type(request_obj)}'
646+
)
647+
error = UnsupportedOperationError(
648+
message=f'Request type {type(request_obj).__name__} is unknown.'
649+
)
650+
handler_result = JSONRPCErrorResponse(
651+
id=request_id, error=error
652+
)
653+
654+
return self._create_response(handler_result)
655+
656+
def _create_response(
657+
self,
658+
handler_result: (
659+
AsyncGenerator[SendStreamingMessageResponse]
660+
| JSONRPCErrorResponse
661+
| JSONRPCResponse
662+
),
663+
) -> Response:
664+
"""Creates a Starlette Response based on the result from the request handler.
665+
666+
Handles:
667+
- AsyncGenerator for Server-Sent Events (SSE).
668+
- JSONRPCErrorResponse for explicit errors returned by handlers.
669+
- Pydantic RootModels (like GetTaskResponse) containing success or error
670+
payloads.
671+
672+
Args:
673+
handler_result: The result from a request handler method. Can be an
674+
async generator for streaming or a Pydantic model for non-streaming.
675+
676+
Returns:
677+
A Starlette JSONResponse or EventSourceResponse.
678+
"""
679+
if isinstance(handler_result, AsyncGenerator):
680+
# Result is a stream of SendStreamingMessageResponse objects
681+
async def event_generator(
682+
stream: AsyncGenerator[SendStreamingMessageResponse],
683+
) -> AsyncGenerator[dict[str, str]]:
684+
async for item in stream:
685+
yield {'data': item.root.model_dump_json(exclude_none=True)}
686+
687+
return EventSourceResponse(event_generator(handler_result))
688+
if isinstance(handler_result, JSONRPCErrorResponse):
689+
return JSONResponse(
690+
handler_result.model_dump(
691+
mode='json',
692+
exclude_none=True,
693+
)
694+
)
695+
696+
return JSONResponse(
697+
handler_result.root.model_dump(mode='json', exclude_none=True)
698+
)
699+
700+
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
701+
"""Handles GET requests for the agent card endpoint.
702+
703+
Args:
704+
request: The incoming Starlette Request object.
705+
706+
Returns:
707+
A JSONResponse containing the agent card data.
708+
"""
709+
return JSONResponse(
710+
self.agent_card.model_dump(mode='json', exclude_none=True)
711+
)
712+
713+
def routes(
714+
self,
715+
agent_card_path: str = '/agent.json',
716+
rpc_path: str = '/',
717+
) -> list[Route]:
718+
"""Returns the Starlette Routes for handling A2A requests.
719+
720+
Args:
721+
agent_card_path: The URL path for the agent card endpoint.
722+
rpc_path: The URL path for the A2A JSON-RPC endpoint (POST requests).
723+
724+
Returns:
725+
A list of Starlette Route objects.
726+
"""
727+
return [
728+
Route(
729+
rpc_url,
730+
self._handle_requests,
731+
methods=['POST'],
732+
name='a2a_handler',
733+
),
734+
Route(
735+
agent_card_url,
736+
self._handle_get_agent_card,
737+
methods=['GET'],
738+
name='agent_card',
739+
),
740+
]

0 commit comments

Comments
 (0)