Skip to content

Commit 6e856d5

Browse files
committed
feat: enhance GrpcTransport to manage extensions in metadata and update related tests
1 parent c5cea2c commit 6e856d5

File tree

4 files changed

+90
-166
lines changed

4 files changed

+90
-166
lines changed

src/a2a/client/transports/grpc.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1818
from a2a.client.optionals import Channel
1919
from a2a.client.transports.base import ClientTransport
20-
from a2a.client.transports.utils import update_extension_metadata
20+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
2121
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
2222
from a2a.types import (
2323
AgentCard,
@@ -59,6 +59,12 @@ def __init__(
5959
)
6060
self.extensions = extensions
6161

62+
def _get_grpc_metadata(self) -> list[tuple[str, str]] | None:
63+
"""Creates gRPC metadata for extensions."""
64+
if not self.extensions:
65+
return None
66+
return [(HTTP_EXTENSION_HEADER, ', '.join(self.extensions))]
67+
6268
@classmethod
6369
def create(
6470
cls,
@@ -85,10 +91,9 @@ async def send_message(
8591
configuration=proto_utils.ToProto.message_send_configuration(
8692
request.configuration
8793
),
88-
metadata=update_extension_metadata(
89-
request.metadata, self.extensions
90-
),
94+
metadata=proto_utils.ToProto.metadata(request.metadata),
9195
),
96+
metadata=self._get_grpc_metadata(),
9297
)
9398
if response.HasField('task'):
9499
return proto_utils.FromProto.task(response.task)
@@ -109,10 +114,9 @@ async def send_message_streaming(
109114
configuration=proto_utils.ToProto.message_send_configuration(
110115
request.configuration
111116
),
112-
metadata=update_extension_metadata(
113-
request.metadata, self.extensions
114-
),
117+
metadata=proto_utils.ToProto.metadata(request.metadata),
115118
),
119+
metadata=self._get_grpc_metadata(),
116120
)
117121
while True:
118122
response = await stream.read()
@@ -128,9 +132,7 @@ async def resubscribe(
128132
"""Reconnects to get task updates."""
129133
stream = self.stub.TaskSubscription(
130134
a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'),
131-
metadata=update_extension_metadata(
132-
request.metadata, self.extensions
133-
),
135+
metadata=self._get_grpc_metadata(),
134136
)
135137
while True:
136138
response = await stream.read()
@@ -150,9 +152,7 @@ async def get_task(
150152
name=f'tasks/{request.id}',
151153
history_length=request.history_length,
152154
),
153-
metadata=update_extension_metadata(
154-
request.metadata, self.extensions
155-
),
155+
metadata=self._get_grpc_metadata(),
156156
)
157157
return proto_utils.FromProto.task(task)
158158

@@ -165,6 +165,7 @@ async def cancel_task(
165165
"""Requests the agent to cancel a specific task."""
166166
task = await self.stub.CancelTask(
167167
a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'),
168+
metadata=self._get_grpc_metadata(),
168169
)
169170
return proto_utils.FromProto.task(task)
170171

@@ -183,9 +184,7 @@ async def set_task_callback(
183184
request
184185
),
185186
),
186-
metadata=update_extension_metadata(
187-
request.metadata, self.extensions
188-
),
187+
metadata=self._get_grpc_metadata(),
189188
)
190189
return proto_utils.FromProto.task_push_notification_config(config)
191190

@@ -200,9 +199,7 @@ async def get_task_callback(
200199
a2a_pb2.GetTaskPushNotificationConfigRequest(
201200
name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
202201
),
203-
metadata=update_extension_metadata(
204-
request.metadata, self.extensions
205-
),
202+
metadata=self._get_grpc_metadata(),
206203
)
207204
return proto_utils.FromProto.task_push_notification_config(config)
208205

src/a2a/client/transports/utils.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
from typing import Any
22

3-
from google.protobuf import struct_pb2
4-
53
from a2a.client.middleware import ClientCallContext
64
from a2a.extensions.common import HTTP_EXTENSION_HEADER
7-
from a2a.utils import proto_utils
85

96

107
def get_http_args(context: ClientCallContext | None) -> dict[str, Any] | None:
@@ -30,20 +27,14 @@ def update_extension_header(
3027
headers = http_kwargs.setdefault('headers', {})
3128
existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '')
3229

