Skip to content

Commit 4778c65

Browse files
committed
feat: support extended_agent_card in A2AStarletteRouter
Signed-off-by: Shingo OKAWA <[email protected]>
1 parent e9d8e8f commit 4778c65

File tree

1 file changed

+127
-38
lines changed

1 file changed

+127
-38
lines changed

src/a2a/server/apps/starlette_app.py

Lines changed: 127 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33
import traceback
4+
45
from abc import ABC, abstractmethod
56
from collections.abc import AsyncGenerator
67
from typing import Any
@@ -15,16 +16,29 @@
1516
from a2a.server.context import ServerCallContext
1617
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
1718
from a2a.server.request_handlers.request_handler import RequestHandler
18-
from a2a.types import (A2AError, A2ARequest, AgentCard, CancelTaskRequest,
19-
GetTaskPushNotificationConfigRequest, GetTaskRequest,
20-
InternalError, InvalidRequestError, JSONParseError,
21-
JSONRPCError, JSONRPCErrorResponse, JSONRPCResponse,
22-
SendMessageRequest, SendStreamingMessageRequest,
23-
SendStreamingMessageResponse,
24-
SetTaskPushNotificationConfigRequest,
25-
TaskResubscriptionRequest, UnsupportedOperationError)
19+
from a2a.types import (
20+
A2AError,
21+
A2ARequest,
22+
AgentCard,
23+
CancelTaskRequest,
24+
GetTaskPushNotificationConfigRequest,
25+
GetTaskRequest,
26+
InternalError,
27+
InvalidRequestError,
28+
JSONParseError,
29+
JSONRPCError,
30+
JSONRPCErrorResponse,
31+
JSONRPCResponse,
32+
SendMessageRequest,
33+
SendStreamingMessageRequest,
34+
SendStreamingMessageResponse,
35+
SetTaskPushNotificationConfigRequest,
36+
TaskResubscriptionRequest,
37+
UnsupportedOperationError,
38+
)
2639
from a2a.utils.errors import MethodNotImplementedError
2740

41+
2842
logger = logging.getLogger(__name__)
2943

3044

