Skip to content

Commit 4d5b92c

Browse files
mikeas1gemini-code-assist[bot]holtskinner
authored
feat: Add server-side support for plumbing requested and activated extensions (#333)
# Description This adds PR adds support for reading requested extensions from a client, via the `X-A2A-Extensions` header, and plumbing that through the ServerCallContext. Here's my rough thinking on the design choices: 1. Requests to activate extensions are specified in a transport-specific manner. For HTTP, it's using headers. For gRPC, it's using side-channel metadata (which are also HTTP headers). 2. However, indicating that an extension was requested is a transport-independent concept. So, it should be passed to the transport-independent layers of the SDK. 3. We already have a means of extracting relevant information from the transport layer (and others) and passing down to the independent layers: the `ServerCallContext`. 4. So, we can put the list of requested extensions here. The transport layers are responsible for pulling out the list of requested extensions in whatever transport specific means they are communicated, and passing these down via the ServerCallContext. 5. Returning the list of extensions that were activated is slightly more annoying. I didn't want to have to refactor every layer to plumb both the request-specific response type (such as SendMessageResponse) AND the list of extensions activated. That's just too intrusive. 6. So, I need some way to communicate the list of activated extension from the independent layers back "up" to the transport layers via something other than the return value. I'm not thrilled with the decision here, but the answer is to use interior mutability: I put a mutable set in the ServerCallContext, and allow lower levels to modify it to add the activated extensions. 7. The transport layers then need to read the list of activated extensions from the ServerCallContext and pass those back in the response in the transport-defined method. The result is that we have full stack plumbing of requested and activated extensions by way of the ServerCallContext. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Holt Skinner <[email protected]> Co-authored-by: Holt Skinner <[email protected]>
1 parent 2a7f7c1 commit 4d5b92c

File tree

9 files changed

+495
-13
lines changed

9 files changed

+495
-13
lines changed

src/a2a/extensions/common.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from a2a.types import AgentCard, AgentExtension
2+
3+
4+
HTTP_EXTENSION_HEADER = 'X-A2A-Extensions'
5+
6+
7+
def get_requested_extensions(values: list[str]) -> set[str]:
8+
"""Get the set of requested extensions from an input list.
9+
10+
This handles the list containing potentially comma-separated values, as
11+
occurs when using a list in an HTTP header.
12+
"""
13+
return {
14+
stripped
15+
for v in values
16+
for ext in v.split(',')
17+
if (stripped := ext.strip())
18+
}
19+
20+
21+
def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None:
22+
"""Find an AgentExtension in an AgentCard given a uri."""
23+
for ext in card.capabilities.extensions or []:
24+
if ext.uri == uri:
25+
return ext
26+
27+
return None

src/a2a/server/agent_execution/context.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,24 @@ def metadata(self) -> dict[str, Any]:
143143
return {}
144144
return self._params.metadata or {}
145145

146+
def add_activated_extension(self, uri: str) -> None:
147+
"""Add an extension to the set of activated extensions for this request.
148+
149+
This causes the extension to be indicated back to the client in the
150+
response.
151+
"""
152+
if self._call_context:
153+
self._call_context.activated_extensions.add(uri)
154+
155+
@property
156+
def requested_extensions(self) -> set[str]:
157+
"""Extensions that the client requested to activate."""
158+
return (
159+
self._call_context.requested_extensions
160+
if self._call_context
161+
else set()
162+
)
163+
146164
def _check_or_generate_task_id(self) -> None:
147165
"""Ensures a task ID is present, generating one if necessary."""
148166
if not self._params:

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
from a2a.auth.user import UnauthenticatedUser
2121
from a2a.auth.user import User as A2AUser
22+
from a2a.extensions.common import (
23+
HTTP_EXTENSION_HEADER,
24+
get_requested_extensions,
25+
)
2226
from a2a.server.context import ServerCallContext
2327
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
2428
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -99,7 +103,13 @@ def build(self, request: Request) -> ServerCallContext:
99103
user = StarletteUserProxy(request.user)
100104
state['auth'] = request.auth
101105
state['headers'] = dict(request.headers)
102-
return ServerCallContext(user=user, state=state)
106+
return ServerCallContext(
107+
user=user,
108+
state=state,
109+
requested_extensions=get_requested_extensions(
110+
request.headers.getlist(HTTP_EXTENSION_HEADER)
111+
),
112+
)
103113

104114

105115
class JSONRPCApplication(ABC):
@@ -281,7 +291,7 @@ async def _process_streaming_request(
281291
request_obj, context
282292
)
283293

284-
return self._create_response(handler_result)
294+
return self._create_response(context, handler_result)
285295

286296
async def _process_non_streaming_request(
287297
self,
@@ -353,10 +363,11 @@ async def _process_non_streaming_request(
353363
id=request_id, error=error
354364
)
355365

356-
return self._create_response(handler_result)
366+
return self._create_response(context, handler_result)
357367

358368
def _create_response(
359369
self,
370+
context: ServerCallContext,
360371
handler_result: (
361372
AsyncGenerator[SendStreamingMessageResponse]
362373
| JSONRPCErrorResponse
@@ -372,12 +383,16 @@ def _create_response(
372383
payloads.
373384
374385
Args:
386+
context: The ServerCallContext provided to the request handler.
375387
handler_result: The result from a request handler method. Can be an
376388
async generator for streaming or a Pydantic model for non-streaming.
377389
378390
Returns:
379391
A Starlette JSONResponse or EventSourceResponse.
380392
"""
393+
headers = {}
394+
if exts := context.activated_extensions:
395+
headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts))
381396
if isinstance(handler_result, AsyncGenerator):
382397
# Result is a stream of SendStreamingMessageResponse objects
383398
async def event_generator(
@@ -386,17 +401,21 @@ async def event_generator(
386401
async for item in stream:
387402
yield {'data': item.root.model_dump_json(exclude_none=True)}
388403

389-
return EventSourceResponse(event_generator(handler_result))
404+
return EventSourceResponse(
405+
event_generator(handler_result), headers=headers
406+
)
390407
if isinstance(handler_result, JSONRPCErrorResponse):
391408
return JSONResponse(
392409
handler_result.model_dump(
393410
mode='json',
394411
exclude_none=True,
395-
)
412+
),
413+
headers=headers,
396414
)
397415

398416
return JSONResponse(
399-
handler_result.root.model_dump(mode='json', exclude_none=True)
417+
handler_result.root.model_dump(mode='json', exclude_none=True),
418+
headers=headers,
400419
)
401420

402421
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:

src/a2a/server/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@ class ServerCallContext(BaseModel):
2121

2222
state: State = Field(default={})
2323
user: User = Field(default=UnauthenticatedUser())
24+
requested_extensions: set[str] = Field(default_factory=set)
25+
activated_extensions: set[str] = Field(default_factory=set)

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import logging
44

55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterable
6+
from collections.abc import AsyncIterable, Sequence
77

88

99
try:
1010
import grpc
1111
import grpc.aio
12+
13+
from grpc.aio import Metadata
1214
except ImportError as e:
1315
raise ImportError(
1416
'GrpcHandler requires grpcio and grpcio-tools to be installed. '
@@ -20,6 +22,10 @@
2022

2123
from a2a import types
2224
from a2a.auth.user import UnauthenticatedUser
25+
from a2a.extensions.common import (
26+
HTTP_EXTENSION_HEADER,
27+
get_requested_extensions,
28+
)
2329
from a2a.grpc import a2a_pb2
2430
from a2a.server.context import ServerCallContext
2531
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -42,6 +48,19 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
4248
"""Builds a ServerCallContext from a gRPC Request."""
4349

4450

51+
def _get_metadata_value(
52+
context: grpc.aio.ServicerContext, key: str
53+
) -> list[str]:
54+
md = context.invocation_metadata
55+
raw_values: list[str | bytes] = []
56+
if isinstance(md, Metadata):
57+
raw_values = md.get_all(key)
58+
elif isinstance(md, Sequence):
59+
lower_key = key.lower()
60+
raw_values = [e for (k, e) in md if k.lower() == lower_key]
61+
return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values]
62+
63+
4564
class DefaultCallContextBuilder(CallContextBuilder):
4665
"""A default implementation of CallContextBuilder."""
4766

@@ -51,7 +70,13 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
5170
state = {}
5271
with contextlib.suppress(Exception):
5372
state['grpc_context'] = context
54-
return ServerCallContext(user=user, state=state)
73+
return ServerCallContext(
74+
user=user,
75+
state=state,
76+
requested_extensions=get_requested_extensions(
77+
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
78+
),
79+
)
5580

5681

5782
class GrpcHandler(a2a_grpc.A2AServiceServicer):
@@ -102,6 +127,7 @@ async def SendMessage(
102127
task_or_message = await self.request_handler.on_message_send(
103128
a2a_request, server_context
104129
)
130+
self._set_extension_metadata(context, server_context)
105131
return proto_utils.ToProto.task_or_message(task_or_message)
106132
except ServerError as e:
107133
await self.abort_context(e, context)
@@ -140,6 +166,7 @@ async def SendStreamingMessage(
140166
a2a_request, server_context
141167
):
142168
yield proto_utils.ToProto.stream_response(event)
169+
self._set_extension_metadata(context, server_context)
143170
except ServerError as e:
144171
await self.abort_context(e, context)
145172
return
@@ -371,3 +398,16 @@ async def abort_context(
371398
grpc.StatusCode.UNKNOWN,
372399
f'Unknown error type: {error.error}',
373400
)
401+
402+
def _set_extension_metadata(
403+
self,
404+
context: grpc.aio.ServicerContext,
405+
server_context: ServerCallContext,
406+
) -> None:
407+
if server_context.activated_extensions:
408+
context.set_trailing_metadata(
409+
[
410+
(HTTP_EXTENSION_HEADER, e)
411+
for e in sorted(server_context.activated_extensions)
412+
]
413+
)

tests/extensions/test_common.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from a2a.extensions.common import (
2+
find_extension_by_uri,
3+
get_requested_extensions,
4+
)
5+
from a2a.types import AgentCapabilities, AgentCard, AgentExtension
6+
7+
8+
def test_get_requested_extensions():
9+
assert get_requested_extensions([]) == set()
10+
assert get_requested_extensions(['foo']) == {'foo'}
11+
assert get_requested_extensions(['foo', 'bar']) == {'foo', 'bar'}
12+
assert get_requested_extensions(['foo, bar']) == {'foo', 'bar'}
13+
assert get_requested_extensions(['foo,bar']) == {'foo', 'bar'}
14+
assert get_requested_extensions(['foo', 'bar,baz']) == {'foo', 'bar', 'baz'}
15+
assert get_requested_extensions(['foo,, bar', 'baz']) == {
16+
'foo',
17+
'bar',
18+
'baz',
19+
}
20+
assert get_requested_extensions([' foo , bar ', 'baz']) == {
21+
'foo',
22+
'bar',
23+
'baz',
24+
}
25+
26+
27+
def test_find_extension_by_uri():
28+
ext1 = AgentExtension(uri='foo', description='The Foo extension')
29+
ext2 = AgentExtension(uri='bar', description='The Bar extension')
30+
card = AgentCard(
31+
name='Test Agent',
32+
description='Test Agent Description',
33+
version='1.0',
34+
url='http://test.com',
35+
skills=[],
36+
default_input_modes=['text/plain'],
37+
default_output_modes=['text/plain'],
38+
capabilities=AgentCapabilities(extensions=[ext1, ext2]),
39+
)
40+
41+
assert find_extension_by_uri(card, 'foo') == ext1
42+
assert find_extension_by_uri(card, 'bar') == ext2
43+
assert find_extension_by_uri(card, 'baz') is None
44+
45+
46+
def test_find_extension_by_uri_no_extensions():
47+
card = AgentCard(
48+
name='Test Agent',
49+
description='Test Agent Description',
50+
version='1.0',
51+
url='http://test.com',
52+
skills=[],
53+
default_input_modes=['text/plain'],
54+
default_output_modes=['text/plain'],
55+
capabilities=AgentCapabilities(extensions=None),
56+
)
57+
58+
assert find_extension_by_uri(card, 'foo') is None

tests/server/agent_execution/test_context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from a2a.server.agent_execution import RequestContext
8+
from a2a.server.context import ServerCallContext
89
from a2a.types import (
910
Message,
1011
MessageSendParams,
@@ -263,3 +264,16 @@ def test_init_with_context_id_and_existing_context_id_match(
263264

264265
assert context.context_id == mock_task.context_id
265266
assert context.current_task == mock_task
267+
268+
def test_extension_handling(self):
269+
"""Test extension handling in RequestContext."""
270+
call_context = ServerCallContext(requested_extensions={'foo', 'bar'})
271+
context = RequestContext(call_context=call_context)
272+
273+
assert context.requested_extensions == {'foo', 'bar'}
274+
275+
context.add_activated_extension('foo')
276+
assert call_context.activated_extensions == {'foo'}
277+
278+
context.add_activated_extension('baz')
279+
assert call_context.activated_extensions == {'foo', 'baz'}

0 commit comments

Comments
 (0)