Skip to content

Commit c5cea2c

Browse files
committed
Move transport tests from tests/client to tests/client/transport. Add method "update_extension_metadata" to tansports/utils.py.
1 parent 4073c0b commit c5cea2c

File tree

8 files changed

+215
-91
lines changed

8 files changed

+215
-91
lines changed

src/a2a/client/base_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def __init__(
4040
):
4141
super().__init__(consumers, middleware, extensions)
4242
self._card = card
43+
config.extensions = extensions
4344
self._config = config
4445
self._transport = transport
45-
self._config.extensions = self._extensions
4646

4747
async def send_message(
4848
self,

src/a2a/client/client.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def __init__(
9393
self,
9494
consumers: list[Consumer] | None = None,
9595
middleware: list[ClientCallInterceptor] | None = None,
96-
# iva todo - it can override value from the config, if it is provided
9796
extensions: list[str] | None = None,
9897
):
9998
"""Initializes the client with consumers and middleware.
@@ -118,8 +117,6 @@ async def send_message(
118117
request: Message,
119118
*,
120119
context: ClientCallContext | None = None,
121-
# iva todo add optional extensions- it can override value from the config, if it is provided
122-
# and to the other ones as well
123120
) -> AsyncIterator[ClientEvent | Message]:
124121
"""Sends a message to the server.
125122

src/a2a/client/transports/grpc.py

Lines changed: 19 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22

33
from collections.abc import AsyncGenerator
4-
from typing import Any
54

65

76
try:
@@ -13,13 +12,12 @@
1312
"'pip install a2a-sdk[grpc]'"
1413
) from e
1514

16-
from google.protobuf import struct_pb2
1715

1816
from a2a.client.client import ClientConfig
1917
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
2018
from a2a.client.optionals import Channel
2119
from a2a.client.transports.base import ClientTransport
22-
from a2a.extensions.common import HTTP_EXTENSION_HEADER
20+
from a2a.client.transports.utils import update_extension_metadata
2321
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
2422
from a2a.types import (
2523
AgentCard,
@@ -87,9 +85,10 @@ async def send_message(
8785
configuration=proto_utils.ToProto.message_send_configuration(
8886
request.configuration
8987
),
90-
metadata=proto_utils.ToProto.metadata(request.metadata),
88+
metadata=update_extension_metadata(
89+
request.metadata, self.extensions
90+
),
9191
),
92-
metadata=self._update_extension_metadata(),
9392
)
9493
if response.HasField('task'):
9594
return proto_utils.FromProto.task(response.task)
@@ -110,17 +109,17 @@ async def send_message_streaming(
110109
configuration=proto_utils.ToProto.message_send_configuration(
111110
request.configuration
112111
),
113-
metadata=proto_utils.ToProto.metadata(request.metadata),
112+
metadata=update_extension_metadata(
113+
request.metadata, self.extensions
114+
),
114115
),
115-
metadata=self._update_extension_metadata(request.metadata),
116116
)
117117
while True:
118118
response = await stream.read()
119119
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
120120
break
121121
yield proto_utils.FromProto.stream_response(response)
122122