@@ -344,7 +358,9 @@ async def _handle_get_authenticated_extended_agent_card(
344358
# extended_agent_card was provided during server initialization,
345359
# return a 404
346360
return JSONResponse(
347-
{'error': 'Authenticated extended agent card is supported but not configured on the server.'},
361+
{
362+
'error': 'Authenticated extended agent card is supported but not configured on the server.'
363+
},
348364
status_code=404,
349365
)
350366

@@ -437,24 +453,34 @@ def __init__(
437453
self,
438454
agent_card: AgentCard,
439455
http_handler: RequestHandler,
440-
agent_card_path: str = '/agent.json',
441-
rpc_path: str = '/',
456+
extended_agent_card: AgentCard | None = None,
457+
context_builder: CallContextBuilder | None = None,
442458
):
443459
"""Initializes the A2AStarletteRouter.
444460
445461
Args:
446462
agent_card: The AgentCard describing the agent's capabilities.
447463
http_handler: The handler instance responsible for processing A2A
448464
requests via http.
449-
agent_card_path: The URL path for the agent card endpoint.
450-
rpc_path: The URL path for the A2A JSON-RPC endpoint (POST requests).
465+
extended_agent_card: An optional, distinct AgentCard to be served
466+
at the authenticated extended card endpoint.
467+
context_builder: The CallContextBuilder used to construct the
468+
ServerCallContext passed to the http_handler. If None, no
469+
ServerCallContext is passed.
451470
"""
452471
self.agent_card = agent_card
472+
self.extended_agent_card = extended_agent_card
453473
self.handler = JSONRPCHandler(
454474
agent_card=agent_card, request_handler=http_handler
455475
)
456-
self.agent_card_path = agent_card_path
457-
self.rpc_path = rpc_path
476+
if (
477+
self.agent_card.supportsAuthenticatedExtendedCard
478+
and self.extended_agent_card is None
479+
):
480+
logger.error(
481+
'AgentCard.supportsAuthenticatedExtendedCard is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
482+
)
483+
self._context_builder = context_builder
458484

459485
def _generate_error_response(
460486
self, request_id: str | int | None, error: JSONRPCError | A2AError
@@ -474,7 +500,6 @@ def _generate_error_response(
474500
id=request_id,
475501
error=error if isinstance(error, JSONRPCError) else error.root,
476502
)
477-
478503
log_level = (
479504
logging.ERROR
480505
if not isinstance(error, A2AError)
@@ -512,24 +537,25 @@ async def _handle_requests(self, request: Request) -> Response:
512537
"""
513538
request_id = None
514539
body = None
515-
516540
try:
517541
body = await request.json()
518542
a2a_request = A2ARequest.model_validate(body)
519-
543+
call_context = (
544+
self._context_builder.build(request)
545+
if self._context_builder
546+
else None
547+
)
520548
request_id = a2a_request.root.id
521549
request_obj = a2a_request.root
522-
523550
if isinstance(
524551
request_obj,
525552
TaskResubscriptionRequest | SendStreamingMessageRequest,
526553
):
527554
return await self._process_streaming_request(
528-
request_id, a2a_request
555+
request_id, a2a_request, call_context
529556
)
530-
531557
return await self._process_non_streaming_request(
532-
request_id, a2a_request
558+
request_id, a2a_request, call_context
533559
)
534560
except MethodNotImplementedError:
535561
traceback.print_exc()
@@ -555,7 +581,10 @@ async def _handle_requests(self, request: Request) -> Response:
555581
)
556582

557583
async def _process_streaming_request(
558-
self, request_id: str | int | None, a2a_request: A2ARequest
584+
self,
585+
request_id: str | int | None,
586+
a2a_request: A2ARequest,
587+
context: ServerCallContext,
559588
) -> Response:
560589
"""Processes streaming requests (message/stream or tasks/resubscribe).
561590
@@ -572,14 +601,20 @@ async def _process_streaming_request(
572601
request_obj,
573602
SendStreamingMessageRequest,
574603
):
575-
handler_result = self.handler.on_message_send_stream(request_obj)
604+
handler_result = self.handler.on_message_send_stream(
605+
request_obj, context
606+
)
576607
elif isinstance(request_obj, TaskResubscriptionRequest):
577-
handler_result = self.handler.on_resubscribe_to_task(request_obj)
578-
608+
handler_result = self.handler.on_resubscribe_to_task(
609+
request_obj, context
610+
)
579611
return self._create_response(handler_result)
580612

581613
async def _process_non_streaming_request(
582-
self, request_id: str | int | None, a2a_request: A2ARequest
614+
self,
615+
request_id: str | int | None,
616+
a2a_request: A2ARequest,
617+
context: ServerCallContext,
583618
) -> Response:
584619
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
585620
@@ -594,18 +629,26 @@ async def _process_non_streaming_request(
594629
handler_result: Any = None
595630
match request_obj:
596631
case SendMessageRequest():
597-
handler_result = await self.handler.on_message_send(request_obj)
632+
handler_result = await self.handler.on_message_send(
633+
request_obj, context
634+
)
598635
case CancelTaskRequest():
599-
handler_result = await self.handler.on_cancel_task(request_obj)
636+
handler_result = await self.handler.on_cancel_task(
637+
request_obj, context
638+
)
600639
case GetTaskRequest():
601-
handler_result = await self.handler.on_get_task(request_obj)
640+
handler_result = await self.handler.on_get_task(
641+
request_obj, context
642+
)
602643
case SetTaskPushNotificationConfigRequest():
603644
handler_result = await self.handler.set_push_notification(
604-
request_obj
645+
request_obj,
646+
context,
605647
)
606648
case GetTaskPushNotificationConfigRequest():
607649
handler_result = await self.handler.get_push_notification(
608-
request_obj
650+
request_obj,
651+
context,
609652
)
610653
case _:
611654
logger.error(
@@ -617,7 +660,6 @@ async def _process_non_streaming_request(
617660
handler_result = JSONRPCErrorResponse(
618661
id=request_id, error=error
619662
)
620-
621663
return self._create_response(handler_result)
622664

623665
def _create_response(
@@ -659,7 +701,6 @@ async def event_generator(
659701
exclude_none=True,
660702
)
661703
)
662-
663704
return JSONResponse(
664705
handler_result.root.model_dump(mode='json', exclude_none=True)
665706
)
@@ -673,27 +714,75 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
673714
Returns:
674715
A JSONResponse containing the agent card data.
675716
"""
717+
# The public agent card is a direct serialization of the agent_card
718+
# provided at initialization.
676719
return JSONResponse(
677720
self.agent_card.model_dump(mode='json', exclude_none=True)
678721
)
679722

680-
def routes(self) -> list[Route]:
723+
async def _handle_get_authenticated_extended_agent_card(
724+
self, request: Request
725+
) -> JSONResponse:
726+
"""Handles GET requests for the authenticated extended agent card."""
727+
if not self.agent_card.supportsAuthenticatedExtendedCard:
728+
return JSONResponse(
729+
{'error': 'Extended agent card not supported or not enabled.'},
730+
status_code=404,
731+
)
732+
# If an explicit extended_agent_card is provided, serve that.
733+
if self.extended_agent_card:
734+
return JSONResponse(
735+
self.extended_agent_card.model_dump(
736+
mode='json', exclude_none=True
737+
)
738+
)
739+
# If supportsAuthenticatedExtendedCard is true, but no specific
740+
# extended_agent_card was provided during server initialization,
741+
# return a 404
742+
return JSONResponse(
743+
{
744+
'error': 'Authenticated extended agent card is supported but not configured on the server.'
745+
},
746+
status_code=404,
747+
)
748+
749+
def routes(
750+
self,
751+
agent_card_path: str = '/.well-known/agent.json',
752+
extended_agent_card_path: str = '/agent/authenticatedExtendedCard',
753+
rpc_path: str = '/',
754+
) -> list[Route]:
681755
"""Returns the Starlette Routes for handling A2A requests.
682756
757+
Args:
758+
agent_card_path: The URL path for the agent card endpoint.
759+
rpc_path: The URL path for the A2A JSON-RPC endpoint (POST requests).
760+
extended_agent_card_path: The URL path for the authenticated extended agent card endpoint.
761+
683762
Returns:
684763
A list of Starlette Route objects.
685764
"""
686-
return [
765+
routes = [
687766
Route(
688-
self.rpc_path,
767+
rpc_path,
689768
self._handle_requests,
690769
methods=['POST'],
691770
name='a2a_handler',
692771
),
693772
Route(
694-
self.agent_card_path,
773+
agent_card_path,
695774
self._handle_get_agent_card,
696775
methods=['GET'],
697776
name='agent_card',
698777
),
699778
]
779+
if self.agent_card.supportsAuthenticatedExtendedCard:
780+
routes.append(
781+
Route(
782+
extended_agent_card_path,
783+
self._handle_get_authenticated_extended_agent_card,
784+
methods=['GET'],
785+
name='authenticated_extended_agent_card',
786+
)
787+
)
788+
return routes

0 commit comments

Comments
 (0)