33-
headers[HTTP_EXTENSION_HEADER] = __merge_extensions(
34-
existing_extensions_str, extensions
35-
)
36-
return http_kwargs
37-
30+
existing_extensions_list = [
31+
e.strip() for e in existing_extensions_str.split(',') if e.strip()
32+
]
33+
new_extensions = [
34+
ext for ext in extensions if ext not in existing_extensions_list
35+
]
3836

39-
def update_extension_metadata(
40-
metadata: dict[str, Any] | None, extensions: list[str] | None
41-
) -> struct_pb2.Struct | None:
42-
if metadata is None:
43-
metadata = {}
44-
if extensions:
45-
existing_extensions_str = str(metadata.get(HTTP_EXTENSION_HEADER, ''))
46-
metadata[HTTP_EXTENSION_HEADER] = __merge_extensions(
47-
existing_extensions_str, extensions
37+
headers[HTTP_EXTENSION_HEADER] = ','.join(
38+
existing_extensions_list + new_extensions
4839
)
49-
return proto_utils.ToProto.metadata(metadata if metadata else None)
40+
return http_kwargs

tests/client/transports/test_grpc_client.py

Lines changed: 64 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,14 @@ def grpc_transport(
6565
) -> GrpcTransport:
6666
"""Provides a GrpcTransport instance."""
6767
channel = AsyncMock()
68-
transport = GrpcTransport(channel=channel, agent_card=sample_agent_card)
68+
transport = GrpcTransport(
69+
channel=channel,
70+
agent_card=sample_agent_card,
71+
extensions=[
72+
'https://example.com/test-ext/v1',
73+
'https://example.com/test-ext/v2',
74+
],
75+
)
6976
transport.stub = mock_grpc_stub
7077
return transport
7178

@@ -189,6 +196,13 @@ async def test_send_message_task_response(
189196
response = await grpc_transport.send_message(sample_message_send_params)
190197

191198
mock_grpc_stub.SendMessage.assert_awaited_once()
199+
_, kwargs = mock_grpc_stub.SendMessage.call_args
200+
assert kwargs['metadata'] == [
201+
(
202+
HTTP_EXTENSION_HEADER,
203+
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
204+
)
205+
]
192206
assert isinstance(response, Task)
193207
assert response.id == sample_task.id
194208

@@ -208,6 +222,13 @@ async def test_send_message_message_response(
208222
response = await grpc_transport.send_message(sample_message_send_params)
209223

210224
mock_grpc_stub.SendMessage.assert_awaited_once()
225+
_, kwargs = mock_grpc_stub.SendMessage.call_args
226+
assert kwargs['metadata'] == [
227+
(
228+
HTTP_EXTENSION_HEADER,
229+
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
230+
)
231+
]
211232
assert isinstance(response, Message)
212233
assert response.message_id == sample_message.message_id
213234
assert get_text_parts(response.parts) == get_text_parts(
@@ -256,6 +277,13 @@ async def test_send_message_streaming( # noqa: PLR0913
256277
]
257278

258279
mock_grpc_stub.SendStreamingMessage.assert_called_once()
280+
_, kwargs = mock_grpc_stub.SendStreamingMessage.call_args
281+
assert kwargs['metadata'] == [
282+
(
283+
HTTP_EXTENSION_HEADER,
284+
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
285+
)
286+
]
259287
assert isinstance(responses[0], Message)
260288
assert responses[0].message_id == sample_message.message_id
261289
assert isinstance(responses[1], Task)
@@ -279,7 +307,13 @@ async def test_get_task(
279307
mock_grpc_stub.GetTask.assert_awaited_once_with(
280308
a2a_pb2.GetTaskRequest(
281309
name=f'tasks/{sample_task.id}', history_length=None
282-
)
310+
),
311+
metadata=[
312+
(
313+
HTTP_EXTENSION_HEADER,
314+
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
315+
)
316+
],
283317
)
284318
assert response.id == sample_task.id
285319

@@ -298,7 +332,13 @@ async def test_get_task_with_history(
298332
mock_grpc_stub.GetTask.assert_awaited_once_with(
299333
a2a_pb2.GetTaskRequest(
300334
name=f'tasks/{sample_task.id}', history_length=history_len
301-
)
335+
),
336+
metadata=[
337+
(
338+
HTTP_EXTENSION_HEADER,
339+
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
340+
)
341+
],
302342
)
303343

304344

@@ -317,7 +357,13 @@ async def test_cancel_task(
317357
response = await grpc_transport.cancel_task(params)
318358

319359
mock_grpc_stub.CancelTask.assert_awaited_once_with(
320-
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}')
360+
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'),
361+
metadata=[
362+
(
363+
HTTP_EXTENSION_HEADER,
364+
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
365+
)
366+
],
321367
)
322368
assert response.status.state == TaskState.canceled
323369

@@ -346,7 +392,13 @@ async def test_set_task_callback_with_valid_task(
346392
config=proto_utils.ToProto.task_push_notification_config(
347393
sample_task_push_notification_config
348394
),
349-
)
395+
),
396+
metadata=[
397+
(
398+
HTTP_EXTENSION_HEADER,
399+
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
400+
)
401+
],
350402
)
351403
assert response.task_id == sample_task_push_notification_config.task_id
352404

@@ -403,7 +455,13 @@ async def test_get_task_callback_with_valid_task(
403455
f'tasks/{params.id}/'
404456
f'pushNotificationConfigs/{params.push_notification_config_id}'
405457
),
406-
)
458+
),
459+
metadata=[
460+
(
461+
HTTP_EXTENSION_HEADER,
462+
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
463+
)
464+
],
407465
)
408466
assert response.task_id == sample_task_push_notification_config.task_id
409467

