Skip to content

Commit 4073c0b

Browse files
committed
refactor: streamline extension handling in BaseClient and GrpcTransport. Add helper __merge_extensions method in utils.py
1 parent 948d3f3 commit 4073c0b

File tree

4 files changed

+60
-338
lines changed

4 files changed

+60
-338
lines changed

src/a2a/client/base_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def __init__(
4242
self._card = card
4343
self._config = config
4444
self._transport = transport
45-
if self._extensions:
46-
self._config.extensions = self._extensions
45+
self._config.extensions = self._extensions
4746

4847
async def send_message(
4948
self,

src/a2a/client/transports/grpc.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
"'pip install a2a-sdk[grpc]'"
1414
) from e
1515

16+
from google.protobuf import struct_pb2
17+
1618
from a2a.client.client import ClientConfig
1719
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1820
from a2a.client.optionals import Channel
1921
from a2a.client.transports.base import ClientTransport
20-
from a2a.client.transports.utils import update_extension_header
22+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
2123
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
2224
from a2a.types import (
2325
AgentCard,
@@ -59,24 +61,6 @@ def __init__(
5961
)
6062
self.extensions = extensions
6163

62-
def _get_metadata(
63-
self, context: ClientCallContext | None
64-
) -> list[tuple[str, str]]:
65-
http_kwargs: dict[str, Any] = {}
66-
if context and context.state.get('grpc_metadata'):
67-
# Convert existing metadata to headers format for update_extension_header
68-
http_kwargs['headers'] = {
69-
k: v for k, v in context.state['grpc_metadata']
70-
}
71-
72-
updated_kwargs = update_extension_header(http_kwargs, self.extensions)
73-
74-
metadata = []
75-
if 'headers' in updated_kwargs:
76-
metadata.extend(updated_kwargs['headers'].items())
77-
78-
return metadata
79-
8064
@classmethod
8165
def create(
8266
cls,
@@ -105,7 +89,7 @@ async def send_message(
10589
),
10690
metadata=proto_utils.ToProto.metadata(request.metadata),
10791
),
108-
metadata=self._get_metadata(context),
92+
metadata=self._update_extension_metadata(),
10993
)
11094
if response.HasField('task'):
11195
return proto_utils.FromProto.task(response.task)
@@ -128,14 +112,15 @@ async def send_message_streaming(
128112
),
129113
metadata=proto_utils.ToProto.metadata(request.metadata),
130114
),
131-
metadata=self._get_metadata(context),
115+
metadata=self._update_extension_metadata(request.metadata),
132116
)
133117
while True:
134118
response = await stream.read()
135119
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
136120
break
137121
yield proto_utils.FromProto.stream_response(response)
138122

123+
# iva todo TaskIdParams has metadata
139124
async def resubscribe(
140125
self, request: TaskIdParams, *, context: ClientCallContext | None = None
141126
) -> AsyncGenerator[
@@ -144,14 +129,15 @@ async def resubscribe(
144129
"""Reconnects to get task updates."""
145130
stream = self.stub.TaskSubscription(
146131
a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'),
147-
metadata=self._get_metadata(context),
132+
metadata=self._update_extension_metadata(),
148133
)
149134
while True:
150135
response = await stream.read()
151136
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
152137
break
153138
yield proto_utils.FromProto.stream_response(response)
154139

140+
# iva todo TaskQueryParams has metadata
155141
async def get_task(
156142
self,
157143
request: TaskQueryParams,
@@ -164,7 +150,7 @@ async def get_task(
164150
name=f'tasks/{request.id}',
165151
history_length=request.history_length,
166152
),
167-
metadata=self._get_metadata(context),
153+
metadata=self._update_extension_metadata(),
168154
)
169155
return proto_utils.FromProto.task(task)
170156

@@ -177,7 +163,7 @@ async def cancel_task(
177163
"""Requests the agent to cancel a specific task."""
178164
task = await self.stub.CancelTask(
179165
a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'),
180-
metadata=self._get_metadata(context),
166+
metadata=self._update_extension_metadata(),
181167
)
182168
return proto_utils.FromProto.task(task)
183169

@@ -196,10 +182,11 @@ async def set_task_callback(
196182
request
197183
),
198184
),
199-
metadata=self._get_metadata(context),
185+
metadata=self._update_extension_metadata(),
200186
)
201187
return proto_utils.FromProto.task_push_notification_config(config)
202188

189+
# iva todo GetTaskPushNotificationConfigParams has metadata
203190
async def get_task_callback(
204191
self,
205192
request: GetTaskPushNotificationConfigParams,
@@ -211,7 +198,7 @@ async def get_task_callback(
211198
a2a_pb2.GetTaskPushNotificationConfigRequest(
212199
name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
213200
),
214-
metadata=self._get_metadata(context),
201+
metadata=self._update_extension_metadata(),
215202
)
216203
return proto_utils.FromProto.task_push_notification_config(config)
217204

@@ -235,6 +222,33 @@ async def get_card(
235222
self._needs_extended_card = False
236223
return card
237224

225+
def _update_extension_metadata(
226+
self, metadata: dict[str, Any] | None = None
227+
) -> struct_pb2.Struct | None:
228+
"""Gets the metadata for the gRPC call."""
229+
if metadata is None:
230+
metadata = {}
231+
232+
if self.extensions:
233+
existing_extensions_str = str(
234+
metadata.get(HTTP_EXTENSION_HEADER, '')
235+
)
236+
existing_extensions = {
237+
e.strip()
238+
for e in existing_extensions_str.split(',')
239+
if e.strip()
240+
}
241+
242+
all_extensions = set(existing_extensions)
243+
all_extensions.update(self.extensions)
244+
245+
if all_extensions:
246+
metadata[HTTP_EXTENSION_HEADER] = ','.join(all_extensions)
247+
elif HTTP_EXTENSION_HEADER in metadata:
248+
del metadata[HTTP_EXTENSION_HEADER]
249+
250+
return proto_utils.ToProto.metadata(metadata if metadata else None)
251+
238252
async def close(self) -> None:
239253
"""Closes the gRPC channel."""
240254
await self.channel.close()

src/a2a/client/transports/utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,33 @@ def get_http_args(context: ClientCallContext | None) -> dict[str, Any] | None:
88
return context.state.get('http_kwargs') if context else None
99

1010

11+
def __merge_extensions(
12+
existing_extensions: str, new_extensions: list[str]
13+
) -> str:
14+
existing_extensions_list = [
15+
e.strip() for e in existing_extensions.split(',') if e.strip()
16+
]
17+
existing_extensions_set = set(existing_extensions_list)
18+
new_extensions = [
19+
ext for ext in new_extensions if ext not in existing_extensions_set
20+
]
21+
22+
return ','.join(existing_extensions_list + new_extensions)
23+
24+
1125
def update_extension_header(
1226
http_kwargs: dict[str, Any], extensions: list[str] | None
1327
) -> dict[str, Any]:
1428
if not extensions:
1529
return http_kwargs
1630
headers = http_kwargs.setdefault('headers', {})
1731
existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '')
18-
existing_extensions = [
32+
"""existing_extensions = [
1933
e.strip() for e in existing_extensions_str.split(',') if e.strip()
2034
]
2135
all_extensions = set(existing_extensions)
22-
all_extensions.update(extensions)
23-
headers[HTTP_EXTENSION_HEADER] = ','.join(all_extensions)
36+
all_extensions.update(extensions)"""
37+
headers[HTTP_EXTENSION_HEADER] = __merge_extensions(
38+
existing_extensions_str, extensions
39+
)
2440
return http_kwargs

0 commit comments

Comments
 (0)