22import logging
33import traceback
44
5+ from abc import ABC , abstractmethod
56from collections .abc import AsyncGenerator
67from typing import Any
78
1213from starlette .responses import JSONResponse , Response
1314from starlette .routing import Route
1415
15- from a2a .server .request_handlers . request_handler import RequestHandler
16+ from a2a .server .context import ServerCallContext
1617from a2a .server .request_handlers .jsonrpc_handler import JSONRPCHandler
17-
18+ from a2a . server . request_handlers . request_handler import RequestHandler
1819from a2a .types import (
1920 A2AError ,
2021 A2ARequest ,
4142logger = logging .getLogger (__name__ )
4243
4344
45+ class CallContextBuilder (ABC ):
46+ """A class for building ServerCallContexts using the Starlette Request."""
47+
48+ @abstractmethod
49+ def build (self , request : Request ) -> ServerCallContext :
50+ """Builds a ServerCallContext from a Starlette Request."""
51+
52+
4453class A2AStarletteApplication :
4554 """A Starlette application implementing the A2A protocol server endpoints.
4655
@@ -49,18 +58,27 @@ class A2AStarletteApplication:
4958 (SSE).
5059 """
5160
52- def __init__ (self , agent_card : AgentCard , http_handler : RequestHandler ):
61+ def __init__ (
62+ self ,
63+ agent_card : AgentCard ,
64+ http_handler : RequestHandler ,
65+ context_builder : CallContextBuilder | None = None ,
66+ ):
5367 """Initializes the A2AStarletteApplication.
5468
5569 Args:
5670 agent_card: The AgentCard describing the agent's capabilities.
5771 http_handler: The handler instance responsible for processing A2A
5872 requests via http.
73+ context_builder: The CallContextBuilder used to construct the
74+ ServerCallContext passed to the http_handler. If None, no
75+ ServerCallContext is passed.
5976 """
6077 self .agent_card = agent_card
6178 self .handler = JSONRPCHandler (
6279 agent_card = agent_card , request_handler = http_handler
6380 )
81+ self ._context_builder = context_builder
6482
6583 def _generate_error_response (
6684 self , request_id : str | int | None , error : JSONRPCError | A2AError
@@ -122,6 +140,11 @@ async def _handle_requests(self, request: Request) -> Response:
122140 try :
123141 body = await request .json ()
124142 a2a_request = A2ARequest .model_validate (body )
143+ call_context = (
144+ self ._context_builder .build (request )
145+ if self ._context_builder
146+ else None
147+ )
125148
126149 request_id = a2a_request .root .id
127150 request_obj = a2a_request .root
@@ -131,11 +154,11 @@ async def _handle_requests(self, request: Request) -> Response:
131154 TaskResubscriptionRequest | SendStreamingMessageRequest ,
132155 ):
133156 return await self ._process_streaming_request (
134- request_id , a2a_request
157+ request_id , a2a_request , call_context
135158 )
136159
137160 return await self ._process_non_streaming_request (
138- request_id , a2a_request
161+ request_id , a2a_request , call_context
139162 )
140163 except MethodNotImplementedError :
141164 traceback .print_exc ()
@@ -161,7 +184,10 @@ async def _handle_requests(self, request: Request) -> Response:
161184 )
162185
163186 async def _process_streaming_request (
164- self , request_id : str | int | None , a2a_request : A2ARequest
187+ self ,
188+ request_id : str | int | None ,
189+ a2a_request : A2ARequest ,
190+ context : ServerCallContext ,
165191 ) -> Response :
166192 """Processes streaming requests (message/stream or tasks/resubscribe).
167193
@@ -178,14 +204,21 @@ async def _process_streaming_request(
178204 request_obj ,
179205 SendStreamingMessageRequest ,
180206 ):
181- handler_result = self .handler .on_message_send_stream (request_obj )
207+ handler_result = self .handler .on_message_send_stream (
208+ request_obj , context
209+ )
182210 elif isinstance (request_obj , TaskResubscriptionRequest ):
183- handler_result = self .handler .on_resubscribe_to_task (request_obj )
211+ handler_result = self .handler .on_resubscribe_to_task (
212+ request_obj , context
213+ )
184214
185215 return self ._create_response (handler_result )
186216
187217 async def _process_non_streaming_request (
188- self , request_id : str | int | None , a2a_request : A2ARequest
218+ self ,
219+ request_id : str | int | None ,
220+ a2a_request : A2ARequest ,
221+ context : ServerCallContext ,
189222 ) -> Response :
190223 """Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
191224
@@ -200,18 +233,26 @@ async def _process_non_streaming_request(
200233 handler_result : Any = None
201234 match request_obj :
202235 case SendMessageRequest ():
203- handler_result = await self .handler .on_message_send (request_obj )
236+ handler_result = await self .handler .on_message_send (
237+ request_obj , context
238+ )
204239 case CancelTaskRequest ():
205- handler_result = await self .handler .on_cancel_task (request_obj )
240+ handler_result = await self .handler .on_cancel_task (
241+ request_obj , context
242+ )
206243 case GetTaskRequest ():
207- handler_result = await self .handler .on_get_task (request_obj )
244+ handler_result = await self .handler .on_get_task (
245+ request_obj , context
246+ )
208247 case SetTaskPushNotificationConfigRequest ():
209248 handler_result = await self .handler .set_push_notification (
210- request_obj
249+ request_obj ,
250+ context ,
211251 )
212252 case GetTaskPushNotificationConfigRequest ():
213253 handler_result = await self .handler .get_push_notification (
214- request_obj
254+ request_obj ,
255+ context ,
215256 )
216257 case _:
217258 logger .error (
0 commit comments