33import logging
44
55from abc import ABC , abstractmethod
6- from collections .abc import AsyncIterable
6+ from collections .abc import AsyncIterable , Sequence
77
88
99try :
1010 import grpc
1111 import grpc .aio
12+
13+ from grpc .aio import Metadata
1214except ImportError as e :
1315 raise ImportError (
1416 'GrpcHandler requires grpcio and grpcio-tools to be installed. '
2022
2123from a2a import types
2224from a2a .auth .user import UnauthenticatedUser
25+ from a2a .extensions .common import (
26+ HTTP_EXTENSION_HEADER ,
27+ get_requested_extensions ,
28+ )
2329from a2a .grpc import a2a_pb2
2430from a2a .server .context import ServerCallContext
2531from a2a .server .request_handlers .request_handler import RequestHandler
@@ -42,6 +48,19 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
4248 """Builds a ServerCallContext from a gRPC Request."""
4349
4450
51+ def _get_metadata_value (
52+ context : grpc .aio .ServicerContext , key : str
53+ ) -> list [str ]:
54+ md = context .invocation_metadata
55+ raw_values : list [str | bytes ] = []
56+ if isinstance (md , Metadata ):
57+ raw_values = md .get_all (key )
58+ elif isinstance (md , Sequence ):
59+ lower_key = key .lower ()
60+ raw_values = [e for (k , e ) in md if k .lower () == lower_key ]
61+ return [e if isinstance (e , str ) else e .decode ('utf-8' ) for e in raw_values ]
62+
63+
4564class DefaultCallContextBuilder (CallContextBuilder ):
4665 """A default implementation of CallContextBuilder."""
4766
@@ -51,7 +70,13 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
5170 state = {}
5271 with contextlib .suppress (Exception ):
5372 state ['grpc_context' ] = context
54- return ServerCallContext (user = user , state = state )
73+ return ServerCallContext (
74+ user = user ,
75+ state = state ,
76+ requested_extensions = get_requested_extensions (
77+ _get_metadata_value (context , HTTP_EXTENSION_HEADER )
78+ ),
79+ )
5580
5681
5782class GrpcHandler (a2a_grpc .A2AServiceServicer ):
@@ -102,6 +127,7 @@ async def SendMessage(
102127 task_or_message = await self .request_handler .on_message_send (
103128 a2a_request , server_context
104129 )
130+ self ._set_extension_metadata (context , server_context )
105131 return proto_utils .ToProto .task_or_message (task_or_message )
106132 except ServerError as e :
107133 await self .abort_context (e , context )
@@ -140,6 +166,7 @@ async def SendStreamingMessage(
140166 a2a_request , server_context
141167 ):
142168 yield proto_utils .ToProto .stream_response (event )
169+ self ._set_extension_metadata (context , server_context )
143170 except ServerError as e :
144171 await self .abort_context (e , context )
145172 return
@@ -371,3 +398,16 @@ async def abort_context(
371398 grpc .StatusCode .UNKNOWN ,
372399 f'Unknown error type: { error .error } ' ,
373400 )
401+
402+ def _set_extension_metadata (
403+ self ,
404+ context : grpc .aio .ServicerContext ,
405+ server_context : ServerCallContext ,
406+ ) -> None :
407+ if server_context .activated_extensions :
408+ context .set_trailing_metadata (
409+ [
410+ (HTTP_EXTENSION_HEADER , e )
411+ for e in sorted (server_context .activated_extensions )
412+ ]
413+ )
0 commit comments