Skip to content

Commit ed6d5bb

Browse files
committed
feat: support extended_agent_card in A2AStarletteRouter
Signed-off-by: Shingo OKAWA <[email protected]>
1 parent e829c96 commit ed6d5bb

File tree

1 file changed

+102
-29
lines changed

1 file changed

+102
-29
lines changed

src/a2a/server/apps/starlette_app.py

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -468,24 +468,34 @@ def __init__(
468468
self,
469469
agent_card: AgentCard,
470470
http_handler: RequestHandler,
471-
agent_card_path: str = '/agent.json',
472-
rpc_path: str = '/',
471+
extended_agent_card: AgentCard | None = None,
472+
context_builder: CallContextBuilder | None = None,
473473
):
474474
"""Initializes the A2AStarletteRouter.
475475
476476
Args:
477477
agent_card: The AgentCard describing the agent's capabilities.
478478
http_handler: The handler instance responsible for processing A2A
479479
requests via http.
480-
agent_card_path: The URL path for the agent card endpoint.
481-
rpc_path: The URL path for the A2A JSON-RPC endpoint (POST requests).
480+
extended_agent_card: An optional, distinct AgentCard to be served
481+
at the authenticated extended card endpoint.
482+
context_builder: The CallContextBuilder used to construct the
483+
ServerCallContext passed to the http_handler. If None, no
484+
ServerCallContext is passed.
482485
"""
483486
self.agent_card = agent_card
487+
self.extended_agent_card = extended_agent_card
484488
self.handler = JSONRPCHandler(
485489
agent_card=agent_card, request_handler=http_handler
486490
)
487-
self.agent_card_path = agent_card_path
488-
self.rpc_path = rpc_path
491+
if (
492+
self.agent_card.supportsAuthenticatedExtendedCard
493+
and self.extended_agent_card is None
494+
):
495+
logger.error(
496+
'AgentCard.supportsAuthenticatedExtendedCard is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
497+
)
498+
self._context_builder = context_builder
489499