@@ -435,61 +493,3 @@ async def test_get_task_callback_with_invalid_task(
435493
'Bad TaskPushNotificationConfig resource name'
436494
in exc_info.value.error.message
437495
)
438-
439-
440-
@pytest.mark.asyncio
441-
async def test_send_message_with_extensions(
442-
mock_grpc_stub: AsyncMock,
443-
sample_agent_card: AgentCard,
444-
sample_message_send_params: MessageSendParams,
445-
sample_task: Task,
446-
) -> None:
447-
"""Test send_message with extensions."""
448-
extensions = ['test_extension_1', 'test_extension_2']
449-
channel = AsyncMock()
450-
transport = GrpcTransport(
451-
channel=channel, agent_card=sample_agent_card, extensions=extensions
452-
)
453-
transport.stub = mock_grpc_stub
454-
455-
mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
456-
task=proto_utils.ToProto.task(sample_task)
457-
)
458-
459-
await transport.send_message(sample_message_send_params)
460-
461-
mock_grpc_stub.SendMessage.assert_awaited_once()
462-
args, _ = mock_grpc_stub.SendMessage.call_args
463-
request = args[0]
464-
metadata = proto_utils.FromProto.metadata(request.metadata)
465-
assert HTTP_EXTENSION_HEADER in metadata
466-
assert metadata[HTTP_EXTENSION_HEADER] == 'test_extension_1,test_extension_2'
467-
468-
469-
@pytest.mark.asyncio
470-
async def test_send_message_streaming_with_extensions(
471-
mock_grpc_stub: AsyncMock,
472-
sample_agent_card: AgentCard,
473-
sample_message_send_params: MessageSendParams,
474-
) -> None:
475-
"""Test send_message_streaming with extensions."""
476-
extensions = ['test_extension_1', 'test_extension_2']
477-
channel = AsyncMock()
478-
transport = GrpcTransport(
479-
channel=channel, agent_card=sample_agent_card, extensions=extensions
480-
)
481-
transport.stub = mock_grpc_stub
482-
483-
stream = MagicMock()
484-
stream.read = AsyncMock(side_effect=[grpc.aio.EOF])
485-
mock_grpc_stub.SendStreamingMessage.return_value = stream
486-
487-
async for _ in transport.send_message_streaming(sample_message_send_params):
488-
pass
489-
490-
mock_grpc_stub.SendStreamingMessage.assert_called_once()
491-
args, _ = mock_grpc_stub.SendStreamingMessage.call_args
492-
request = args[0]
493-
metadata = proto_utils.FromProto.metadata(request.metadata)
494-
assert HTTP_EXTENSION_HEADER in metadata
495-
assert metadata[HTTP_EXTENSION_HEADER] == 'test_extension_1,test_extension_2'

0 commit comments

Comments
 (0)