11import logging
22
33from collections .abc import AsyncGenerator
4- from typing import Any
54
65
76try :
1312 "'pip install a2a-sdk[grpc]'"
1413 ) from e
1514
16- from google .protobuf import struct_pb2
1715
1816from a2a .client .client import ClientConfig
1917from a2a .client .middleware import ClientCallContext , ClientCallInterceptor
2018from a2a .client .optionals import Channel
2119from a2a .client .transports .base import ClientTransport
22- from a2a .extensions . common import HTTP_EXTENSION_HEADER
20+ from a2a .client . transports . utils import update_extension_metadata
2321from a2a .grpc import a2a_pb2 , a2a_pb2_grpc
2422from a2a .types import (
2523 AgentCard ,
@@ -87,9 +85,10 @@ async def send_message(
8785 configuration = proto_utils .ToProto .message_send_configuration (
8886 request .configuration
8987 ),
90- metadata = proto_utils .ToProto .metadata (request .metadata ),
88+ metadata = update_extension_metadata (
89+ request .metadata , self .extensions
90+ ),
9191 ),
92- metadata = self ._update_extension_metadata (),
9392 )
9493 if response .HasField ('task' ):
9594 return proto_utils .FromProto .task (response .task )
@@ -110,17 +109,17 @@ async def send_message_streaming(
110109 configuration = proto_utils .ToProto .message_send_configuration (
111110 request .configuration
112111 ),
113- metadata = proto_utils .ToProto .metadata (request .metadata ),
112+ metadata = update_extension_metadata (
113+ request .metadata , self .extensions
114+ ),
114115 ),
115- metadata = self ._update_extension_metadata (request .metadata ),
116116 )
117117 while True :
118118 response = await stream .read ()
119119 if response == grpc .aio .EOF : # pyright: ignore[reportAttributeAccessIssue]
120120 break
121121 yield proto_utils .FromProto .stream_response (response )
122122
123- # iva todo TaskIdParams has metadata
124123 async def resubscribe (
125124 self , request : TaskIdParams , * , context : ClientCallContext | None = None
126125 ) -> AsyncGenerator [
@@ -129,15 +128,16 @@ async def resubscribe(
129128 """Reconnects to get task updates."""
130129 stream = self .stub .TaskSubscription (
131130 a2a_pb2 .TaskSubscriptionRequest (name = f'tasks/{ request .id } ' ),
132- metadata = self ._update_extension_metadata (),
131+ metadata = update_extension_metadata (
132+ request .metadata , self .extensions
133+ ),
133134 )
134135 while True :
135136 response = await stream .read ()
136137 if response == grpc .aio .EOF : # pyright: ignore[reportAttributeAccessIssue]
137138 break
138139 yield proto_utils .FromProto .stream_response (response )
139140
140- # iva todo TaskQueryParams has metadata
141141 async def get_task (
142142 self ,
143143 request : TaskQueryParams ,
@@ -150,7 +150,9 @@ async def get_task(
150150 name = f'tasks/{ request .id } ' ,
151151 history_length = request .history_length ,
152152 ),
153- metadata = self ._update_extension_metadata (),
153+ metadata = update_extension_metadata (
154+ request .metadata , self .extensions
155+ ),
154156 )
155157 return proto_utils .FromProto .task (task )
156158
@@ -163,7 +165,6 @@ async def cancel_task(
163165 """Requests the agent to cancel a specific task."""
164166 task = await self .stub .CancelTask (
165167 a2a_pb2 .CancelTaskRequest (name = f'tasks/{ request .id } ' ),
166- metadata = self ._update_extension_metadata (),
167168 )
168169 return proto_utils .FromProto .task (task )
169170
@@ -182,11 +183,12 @@ async def set_task_callback(
182183 request
183184 ),
184185 ),
185- metadata = self ._update_extension_metadata (),
186+ metadata = update_extension_metadata (
187+ request .metadata , self .extensions
188+ ),
186189 )
187190 return proto_utils .FromProto .task_push_notification_config (config )
188191
189- # iva todo GetTaskPushNotificationConfigParams has metadata
190192 async def get_task_callback (
191193 self ,
192194 request : GetTaskPushNotificationConfigParams ,
@@ -198,7 +200,9 @@ async def get_task_callback(
198200 a2a_pb2 .GetTaskPushNotificationConfigRequest (
199201 name = f'tasks/{ request .id } /pushNotificationConfigs/{ request .push_notification_config_id } ' ,
200202 ),
201- metadata = self ._update_extension_metadata (),
203+ metadata = update_extension_metadata (
204+ request .metadata , self .extensions
205+ ),
202206 )
203207 return proto_utils .FromProto .task_push_notification_config (config )
204208
@@ -222,33 +226,6 @@ async def get_card(
222226 self ._needs_extended_card = False
223227 return card
224228
225- def _update_extension_metadata (
226- self , metadata : dict [str , Any ] | None = None
227- ) -> struct_pb2 .Struct | None :
228- """Gets the metadata for the gRPC call."""
229- if metadata is None :
230- metadata = {}
231-
232- if self .extensions :
233- existing_extensions_str = str (
234- metadata .get (HTTP_EXTENSION_HEADER , '' )
235- )
236- existing_extensions = {
237- e .strip ()
238- for e in existing_extensions_str .split (',' )
239- if e .strip ()
240- }
241-
242- all_extensions = set (existing_extensions )
243- all_extensions .update (self .extensions )
244-
245- if all_extensions :
246- metadata [HTTP_EXTENSION_HEADER ] = ',' .join (all_extensions )
247- elif HTTP_EXTENSION_HEADER in metadata :
248- del metadata [HTTP_EXTENSION_HEADER ]
249-
250- return proto_utils .ToProto .metadata (metadata if metadata else None )
251-
252229 async def close (self ) -> None :
253230 """Closes the gRPC channel."""
254231 await self .channel .close ()
0 commit comments