Skip to content

Commit 3362d43

Browse files
committed
feat: add A2AStarletteRouter for modular route definition
Signed-off-by: Shingo OKAWA <[email protected]>
1 parent 0bc1c99 commit 3362d43

File tree

1 file changed

+277
-0
lines changed

1 file changed

+277
-0
lines changed

src/a2a/server/apps/starlette_app.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,280 @@ def build(
449449
kwargs['routes'] = app_routes
450450

451451
return Starlette(**kwargs)
452+
453+
454+
class A2AStarletteRouter:
455+
"""Defines Starlette routes for serving A2A protocol endpoints.
456+
457+
This router provides the necessary HTTP routes for an A2A-compliant agent.
458+
It handles routing, dispatches requests to appropriate handler methods,
459+
and generates responses—including support for Server-Sent Events (SSE).
460+
461+
Note:
462+
As of 2025-05-24, this class is functionally equivalent to
463+
`A2AStarletteApplication` with the exception that it does not implement
464+
the `build()` method.
465+
"""
466+
467+
def __init__(self, agent_card: AgentCard, http_handler: RequestHandler):
468+
"""Initializes the A2AStarletteRouter.
469+
470+
Args:
471+
agent_card: The AgentCard describing the agent's capabilities.
472+
http_handler: The handler instance responsible for processing A2A
473+
requests via http.
474+
"""
475+
self.agent_card = agent_card
476+
self.handler = JSONRPCHandler(
477+
agent_card=agent_card, request_handler=http_handler
478+
)
479+
480+
def _generate_error_response(
481+
self, request_id: str | int | None, error: JSONRPCError | A2AError
482+
) -> JSONResponse:
483+
"""Creates a Starlette JSONResponse for a JSON-RPC error.
484+
485+
Logs the error based on its type.
486+
487+
Args:
488+
request_id: The ID of the request that caused the error.
489+
error: The `JSONRPCError` or `A2AError` object.
490+
491+
Returns:
492+
A `JSONResponse` object formatted as a JSON-RPC error response.
493+
"""
494+
error_resp = JSONRPCErrorResponse(
495+
id=request_id,
496+
error=error if isinstance(error, JSONRPCError) else error.root,
497+
)
498+
499+
log_level = (
500+
logging.ERROR
501+
if not isinstance(error, A2AError)
502+
or isinstance(error.root, InternalError)
503+
else logging.WARNING
504+
)
505+
logger.log(
506+
log_level,
507+
f'Request Error (ID: {request_id}): '
508+
f"Code={error_resp.error.code}, Message='{error_resp.error.message}'"
509+
f'{", Data=" + str(error_resp.error.data) if hasattr(error, "data") and error_resp.error.data else ""}',
510+
)
511+
return JSONResponse(
512+
error_resp.model_dump(mode='json', exclude_none=True),
513+
status_code=200,
514+
)
515+
516+
async def _handle_requests(self, request: Request) -> Response:
517+
"""Handles incoming POST requests to the main A2A endpoint.
518+
519+
Parses the request body as JSON, validates it against A2A request types,
520+
dispatches it to the appropriate handler method, and returns the response.
521+
Handles JSON parsing errors, validation errors, and other exceptions,
522+
returning appropriate JSON-RPC error responses.
523+
524+
Args:
525+
request: The incoming Starlette Request object.
526+
527+
Returns:
528+
A Starlette Response object (JSONResponse or EventSourceResponse).
529+
530+
Raises:
531+
(Implicitly handled): Various exceptions are caught and converted
532+
into JSON-RPC error responses by this method.
533+
"""
534+
request_id = None
535+
body = None
536+
537+
try:
538+
body = await request.json()
539+
a2a_request = A2ARequest.model_validate(body)
540+
541+
request_id = a2a_request.root.id
542+
request_obj = a2a_request.root
543+
544+
if isinstance(
545+
request_obj,
546+
TaskResubscriptionRequest | SendStreamingMessageRequest,
547+
):
548+
return await self._process_streaming_request(
549+
request_id, a2a_request
550+
)
551+
552+
return await self._process_non_streaming_request(
553+
request_id, a2a_request
554+
)
555+
except MethodNotImplementedError:
556+
traceback.print_exc()
557+
return self._generate_error_response(
558+
request_id, A2AError(root=UnsupportedOperationError())
559+
)
560+
except json.decoder.JSONDecodeError as e:
561+
traceback.print_exc()
562+
return self._generate_error_response(
563+
None, A2AError(root=JSONParseError(message=str(e)))
564+
)
565+
except ValidationError as e:
566+
traceback.print_exc()
567+
return self._generate_error_response(
568+
request_id,
569+
A2AError(root=InvalidRequestError(data=json.loads(e.json()))),
570+
)
571+
except Exception as e:
572+
logger.error(f'Unhandled exception: {e}')
573+
traceback.print_exc()
574+
return self._generate_error_response(
575+
request_id, A2AError(root=InternalError(message=str(e)))
576+
)
577+
578+
async def _process_streaming_request(
579+
self, request_id: str | int | None, a2a_request: A2ARequest
580+
) -> Response:
581+
"""Processes streaming requests (message/stream or tasks/resubscribe).
582+
583+
Args:
584+
request_id: The ID of the request.
585+
a2a_request: The validated A2ARequest object.
586+
587+
Returns:
588+
An `EventSourceResponse` object to stream results to the client.
589+
"""
590+
request_obj = a2a_request.root
591+
handler_result: Any = None
592+
if isinstance(
593+
request_obj,
594+
SendStreamingMessageRequest,
595+
):
596+
handler_result = self.handler.on_message_send_stream(request_obj)
597+
elif isinstance(request_obj, TaskResubscriptionRequest):
598+
handler_result = self.handler.on_resubscribe_to_task(request_obj)
599+
600+
return self._create_response(handler_result)
601+
602+
async def _process_non_streaming_request(
603+
self, request_id: str | int | None, a2a_request: A2ARequest
604+
) -> Response:
605+
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
606+
607+
Args:
608+
request_id: The ID of the request.
609+
a2a_request: The validated A2ARequest object.
610+
611+
Returns:
612+
A `JSONResponse` object containing the result or error.
613+
"""
614+
request_obj = a2a_request.root
615+
handler_result: Any = None
616+
match request_obj:
617+
case SendMessageRequest():
618+
handler_result = await self.handler.on_message_send(request_obj)
619+
case CancelTaskRequest():
620+
handler_result = await self.handler.on_cancel_task(request_obj)
621+
case GetTaskRequest():
622+
handler_result = await self.handler.on_get_task(request_obj)
623+
case SetTaskPushNotificationConfigRequest():
624+
handler_result = await self.handler.set_push_notification(
625+
request_obj
626+
)
627+
case GetTaskPushNotificationConfigRequest():
628+
handler_result = await self.handler.get_push_notification(
629+
request_obj
630+
)
631+
case _:
632+
logger.error(
633+
f'Unhandled validated request type: {type(request_obj)}'
634+
)
635+
error = UnsupportedOperationError(
636+
message=f'Request type {type(request_obj).__name__} is unknown.'
637+
)
638+
handler_result = JSONRPCErrorResponse(
639+
id=request_id, error=error
640+
)
641+
642+
return self._create_response(handler_result)
643+
644+
def _create_response(
645+
self,
646+
handler_result: (
647+
AsyncGenerator[SendStreamingMessageResponse]
648+
| JSONRPCErrorResponse
649+
| JSONRPCResponse
650+
),
651+
) -> Response:
652+
"""Creates a Starlette Response based on the result from the request handler.
653+
654+
Handles:
655+
- AsyncGenerator for Server-Sent Events (SSE).
656+
- JSONRPCErrorResponse for explicit errors returned by handlers.
657+
- Pydantic RootModels (like GetTaskResponse) containing success or error
658+
payloads.
659+
660+
Args:
661+
handler_result: The result from a request handler method. Can be an
662+
async generator for streaming or a Pydantic model for non-streaming.
663+
664+
Returns:
665+
A Starlette JSONResponse or EventSourceResponse.
666+
"""
667+
if isinstance(handler_result, AsyncGenerator):
668+
# Result is a stream of SendStreamingMessageResponse objects
669+
async def event_generator(
670+
stream: AsyncGenerator[SendStreamingMessageResponse],
671+
) -> AsyncGenerator[dict[str, str]]:
672+
async for item in stream:
673+
yield {'data': item.root.model_dump_json(exclude_none=True)}
674+
675+
return EventSourceResponse(event_generator(handler_result))
676+
if isinstance(handler_result, JSONRPCErrorResponse):
677+
return JSONResponse(
678+
handler_result.model_dump(
679+
mode='json',
680+
exclude_none=True,
681+
)
682+
)
683+
684+
return JSONResponse(
685+
handler_result.root.model_dump(mode='json', exclude_none=True)
686+
)
687+
688+
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
689+
"""Handles GET requests for the agent card endpoint.
690+
691+
Args:
692+
request: The incoming Starlette Request object.
693+
694+
Returns:
695+
A JSONResponse containing the agent card data.
696+
"""
697+
return JSONResponse(
698+
self.agent_card.model_dump(mode='json', exclude_none=True)
699+
)
700+
701+
def routes(
702+
self,
703+
agent_card_path: str = '/agent.json',
704+
rpc_path: str = '/',
705+
) -> list[Route]:
706+
"""Returns the Starlette Routes for handling A2A requests.
707+
708+
Args:
709+
agent_card_path: The URL path for the agent card endpoint.
710+
rpc_path: The URL path for the A2A JSON-RPC endpoint (POST requests).
711+
712+
Returns:
713+
A list of Starlette Route objects.
714+
"""
715+
return [
716+
Route(
717+
rpc_url,
718+
self._handle_requests,
719+
methods=['POST'],
720+
name='a2a_handler',
721+
),
722+
Route(
723+
agent_card_url,
724+
self._handle_get_agent_card,
725+
methods=['GET'],
726+
name='agent_card',
727+
),
728+
]

0 commit comments

Comments
 (0)