Skip to content

Commit 3aad9f1

Browse files
committed
feat: support extended_agent_card in A2AStarletteRouter
Signed-off-by: Shingo OKAWA <[email protected]>
1 parent 5176ac2 commit 3aad9f1

File tree

1 file changed

+102
-29
lines changed

1 file changed

+102
-29
lines changed

src/a2a/server/apps/starlette_app.py

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -480,24 +480,34 @@ def __init__(
480480
self,
481481
agent_card: AgentCard,
482482
http_handler: RequestHandler,
483-
agent_card_path: str = '/agent.json',
484-
rpc_path: str = '/',
483+
extended_agent_card: AgentCard | None = None,
484+
context_builder: CallContextBuilder | None = None,
485485
):
486486
"""Initializes the A2AStarletteRouter.
487487
488488
Args:
489489
agent_card: The AgentCard describing the agent's capabilities.
490490
http_handler: The handler instance responsible for processing A2A
491491
requests via http.
492-
agent_card_path: The URL path for the agent card endpoint.
493-
rpc_path: The URL path for the A2A JSON-RPC endpoint (POST requests).
492+
extended_agent_card: An optional, distinct AgentCard to be served
493+
at the authenticated extended card endpoint.
494+
context_builder: The CallContextBuilder used to construct the
495+
ServerCallContext passed to the http_handler. If None, no
496+
ServerCallContext is passed.
494497
"""
495498
self.agent_card = agent_card
499+
self.extended_agent_card = extended_agent_card
496500
self.handler = JSONRPCHandler(
497501
agent_card=agent_card, request_handler=http_handler
498502
)
499-
self.agent_card_path = agent_card_path
500-
self.rpc_path = rpc_path
503+
if (
504+
self.agent_card.supportsAuthenticatedExtendedCard
505+
and self.extended_agent_card is None
506+
):
507+
logger.error(
508+
'AgentCard.supportsAuthenticatedExtendedCard is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
509+
)
510+
self._context_builder = context_builder
501511