490500
def _generate_error_response(
491501
self, request_id: str | int | None, error: JSONRPCError | A2AError
@@ -505,7 +515,6 @@ def _generate_error_response(
505515
id=request_id,
506516
error=error if isinstance(error, JSONRPCError) else error.root,
507517
)
508-
509518
log_level = (
510519
logging.ERROR
511520
if not isinstance(error, A2AError)
@@ -543,24 +552,25 @@ async def _handle_requests(self, request: Request) -> Response:
543552
"""
544553
request_id = None
545554
body = None
546-
547555
try:
548556
body = await request.json()
549557
a2a_request = A2ARequest.model_validate(body)
550-
558+
call_context = (
559+
self._context_builder.build(request)
560+
if self._context_builder
561+
else None
562+
)
551563
request_id = a2a_request.root.id
552564
request_obj = a2a_request.root
553-
554565
if isinstance(
555566
request_obj,
556567
TaskResubscriptionRequest | SendStreamingMessageRequest,
557568
):
558569
return await self._process_streaming_request(
559-
request_id, a2a_request
570+
request_id, a2a_request, call_context
560571
)
561-
562572
return await self._process_non_streaming_request(
563-
request_id, a2a_request
573+
request_id, a2a_request, call_context
564574
)
565575
except MethodNotImplementedError:
566576
traceback.print_exc()
@@ -586,7 +596,10 @@ async def _handle_requests(self, request: Request) -> Response:
586596
)
587597

588598
async def _process_streaming_request(
589-
self, request_id: str | int | None, a2a_request: A2ARequest
599+
self,
600+
request_id: str | int | None,
601+
a2a_request: A2ARequest,
602+
context: ServerCallContext,
590603
) -> Response:
591604
"""Processes streaming requests (message/stream or tasks/resubscribe).
592605
@@ -603,14 +616,20 @@ async def _process_streaming_request(
603616
request_obj,
604617
SendStreamingMessageRequest,
605618
):
606-
handler_result = self.handler.on_message_send_stream(request_obj)
619+
handler_result = self.handler.on_message_send_stream(
620+
request_obj, context
621+
)
607622
elif isinstance(request_obj, TaskResubscriptionRequest):
608-
handler_result = self.handler.on_resubscribe_to_task(request_obj)
609-
623+
handler_result = self.handler.on_resubscribe_to_task(
624+
request_obj, context
625+
)
610626
return self._create_response(handler_result)
611627

612628
async def _process_non_streaming_request(
613-
self, request_id: str | int | None, a2a_request: A2ARequest
629+
self,
630+
request_id: str | int | None,
631+
a2a_request: A2ARequest,
632+
context: ServerCallContext,
614633
) -> Response:
615634
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
616635
@@ -625,18 +644,26 @@ async def _process_non_streaming_request(
625644
handler_result: Any = None
626645
match request_obj:
627646
case SendMessageRequest():
628-
handler_result = await self.handler.on_message_send(request_obj)
647+
handler_result = await self.handler.on_message_send(
648+
request_obj, context
649+
)
629650
case CancelTaskRequest():
630-
handler_result = await self.handler.on_cancel_task(request_obj)
651+
handler_result = await self.handler.on_cancel_task(
652+
request_obj, context
653+
)
631654
case GetTaskRequest():
632-
handler_result = await self.handler.on_get_task(request_obj)
655+
handler_result = await self.handler.on_get_task(
656+
request_obj, context
657+
)
633658
case SetTaskPushNotificationConfigRequest():
634659
handler_result = await self.handler.set_push_notification(
635-
request_obj
660+
request_obj,
661+
context,
636662
)
637663
case GetTaskPushNotificationConfigRequest():
638664
handler_result = await self.handler.get_push_notification(
639-
request_obj
665+
request_obj,
666+
context,
640667
)
641668
case _:
642669
logger.error(
@@ -648,7 +675,6 @@ async def _process_non_streaming_request(
648675
handler_result = JSONRPCErrorResponse(
649676
id=request_id, error=error
650677
)
651-
652678
return self._create_response(handler_result)
653679

654680
def _create_response(
@@ -690,7 +716,6 @@ async def event_generator(
690716
exclude_none=True,
691717
)
692718
)
693-
694719
return JSONResponse(
695720
handler_result.root.model_dump(mode='json', exclude_none=True)
696721
)
@@ -704,27 +729,75 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
704729
Returns:
705730
A JSONResponse containing the agent card data.
706731
"""
732+
# The public agent card is a direct serialization of the agent_card
733+
# provided at initialization.
707734
return JSONResponse(
708735
self.agent_card.model_dump(mode='json', exclude_none=True)
709736
)
710737

711-
def routes(self) -> list[Route]:
738+
async def _handle_get_authenticated_extended_agent_card(
739+
self, request: Request
740+
) -> JSONResponse:
741+
"""Handles GET requests for the authenticated extended agent card."""
742+
if not self.agent_card.supportsAuthenticatedExtendedCard:
743+
return JSONResponse(
744+
{'error': 'Extended agent card not supported or not enabled.'},
745+
status_code=404,
746+
)
747+
# If an explicit extended_agent_card is provided, serve that.
748+
if self.extended_agent_card:
749+
return JSONResponse(
750+
self.extended_agent_card.model_dump(
751+
mode='json', exclude_none=True
752+
)
753+
)
754+
# If supportsAuthenticatedExtendedCard is true, but no specific
755+
# extended_agent_card was provided during server initialization,
756+
# return a 404
757+
return JSONResponse(
758+
{
759+
'error': 'Authenticated extended agent card is supported but not configured on the server.'
760+
},
761+
status_code=404,
762+
)
763+
764+
def routes(
765+
self,
766+
agent_card_path: str = '/.well-known/agent.json',
767+
extended_agent_card_path: str = '/agent/authenticatedExtendedCard',
768+
rpc_path: str = '/',
769+
) -> list[Route]:
712770
"""Returns the Starlette Routes for handling A2A requests.
713771
772+
Args:
773+
agent_card_path: The URL path for the agent card endpoint.
774+
rpc_path: The URL path for the A2A JSON-RPC endpoint (POST requests).
775+
extended_agent_card_path: The URL path for the authenticated extended agent card endpoint.
776+
714777
Returns:
715778
A list of Starlette Route objects.
716779
"""
717-
return [
780+
routes = [
718781
Route(
719-
self.rpc_path,
782+
rpc_path,
720783
self._handle_requests,
721784
methods=['POST'],
722785
name='a2a_handler',
723786
),
724787
Route(
725-
self.agent_card_path,
788+
agent_card_path,
726789
self._handle_get_agent_card,
727790
methods=['GET'],
728791
name='agent_card',
729792
),
730793
]
794+
if self.agent_card.supportsAuthenticatedExtendedCard:
795+
routes.append(
796+
Route(
797+
extended_agent_card_path,
798+
self._handle_get_authenticated_extended_agent_card,
799+
methods=['GET'],
800+
name='authenticated_extended_agent_card',
801+
)
802+
)
803+
return routes

0 commit comments

Comments
 (0)