Skip to content

Commit d1baac9

Browse files
committed
Introduce a ServerCallContext parameter
1 parent 8e96ac1 commit d1baac9

File tree

9 files changed

+219
-60
lines changed

9 files changed

+219
-60
lines changed

src/a2a/server/agent_execution/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uuid
22

3+
from a2a.server.context import ServerCallContext
34
from a2a.types import (
45
InvalidParamsError,
56
Message,
@@ -26,6 +27,7 @@ def __init__(
2627
context_id: str | None = None,
2728
task: Task | None = None,
2829
related_tasks: list[Task] | None = None,
30+
call_context: ServerCallContext | None = None,
2931
):
3032
"""Initializes the RequestContext.
3133
@@ -43,6 +45,7 @@ def __init__(
4345
self._context_id = context_id
4446
self._current_task = task
4547
self._related_tasks = related_tasks
48+
self._call_context = call_context
4649
# If the task id and context id were provided, make sure they
4750
# match the request. Otherwise, create them
4851
if self._params:
@@ -125,6 +128,11 @@ def configuration(self) -> MessageSendConfiguration | None:
125128
return None
126129
return self._params.configuration
127130

131+
@property
132+
def call_context(self) -> ServerCallContext | None:
133+
"""The server call context associated with this request."""
134+
return self._call_context
135+
128136
def _check_or_generate_task_id(self) -> None:
129137
"""Ensures a task ID is present, generating one if necessary."""
130138
if not self._params:

src/a2a/server/agent_execution/request_context_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
22

33
from a2a.server.agent_execution import RequestContext
4+
from a2a.server.context import ServerCallContext
45
from a2a.types import MessageSendParams, Task
56

67

@@ -14,5 +15,6 @@ async def build(
1415
task_id: str | None = None,
1516
context_id: str | None = None,
1617
task: Task | None = None,
18+
context: ServerCallContext | None = None,
1719
) -> RequestContext:
1820
pass

src/a2a/server/agent_execution/simple_request_context_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22

33
from a2a.server.agent_execution import RequestContext, RequestContextBuilder
4+
from a2a.server.context import ServerCallContext
45
from a2a.server.tasks import TaskStore
56
from a2a.types import MessageSendParams, Task
67

@@ -22,6 +23,7 @@ async def build(
2223
task_id: str | None = None,
2324
context_id: str | None = None,
2425
task: Task | None = None,
26+
context: ServerCallContext | None = None,
2527
) -> RequestContext:
2628
related_tasks: list[Task] | None = None
2729

@@ -45,4 +47,5 @@ async def build(
4547
context_id=context_id,
4648
task=task,
4749
related_tasks=related_tasks,
50+
call_context=context,
4851
)

src/a2a/server/apps/starlette_app.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import traceback
44

5+
from abc import ABC, abstractmethod
56
from collections.abc import AsyncGenerator
67
from typing import Any
78

@@ -12,9 +13,9 @@
1213
from starlette.responses import JSONResponse, Response
1314
from starlette.routing import Route
1415

15-
from a2a.server.request_handlers.request_handler import RequestHandler
16+
from a2a.server.context import ServerCallContext
1617
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
17-
18+
from a2a.server.request_handlers.request_handler import RequestHandler
1819
from a2a.types import (
1920
A2AError,
2021
A2ARequest,
@@ -41,6 +42,29 @@
4142
logger = logging.getLogger(__name__)
4243

4344

45+
class CallContextBuilder(ABC):
46+
"""A class for building ServerCallContexts using the Starlette Request."""
47+
48+
@abstractmethod
49+
def build(self, request: Request) -> ServerCallContext:
50+
"""Builds a ServerCallContext from a Starlette Request."""
51+
52+
53+
class DefaultCallContextBuilder(CallContextBuilder):
54+
"""A default implementation of StarletteCallContextBuilder.
55+
56+
This stores the incoming starlette request using the "starlette_request"
57+
key of the context state.
58+
"""
59+
60+
def build(self, request: Request) -> ServerCallContext:
61+
return ServerCallContext(
62+
state={
63+
'starlette_request': request,
64+
}
65+
)
66+
67+
4468
class A2AStarletteApplication:
4569
"""A Starlette application implementing the A2A protocol server endpoints.
4670
@@ -49,18 +73,25 @@ class A2AStarletteApplication:
4973
(SSE).
5074
"""
5175

52-
def __init__(self, agent_card: AgentCard, http_handler: RequestHandler):
76+
def __init__(
77+
self,
78+
agent_card: AgentCard,
79+
http_handler: RequestHandler,
80+
context_builder: CallContextBuilder | None = None,
81+
):
5382
"""Initializes the A2AStarletteApplication.
5483
5584
Args:
5685
agent_card: The AgentCard describing the agent's capabilities.
5786
http_handler: The handler instance responsible for processing A2A
5887
requests via http.
88+
context_builder:
5989
"""
6090
self.agent_card = agent_card
6191
self.handler = JSONRPCHandler(
6292
agent_card=agent_card, request_handler=http_handler
6393
)
94+
self._context_builder = context_builder or DefaultCallContextBuilder()
6495

6596
def _generate_error_response(
6697
self, request_id: str | int | None, error: JSONRPCError | A2AError
@@ -122,6 +153,7 @@ async def _handle_requests(self, request: Request) -> Response:
122153
try:
123154
body = await request.json()
124155
a2a_request = A2ARequest.model_validate(body)
156+
call_context = self._context_builder.build(request)
125157

126158
request_id = a2a_request.root.id
127159
request_obj = a2a_request.root
@@ -131,11 +163,11 @@ async def _handle_requests(self, request: Request) -> Response:
131163
TaskResubscriptionRequest | SendStreamingMessageRequest,
132164
):
133165
return await self._process_streaming_request(
134-
request_id, a2a_request
166+
request_id, a2a_request, call_context
135167
)
136168

137169
return await self._process_non_streaming_request(
138-
request_id, a2a_request
170+
request_id, a2a_request, call_context
139171
)
140172
except MethodNotImplementedError:
141173
traceback.print_exc()
@@ -161,7 +193,10 @@ async def _handle_requests(self, request: Request) -> Response:
161193
)
162194

163195
async def _process_streaming_request(
164-
self, request_id: str | int | None, a2a_request: A2ARequest
196+
self,
197+
request_id: str | int | None,
198+
a2a_request: A2ARequest,
199+
context: ServerCallContext,
165200
) -> Response:
166201
"""Processes streaming requests (message/stream or tasks/resubscribe).
167202
@@ -178,14 +213,21 @@ async def _process_streaming_request(
178213
request_obj,
179214
SendStreamingMessageRequest,
180215
):
181-
handler_result = self.handler.on_message_send_stream(request_obj)
216+
handler_result = self.handler.on_message_send_stream(
217+
request_obj, context
218+
)
182219
elif isinstance(request_obj, TaskResubscriptionRequest):
183-
handler_result = self.handler.on_resubscribe_to_task(request_obj)
220+
handler_result = self.handler.on_resubscribe_to_task(
221+
request_obj, context
222+
)
184223

185224
return self._create_response(handler_result)
186225

187226
async def _process_non_streaming_request(
188-
self, request_id: str | int | None, a2a_request: A2ARequest
227+
self,
228+
request_id: str | int | None,
229+
a2a_request: A2ARequest,
230+
context: ServerCallContext,
189231
) -> Response:
190232
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
191233
@@ -200,18 +242,26 @@ async def _process_non_streaming_request(
200242
handler_result: Any = None
201243
match request_obj:
202244
case SendMessageRequest():
203-
handler_result = await self.handler.on_message_send(request_obj)
245+
handler_result = await self.handler.on_message_send(
246+
request_obj, context
247+
)
204248
case CancelTaskRequest():
205-
handler_result = await self.handler.on_cancel_task(request_obj)
249+
handler_result = await self.handler.on_cancel_task(
250+
request_obj, context
251+
)
206252
case GetTaskRequest():
207-
handler_result = await self.handler.on_get_task(request_obj)
253+
handler_result = await self.handler.on_get_task(
254+
request_obj, context
255+
)
208256
case SetTaskPushNotificationConfigRequest():
209257
handler_result = await self.handler.set_push_notification(
210-
request_obj
258+
request_obj,
259+
context,
211260
)
212261
case GetTaskPushNotificationConfigRequest():
213262
handler_result = await self.handler.get_push_notification(
214-
request_obj
263+
request_obj,
264+
context,
215265
)
216266
case _:
217267
logger.error(

src/a2a/server/context.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Defines the ServerCallContext class."""
2+
3+
import collections.abc
4+
import typing
5+
6+
7+
State = collections.abc.MutableMapping[str, typing.Any]
8+
9+
10+
class ServerCallContext:
11+
"""A context passed when calling a server method.
12+
13+
This class allows storing arbitrary user data in the state attribute.
14+
"""
15+
16+
def __init__(self, state: State | None = None):
17+
if state is None:
18+
state = {}
19+
self._state = state
20+
21+
@property
22+
def state(self) -> State:
23+
"""Get the user-provided state."""
24+
return self._state

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
RequestContextBuilder,
1111
SimpleRequestContextBuilder,
1212
)
13+
from a2a.server.context import ServerCallContext
1314
from a2a.server.events import (
1415
Event,
1516
EventConsumer,
@@ -70,6 +71,8 @@ def __init__(
7071
task_store: The `TaskStore` instance to manage task persistence.
7172
queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`.
7273
push_notifier: The `PushNotifier` instance for sending push notifications. Defaults to None.
74+
request_context_builder: The `RequestContextBuilder` instance used
75+
to build request contexts. Defaults to `SimpleRequestContextBuilder`.
7376
"""
7477
self.agent_executor = agent_executor
7578
self.task_store = task_store
@@ -85,14 +88,20 @@ def __init__(
8588
self._running_agents = {}
8689
self._running_agents_lock = asyncio.Lock()
8790

88-
async def on_get_task(self, params: TaskQueryParams) -> Task | None:
91+
async def on_get_task(
92+
self,
93+
params: TaskQueryParams,
94+
context: ServerCallContext | None = None,
95+
) -> Task | None:
8996
"""Default handler for 'tasks/get'."""
9097
task: Task | None = await self.task_store.get(params.id)
9198
if not task:
9299
raise ServerError(error=TaskNotFoundError())
93100
return task
94101

95-
async def on_cancel_task(self, params: TaskIdParams) -> Task | None:
102+
async def on_cancel_task(
103+
self, params: TaskIdParams, context: ServerCallContext | None = None
104+
) -> Task | None:
96105
"""Default handler for 'tasks/cancel'.
97106
98107
Attempts to cancel the task managed by the `AgentExecutor`.
@@ -150,7 +159,9 @@ async def _run_event_stream(
150159
await queue.close()
151160

152161
async def on_message_send(
153-
self, params: MessageSendParams
162+
self,
163+
params: MessageSendParams,
164+
context: ServerCallContext | None = None,
154165
) -> Message | Task:
155166
"""Default handler for 'message/send' interface (non-streaming).
156167
@@ -183,6 +194,7 @@ async def on_message_send(
183194
task_id=task.id if task else None,
184195
context_id=params.message.contextId,
185196
task=task,
197+
context=context,
186198
)
187199

188200
task_id = cast(str, request_context.task_id)
@@ -232,7 +244,9 @@ async def on_message_send(
232244
return result
233245

234246
async def on_message_send_stream(
235-
self, params: MessageSendParams
247+
self,
248+
params: MessageSendParams,
249+
context: ServerCallContext | None = None,
236250
) -> AsyncGenerator[Event]:
237251
"""Default handler for 'message/stream' (streaming).
238252
@@ -270,6 +284,7 @@ async def on_message_send_stream(
270284
task_id=task.id if task else None,
271285
context_id=params.message.contextId,
272286
task=task,
287+
context=context,
273288
)
274289

275290
task_id = cast(str, request_context.task_id)
@@ -334,7 +349,9 @@ async def _cleanup_producer(
334349
self._running_agents.pop(task_id, None)
335350

336351
async def on_set_task_push_notification_config(
337-
self, params: TaskPushNotificationConfig
352+
self,
353+
params: TaskPushNotificationConfig,
354+
context: ServerCallContext | None = None,
338355
) -> TaskPushNotificationConfig:
339356
"""Default handler for 'tasks/pushNotificationConfig/set'.
340357
@@ -355,7 +372,9 @@ async def on_set_task_push_notification_config(
355372
return params
356373

357374
async def on_get_task_push_notification_config(
358-
self, params: TaskIdParams
375+
self,
376+
params: TaskIdParams,
377+
context: ServerCallContext | None = None,
359378
) -> TaskPushNotificationConfig:
360379
"""Default handler for 'tasks/pushNotificationConfig/get'.
361380
@@ -377,7 +396,9 @@ async def on_get_task_push_notification_config(
377396
)
378397

379398
async def on_resubscribe_to_task(
380-
self, params: TaskIdParams
399+
self,
400+
params: TaskIdParams,
401+
context: ServerCallContext | None = None,
381402
) -> AsyncGenerator[Event]:
382403
"""Default handler for 'tasks/resubscribe'.
383404

0 commit comments

Comments
 (0)