Skip to content

Commit e2d1923

Browse files
committed
feat: add A2AStarletteRouter for modular route definition
Signed-off-by: Shingo OKAWA <[email protected]>
1 parent d11ca2d commit e2d1923

File tree

1 file changed

+277
-0
lines changed

1 file changed

+277
-0
lines changed

src/a2a/server/apps/starlette_app.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,280 @@ def build(
418418
kwargs['routes'] = app_routes
419419

420420
return Starlette(**kwargs)
421+
422+
423+
class A2AStarletteRouter:
424+
"""Defines Starlette routes for serving A2A protocol endpoints.
425+
426+
This router provides the necessary HTTP routes for an A2A-compliant agent.
427+
It handles routing, dispatches requests to appropriate handler methods,
428+
and generates responses—including support for Server-Sent Events (SSE).
429+
430+
Note:
431+
As of 2025-05-24, this class is functionally equivalent to
432+
`A2AStarletteApplication` with the exception that it does not implement
433+
the `build()` method.
434+
"""
435+
436+
def __init__(self, agent_card: AgentCard, http_handler: RequestHandler):
437+
"""Initializes the A2AStarletteRouter.
438+
439+
Args:
440+
agent_card: The AgentCard describing the agent's capabilities.
441+
http_handler: The handler instance responsible for processing A2A
442+
requests via http.
443+
"""
444+
self.agent_card = agent_card
445+
self.handler = JSONRPCHandler(
446+
agent_card=agent_card, request_handler=http_handler
447+
)
448+
449+
def _generate_error_response(
450+
self, request_id: str | int | None, error: JSONRPCError | A2AError
451+
) -> JSONResponse:
452+
"""Creates a Starlette JSONResponse for a JSON-RPC error.
453+
454+
Logs the error based on its type.
455+
456+
Args:
457+
request_id: The ID of the request that caused the error.
458+
error: The `JSONRPCError` or `A2AError` object.
459+
460+
Returns:
461+
A `JSONResponse` object formatted as a JSON-RPC error response.
462+
"""
463+
error_resp = JSONRPCErrorResponse(
464+
id=request_id,
465+
error=error if isinstance(error, JSONRPCError) else error.root,
466+
)
467+
468+
log_level = (
469+
logging.ERROR
470+
if not isinstance(error, A2AError)
471+
or isinstance(error.root, InternalError)
472+
else logging.WARNING
473+
)
474+
logger.log(
475+
log_level,
476+
f'Request Error (ID: {request_id}): '
477+
f"Code={error_resp.error.code}, Message='{error_resp.error.message}'"
478+
f'{", Data=" + str(error_resp.error.data) if hasattr(error, "data") and error_resp.error.data else ""}',
479+
)
480+
return JSONResponse(
481+
error_resp.model_dump(mode='json', exclude_none=True),
482+
status_code=200,
483+
)
484+
485+
async def _handle_requests(self, request: Request) -> Response:
486+
"""Handles incoming POST requests to the main A2A endpoint.
487+
488+
Parses the request body as JSON, validates it against A2A request types,
489+
dispatches it to the appropriate handler method, and returns the response.
490+
Handles JSON parsing errors, validation errors, and other exceptions,
491+
returning appropriate JSON-RPC error responses.
492+
493+
Args:
494+
request: The incoming Starlette Request object.
495+
496+
Returns:
497+
A Starlette Response object (JSONResponse or EventSourceResponse).
498+
499+
Raises:
500+
(Implicitly handled): Various exceptions are caught and converted
501+
into JSON-RPC error responses by this method.
502+
"""
503+
request_id = None
504+
body = None
505+
506+
try:
507+
body = await request.json()
508+
a2a_request = A2ARequest.model_validate(body)
509+
510+
request_id = a2a_request.root.id
511+
request_obj = a2a_request.root
512+
513+
if isinstance(
514+
request_obj,
515+
TaskResubscriptionRequest | SendStreamingMessageRequest,
516+
):
517+
return await self._process_streaming_request(
518+
request_id, a2a_request
519+
)
520+
521+
return await self._process_non_streaming_request(
522+
request_id, a2a_request
523+
)
524+
except MethodNotImplementedError:
525+
traceback.print_exc()
526+
return self._generate_error_response(
527+
request_id, A2AError(root=UnsupportedOperationError())
528+
)
529+
except json.decoder.JSONDecodeError as e:
530+
traceback.print_exc()
531+
return self._generate_error_response(
532+
None, A2AError(root=JSONParseError(message=str(e)))
533+
)
534+
except ValidationError as e:
535+
traceback.print_exc()
536+
return self._generate_error_response(
537+
request_id,
538+
A2AError(root=InvalidRequestError(data=json.loads(e.json()))),
539+
)
540+
except Exception as e:
541+
logger.error(f'Unhandled exception: {e}')
542+
traceback.print_exc()
543+
return self._generate_error_response(
544+
request_id, A2AError(root=InternalError(message=str(e)))
545+
)
546+
547+
async def _process_streaming_request(
548+
self, request_id: str | int | None, a2a_request: A2ARequest
549+
) -> Response:
550+
"""Processes streaming requests (message/stream or tasks/resubscribe).
551+
552+
Args:
553+
request_id: The ID of the request.
554+
a2a_request: The validated A2ARequest object.
555+
556+
Returns:
557+
An `EventSourceResponse` object to stream results to the client.
558+
"""
559+
request_obj = a2a_request.root
560+
handler_result: Any = None
561+
if isinstance(
562+
request_obj,
563+
SendStreamingMessageRequest,
564+
):
565+
handler_result = self.handler.on_message_send_stream(request_obj)
566+
elif isinstance(request_obj, TaskResubscriptionRequest):
567+
handler_result = self.handler.on_resubscribe_to_task(request_obj)
568+
569+
return self._create_response(handler_result)
570+
571+
async def _process_non_streaming_request(
572+
self, request_id: str | int | None, a2a_request: A2ARequest
573+
) -> Response:
574+
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
575+
576+
Args:
577+
request_id: The ID of the request.
578+
a2a_request: The validated A2ARequest object.
579+
580+
Returns:
581+
A `JSONResponse` object containing the result or error.
582+
"""
583+
request_obj = a2a_request.root
584+
handler_result: Any = None
585+
match request_obj:
586+
case SendMessageRequest():
587+
handler_result = await self.handler.on_message_send(request_obj)
588+
case CancelTaskRequest():
589+
handler_result = await self.handler.on_cancel_task(request_obj)
590+
case GetTaskRequest():
591+
handler_result = await self.handler.on_get_task(request_obj)
592+
case SetTaskPushNotificationConfigRequest():
593+
handler_result = await self.handler.set_push_notification(
594+
request_obj
595+
)
596+
case GetTaskPushNotificationConfigRequest():
597+
handler_result = await self.handler.get_push_notification(
598+
request_obj
599+
)
600+
case _:
601+
logger.error(
602+
f'Unhandled validated request type: {type(request_obj)}'
603+
)
604+
error = UnsupportedOperationError(
605+
message=f'Request type {type(request_obj).__name__} is unknown.'
606+
)
607+
handler_result = JSONRPCErrorResponse(
608+
id=request_id, error=error
609+
)
610+
611+
return self._create_response(handler_result)
612+
613+
def _create_response(
614+
self,
615+
handler_result: (
616+
AsyncGenerator[SendStreamingMessageResponse]
617+
| JSONRPCErrorResponse
618+
| JSONRPCResponse
619+
),
620+
) -> Response:
621+
"""Creates a Starlette Response based on the result from the request handler.
622+
623+
Handles:
624+
- AsyncGenerator for Server-Sent Events (SSE).
625+
- JSONRPCErrorResponse for explicit errors returned by handlers.
626+
- Pydantic RootModels (like GetTaskResponse) containing success or error
627+
payloads.
628+
629+
Args:
630+
handler_result: The result from a request handler method. Can be an
631+
async generator for streaming or a Pydantic model for non-streaming.
632+
633+
Returns:
634+
A Starlette JSONResponse or EventSourceResponse.
635+
"""
636+
if isinstance(handler_result, AsyncGenerator):
637+
# Result is a stream of SendStreamingMessageResponse objects
638+
async def event_generator(
639+
stream: AsyncGenerator[SendStreamingMessageResponse],
640+
) -> AsyncGenerator[dict[str, str]]:
641+
async for item in stream:
642+
yield {'data': item.root.model_dump_json(exclude_none=True)}
643+
644+
return EventSourceResponse(event_generator(handler_result))
645+
if isinstance(handler_result, JSONRPCErrorResponse):
646+
return JSONResponse(
647+
handler_result.model_dump(
648+
mode='json',
649+
exclude_none=True,
650+
)
651+
)
652+
653+
return JSONResponse(
654+
handler_result.root.model_dump(mode='json', exclude_none=True)
655+
)
656+
657+
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
658+
"""Handles GET requests for the agent card endpoint.
659+
660+
Args:
661+
request: The incoming Starlette Request object.
662+
663+
Returns:
664+
A JSONResponse containing the agent card data.
665+
"""
666+
return JSONResponse(
667+
self.agent_card.model_dump(mode='json', exclude_none=True)
668+
)
669+
670+
def routes(
671+
self,
672+
agent_card_path: str = '/agent.json',
673+
rpc_path: str = '/',
674+
) -> list[Route]:
675+
"""Returns the Starlette Routes for handling A2A requests.
676+
677+
Args:
678+
agent_card_path: The URL path for the agent card endpoint.
679+
rpc_path: The URL path for the A2A JSON-RPC endpoint (POST requests).
680+
681+
Returns:
682+
A list of Starlette Route objects.
683+
"""
684+
return [
685+
Route(
686+
rpc_url,
687+
self._handle_requests,
688+
methods=['POST'],
689+
name='a2a_handler',
690+
),
691+
Route(
692+
agent_card_url,
693+
self._handle_get_agent_card,
694+
methods=['GET'],
695+
name='agent_card',
696+
),
697+
]

0 commit comments

Comments
 (0)