502512
def _generate_error_response(
503513
self, request_id: str | int | None, error: JSONRPCError | A2AError
@@ -517,7 +527,6 @@ def _generate_error_response(
517527
id=request_id,
518528
error=error if isinstance(error, JSONRPCError) else error.root,
519529
)
520-
521530
log_level = (
522531
logging.ERROR
523532
if not isinstance(error, A2AError)
@@ -555,24 +564,25 @@ async def _handle_requests(self, request: Request) -> Response:
555564
"""
556565
request_id = None
557566
body = None
558-
559567
try:
560568
body = await request.json()
561569
a2a_request = A2ARequest.model_validate(body)
562-
570+
call_context = (
571+
self._context_builder.build(request)
572+
if self._context_builder
573+
else None
574+
)
563575
request_id = a2a_request.root.id
564576
request_obj = a2a_request.root
565-
566577
if isinstance(
567578
request_obj,
568579
TaskResubscriptionRequest | SendStreamingMessageRequest,
569580
):
570581
return await self._process_streaming_request(
571-
request_id, a2a_request
582+
request_id, a2a_request, call_context
572583
)
573-
574584
return await self._process_non_streaming_request(
575-
request_id, a2a_request
585+
request_id, a2a_request, call_context
576586
)
577587
except MethodNotImplementedError:
578588
traceback.print_exc()
@@ -598,7 +608,10 @@ async def _handle_requests(self, request: Request) -> Response:
598608
)
599609

600610
async def _process_streaming_request(
601-
self, request_id: str | int | None, a2a_request: A2ARequest
611+
self,
612+
request_id: str | int | None,
613+
a2a_request: A2ARequest,
614+
context: ServerCallContext,
602615
) -> Response:
603616
"""Processes streaming requests (message/stream or tasks/resubscribe).
604617
@@ -615,14 +628,20 @@ async def _process_streaming_request(
615628
request_obj,
616629
SendStreamingMessageRequest,
617630
):
618-
handler_result = self.handler.on_message_send_stream(request_obj)
631+
handler_result = self.handler.on_message_send_stream(
632+
request_obj, context
633+
)
619634
elif isinstance(request_obj, TaskResubscriptionRequest):
620-
handler_result = self.handler.on_resubscribe_to_task(request_obj)
621-
635+
handler_result = self.handler.on_resubscribe_to_task(
636+
request_obj, context
637+
)
622638
return self._create_response(handler_result)
623639

624640
async def _process_non_streaming_request(
625-
self, request_id: str | int | None, a2a_request: A2ARequest
641+
self,
642+
request_id: str | int | None,
643+
a2a_request: A2ARequest,
644+
context: ServerCallContext,
626645
) -> Response:
627646
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
628647
@@ -637,18 +656,26 @@ async def _process_non_streaming_request(
637656
handler_result: Any = None
638657
match request_obj:
639658
case SendMessageRequest():
640-
handler_result = await self.handler.on_message_send(request_obj)
659+
handler_result = await self.handler.on_message_send(
660+
request_obj, context
661+
)
641662
case CancelTaskRequest():
642-
handler_result = await self.handler.on_cancel_task(request_obj)
663+
handler_result = await self.handler.on_cancel_task(
664+
request_obj, context
665+
)
643666
case GetTaskRequest():
644-
handler_result = await self.handler.on_get_task(request_obj)
667+
handler_result = await self.handler.on_get_task(
668+
request_obj, context
669+
)
645670
case SetTaskPushNotificationConfigRequest():
646671
handler_result = await self.handler.set_push_notification(
647-
request_obj
672+
request_obj,
673+
context,
648674
)
649675
case GetTaskPushNotificationConfigRequest():
650676
handler_result = await self.handler.get_push_notification(
651-
request_obj
677+
request_obj,
678+
context,
652679
)
653680
case _:
654681
logger.error(
@@ -660,7 +687,6 @@ async def _process_non_streaming_request(
660687
handler_result = JSONRPCErrorResponse(
661688
id=request_id, error=error
662689
)
663-
664690
return self._create_response(handler_result)
665691

666692
def _create_response(
@@ -702,7 +728,6 @@ async def event_generator(
702728
exclude_none=True,
703729
)
704730
)
705-
706731
return JSONResponse(
707732
handler_result.root.model_dump(mode='json', exclude_none=True)
708733
)
@@ -716,27 +741,75 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
716741
Returns:
717742
A JSONResponse containing the agent card data.
718743
"""
744+
# The public agent card is a direct serialization of the agent_card
745+
# provided at initialization.
719746
return JSONResponse(
720747
self.agent_card.model_dump(mode='json', exclude_none=True)
721748
)
722749

723-
def routes(self) -> list[Route]:
750+
async def _handle_get_authenticated_extended_agent_card(
751+
self, request: Request
752+
) -> JSONResponse:
753+
"""Handles GET requests for the authenticated extended agent card."""
754+
if not self.agent_card.supportsAuthenticatedExtendedCard:
755+
return JSONResponse(
756+
{'error': 'Extended agent card not supported or not enabled.'},
757+
status_code=404,
758+
)
759+
# If an explicit extended_agent_card is provided, serve that.
760+
if self.extended_agent_card:
761+
return JSONResponse(
762+
self.extended_agent_card.model_dump(
763+
mode='json', exclude_none=True
764+
)
765+
)
766+
# If supportsAuthenticatedExtendedCard is true, but no specific
767+
# extended_agent_card was provided during server initialization,
768+
# return a 404
769+
return JSONResponse(
770+
{
771+
'error': 'Authenticated extended agent card is supported but not configured on the server.'
772+
},
773+
status_code=404,
774+
)
775+
776+
def routes(
777+
self,
778+
agent_card_path: str = '/.well-known/agent.json',
779+
extended_agent_card_path: str = '/agent/authenticatedExtendedCard',
780+
rpc_path: str = '/',
781+
) -> list[Route]:
724782
"""Returns the Starlette Routes for handling A2A requests.
725783
784+
Args:
785+
agent_card_path: The URL path for the agent card endpoint.
786+
rpc_path: The URL path for the A2A JSON-RPC endpoint (POST requests).
787+
extended_agent_card_path: The URL path for the authenticated extended agent card endpoint.
788+
726789
Returns:
727790
A list of Starlette Route objects.
728791
"""
729-
return [
792+
routes = [
730793
Route(
731-
self.rpc_path,
794+
rpc_path,
732795
self._handle_requests,
733796
methods=['POST'],
734797
name='a2a_handler',
735798
),
736799
Route(
737-
self.agent_card_path,
800+
agent_card_path,
738801
self._handle_get_agent_card,
739802
methods=['GET'],
740803
name='agent_card',
741804
),
742805
]
806+
if self.agent_card.supportsAuthenticatedExtendedCard:
807+
routes.append(
808+
Route(
809+
extended_agent_card_path,
810+
self._handle_get_authenticated_extended_agent_card,
811+
methods=['GET'],
812+
name='authenticated_extended_agent_card',
813+
)
814+
)
815+
return routes

0 commit comments

Comments
 (0)