123-
# iva todo TaskIdParams has metadata
124123
async def resubscribe(
125124
self, request: TaskIdParams, *, context: ClientCallContext | None = None
126125
) -> AsyncGenerator[
@@ -129,15 +128,16 @@ async def resubscribe(
129128
"""Reconnects to get task updates."""
130129
stream = self.stub.TaskSubscription(
131130
a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'),
132-
metadata=self._update_extension_metadata(),
131+
metadata=update_extension_metadata(
132+
request.metadata, self.extensions
133+
),
133134
)
134135
while True:
135136
response = await stream.read()
136137
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
137138
break
138139
yield proto_utils.FromProto.stream_response(response)
139140

140-
# iva todo TaskQueryParams has metadata
141141
async def get_task(
142142
self,
143143
request: TaskQueryParams,
@@ -150,7 +150,9 @@ async def get_task(
150150
name=f'tasks/{request.id}',
151151
history_length=request.history_length,
152152
),
153-
metadata=self._update_extension_metadata(),
153+
metadata=update_extension_metadata(
154+
request.metadata, self.extensions
155+
),
154156
)
155157
return proto_utils.FromProto.task(task)
156158

@@ -163,7 +165,6 @@ async def cancel_task(
163165
"""Requests the agent to cancel a specific task."""
164166
task = await self.stub.CancelTask(
165167
a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'),
166-
metadata=self._update_extension_metadata(),
167168
)
168169
return proto_utils.FromProto.task(task)
169170

@@ -182,11 +183,12 @@ async def set_task_callback(
182183
request
183184
),
184185
),
185-
metadata=self._update_extension_metadata(),
186+
metadata=update_extension_metadata(
187+
request.metadata, self.extensions
188+
),
186189
)
187190
return proto_utils.FromProto.task_push_notification_config(config)
188191

189-
# iva todo GetTaskPushNotificationConfigParams has metadata
190192
async def get_task_callback(
191193
self,
192194
request: GetTaskPushNotificationConfigParams,
@@ -198,7 +200,9 @@ async def get_task_callback(
198200
a2a_pb2.GetTaskPushNotificationConfigRequest(
199201
name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
200202
),
201-
metadata=self._update_extension_metadata(),
203+
metadata=update_extension_metadata(
204+
request.metadata, self.extensions
205+
),
202206
)
203207
return proto_utils.FromProto.task_push_notification_config(config)
204208

@@ -222,33 +226,6 @@ async def get_card(
222226
self._needs_extended_card = False
223227
return card
224228

225-
def _update_extension_metadata(
226-
self, metadata: dict[str, Any] | None = None
227-
) -> struct_pb2.Struct | None:
228-
"""Gets the metadata for the gRPC call."""
229-
if metadata is None:
230-
metadata = {}
231-
232-
if self.extensions:
233-
existing_extensions_str = str(
234-
metadata.get(HTTP_EXTENSION_HEADER, '')
235-
)
236-
existing_extensions = {
237-
e.strip()
238-
for e in existing_extensions_str.split(',')
239-
if e.strip()
240-
}
241-
242-
all_extensions = set(existing_extensions)
243-
all_extensions.update(self.extensions)
244-
245-
if all_extensions:
246-
metadata[HTTP_EXTENSION_HEADER] = ','.join(all_extensions)
247-
elif HTTP_EXTENSION_HEADER in metadata:
248-
del metadata[HTTP_EXTENSION_HEADER]
249-
250-
return proto_utils.ToProto.metadata(metadata if metadata else None)
251-
252229
async def close(self) -> None:
253230
"""Closes the gRPC channel."""
254231
await self.channel.close()

src/a2a/client/transports/utils.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from typing import Any
22

3+
from google.protobuf import struct_pb2
4+
35
from a2a.client.middleware import ClientCallContext
46
from a2a.extensions.common import HTTP_EXTENSION_HEADER
7+
from a2a.utils import proto_utils
58

69

710
def get_http_args(context: ClientCallContext | None) -> dict[str, Any] | None:
@@ -14,27 +17,33 @@ def __merge_extensions(
1417
existing_extensions_list = [
1518
e.strip() for e in existing_extensions.split(',') if e.strip()
1619
]
17-
existing_extensions_set = set(existing_extensions_list)
1820
new_extensions = [
19-
ext for ext in new_extensions if ext not in existing_extensions_set
21+
ext for ext in new_extensions if ext not in existing_extensions_list
2022
]
21-
2223
return ','.join(existing_extensions_list + new_extensions)
2324

2425

2526
def update_extension_header(
2627
http_kwargs: dict[str, Any], extensions: list[str] | None
2728
) -> dict[str, Any]:
28-
if not extensions:
29-
return http_kwargs
30-
headers = http_kwargs.setdefault('headers', {})
31-
existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '')
32-
"""existing_extensions = [
33-
e.strip() for e in existing_extensions_str.split(',') if e.strip()
34-
]
35-
all_extensions = set(existing_extensions)
36-
all_extensions.update(extensions)"""
37-
headers[HTTP_EXTENSION_HEADER] = __merge_extensions(
38-
existing_extensions_str, extensions
39-
)
29+
if extensions:
30+
headers = http_kwargs.setdefault('headers', {})
31+
existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '')
32+
33+
headers[HTTP_EXTENSION_HEADER] = __merge_extensions(
34+
existing_extensions_str, extensions
35+
)
4036
return http_kwargs
37+
38+
39+
def update_extension_metadata(
40+
metadata: dict[str, Any] | None, extensions: list[str] | None
41+
) -> struct_pb2.Struct | None:
42+
if metadata is None:
43+
metadata = {}
44+
if extensions:
45+
existing_extensions_str = str(metadata.get(HTTP_EXTENSION_HEADER, ''))
46+
metadata[HTTP_EXTENSION_HEADER] = __merge_extensions(
47+
existing_extensions_str, extensions
48+
)
49+
return proto_utils.ToProto.metadata(metadata if metadata else None)

0 commit comments

Comments
 (0)