1313 "'pip install a2a-sdk[grpc]'"
1414 ) from e
1515
16+ from google .protobuf import struct_pb2
17+
1618from a2a .client .client import ClientConfig
1719from a2a .client .middleware import ClientCallContext , ClientCallInterceptor
1820from a2a .client .optionals import Channel
1921from a2a .client .transports .base import ClientTransport
20- from a2a .client . transports . utils import update_extension_header
22+ from a2a .extensions . common import HTTP_EXTENSION_HEADER
2123from a2a .grpc import a2a_pb2 , a2a_pb2_grpc
2224from a2a .types import (
2325 AgentCard ,
@@ -59,24 +61,6 @@ def __init__(
5961 )
6062 self .extensions = extensions
6163
62- def _get_metadata (
63- self , context : ClientCallContext | None
64- ) -> list [tuple [str , str ]]:
65- http_kwargs : dict [str , Any ] = {}
66- if context and context .state .get ('grpc_metadata' ):
67- # Convert existing metadata to headers format for update_extension_header
68- http_kwargs ['headers' ] = {
69- k : v for k , v in context .state ['grpc_metadata' ]
70- }
71-
72- updated_kwargs = update_extension_header (http_kwargs , self .extensions )
73-
74- metadata = []
75- if 'headers' in updated_kwargs :
76- metadata .extend (updated_kwargs ['headers' ].items ())
77-
78- return metadata
79-
8064 @classmethod
8165 def create (
8266 cls ,
@@ -105,7 +89,7 @@ async def send_message(
10589 ),
10690 metadata = proto_utils .ToProto .metadata (request .metadata ),
10791 ),
108- metadata = self ._get_metadata ( context ),
92+ metadata = self ._update_extension_metadata ( ),
10993 )
11094 if response .HasField ('task' ):
11195 return proto_utils .FromProto .task (response .task )
@@ -128,14 +112,15 @@ async def send_message_streaming(
128112 ),
129113 metadata = proto_utils .ToProto .metadata (request .metadata ),
130114 ),
131- metadata = self ._get_metadata ( context ),
115+ metadata = self ._update_extension_metadata ( request . metadata ),
132116 )
133117 while True :
134118 response = await stream .read ()
135119 if response == grpc .aio .EOF : # pyright: ignore[reportAttributeAccessIssue]
136120 break
137121 yield proto_utils .FromProto .stream_response (response )
138122
123+ # iva todo TaskIdParams has metadata
139124 async def resubscribe (
140125 self , request : TaskIdParams , * , context : ClientCallContext | None = None
141126 ) -> AsyncGenerator [
@@ -144,14 +129,15 @@ async def resubscribe(
144129 """Reconnects to get task updates."""
145130 stream = self .stub .TaskSubscription (
146131 a2a_pb2 .TaskSubscriptionRequest (name = f'tasks/{ request .id } ' ),
147- metadata = self ._get_metadata ( context ),
132+ metadata = self ._update_extension_metadata ( ),
148133 )
149134 while True :
150135 response = await stream .read ()
151136 if response == grpc .aio .EOF : # pyright: ignore[reportAttributeAccessIssue]
152137 break
153138 yield proto_utils .FromProto .stream_response (response )
154139
140+ # iva todo TaskQueryParams has metadata
155141 async def get_task (
156142 self ,
157143 request : TaskQueryParams ,
@@ -164,7 +150,7 @@ async def get_task(
164150 name = f'tasks/{ request .id } ' ,
165151 history_length = request .history_length ,
166152 ),
167- metadata = self ._get_metadata ( context ),
153+ metadata = self ._update_extension_metadata ( ),
168154 )
169155 return proto_utils .FromProto .task (task )
170156
@@ -177,7 +163,7 @@ async def cancel_task(
177163 """Requests the agent to cancel a specific task."""
178164 task = await self .stub .CancelTask (
179165 a2a_pb2 .CancelTaskRequest (name = f'tasks/{ request .id } ' ),
180- metadata = self ._get_metadata ( context ),
166+ metadata = self ._update_extension_metadata ( ),
181167 )
182168 return proto_utils .FromProto .task (task )
183169
@@ -196,10 +182,11 @@ async def set_task_callback(
196182 request
197183 ),
198184 ),
199- metadata = self ._get_metadata ( context ),
185+ metadata = self ._update_extension_metadata ( ),
200186 )
201187 return proto_utils .FromProto .task_push_notification_config (config )
202188
189+ # iva todo GetTaskPushNotificationConfigParams has metadata
203190 async def get_task_callback (
204191 self ,
205192 request : GetTaskPushNotificationConfigParams ,
@@ -211,7 +198,7 @@ async def get_task_callback(
211198 a2a_pb2 .GetTaskPushNotificationConfigRequest (
212199 name = f'tasks/{ request .id } /pushNotificationConfigs/{ request .push_notification_config_id } ' ,
213200 ),
214- metadata = self ._get_metadata ( context ),
201+ metadata = self ._update_extension_metadata ( ),
215202 )
216203 return proto_utils .FromProto .task_push_notification_config (config )
217204
@@ -235,6 +222,33 @@ async def get_card(
235222 self ._needs_extended_card = False
236223 return card
237224
225+ def _update_extension_metadata (
226+ self , metadata : dict [str , Any ] | None = None
227+ ) -> struct_pb2 .Struct | None :
228+ """Gets the metadata for the gRPC call."""
229+ if metadata is None :
230+ metadata = {}
231+
232+ if self .extensions :
233+ existing_extensions_str = str (
234+ metadata .get (HTTP_EXTENSION_HEADER , '' )
235+ )
236+ existing_extensions = {
237+ e .strip ()
238+ for e in existing_extensions_str .split (',' )
239+ if e .strip ()
240+ }
241+
242+ all_extensions = set (existing_extensions )
243+ all_extensions .update (self .extensions )
244+
245+ if all_extensions :
246+ metadata [HTTP_EXTENSION_HEADER ] = ',' .join (all_extensions )
247+ elif HTTP_EXTENSION_HEADER in metadata :
248+ del metadata [HTTP_EXTENSION_HEADER ]
249+
250+ return proto_utils .ToProto .metadata (metadata if metadata else None )
251+
238252 async def close (self ) -> None :
239253 """Closes the gRPC channel."""
240254 await self .channel .close ()
0 commit comments