Skip to content

Commit 1337dcf

Browse files
committed
refactor: streamline extension handling in transport classes and update related tests
1 parent 0746541 commit 1337dcf

File tree

8 files changed

+110
-126
lines changed

8 files changed

+110
-126
lines changed

src/a2a/client/transports/grpc.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ def _get_grpc_metadata(
6464
extensions: list[str] | None = None,
6565
) -> list[tuple[str, str]] | None:
6666
"""Creates gRPC metadata for extensions."""
67-
if extensions:
68-
self.extensions = extensions
69-
if not self.extensions:
70-
return None
71-
return [(HTTP_EXTENSION_HEADER, ', '.join(self.extensions))]
67+
if extensions is not None:
68+
return [(HTTP_EXTENSION_HEADER, ','.join(extensions))]
69+
if self.extensions is not None:
70+
return [(HTTP_EXTENSION_HEADER, ','.join(self.extensions))]
71+
return None
7272

7373
@classmethod
7474
def create(
@@ -233,6 +233,7 @@ async def get_card(
233233

234234
card_pb = await self.stub.GetAgentCard(
235235
a2a_pb2.GetAgentCardRequest(),
236+
metadata=self._get_grpc_metadata(extensions),
236237
)
237238
card = proto_utils.FromProto.agent_card(card_pb)
238239
self.agent_card = card

src/a2a/client/transports/jsonrpc.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ async def send_message(
126126
self._get_http_args(context),
127127
context,
128128
)
129-
modified_kwargs, self.extensions = update_extension_header(
130-
modified_kwargs, self.extensions, extensions
129+
modified_kwargs = update_extension_header(
130+
modified_kwargs,
131+
extensions if extensions is not None else self.extensions,
131132
)
132133
response_data = await self._send_request(payload, modified_kwargs)
133134
response = SendMessageResponse.model_validate(response_data)
@@ -155,8 +156,9 @@ async def send_message_streaming(
155156
context,
156157
)
157158

158-
modified_kwargs, self.extensions = update_extension_header(
159-
modified_kwargs, self.extensions, extensions
159+
modified_kwargs = update_extension_header(
160+
modified_kwargs,
161+
extensions if extensions is not None else self.extensions,
160162
)
161163
modified_kwargs.setdefault(
162164
'timeout', self.httpx_client.timeout.as_dict().get('read', None)
@@ -228,8 +230,9 @@ async def get_task(
228230
self._get_http_args(context),
229231
context,
230232
)
231-
modified_kwargs, self.extensions = update_extension_header(
232-
modified_kwargs, self.extensions, extensions
233+
modified_kwargs = update_extension_header(
234+
modified_kwargs,
235+
extensions if extensions is not None else self.extensions,
233236
)
234237
response_data = await self._send_request(payload, modified_kwargs)
235238
response = GetTaskResponse.model_validate(response_data)
@@ -252,8 +255,9 @@ async def cancel_task(
252255
self._get_http_args(context),
253256
context,
254257
)
255-
modified_kwargs, self.extensions = update_extension_header(
256-
modified_kwargs, self.extensions, extensions
258+
modified_kwargs = update_extension_header(
259+
modified_kwargs,
260+
extensions if extensions is not None else self.extensions,
257261
)
258262
response_data = await self._send_request(payload, modified_kwargs)
259263
response = CancelTaskResponse.model_validate(response_data)
@@ -278,8 +282,9 @@ async def set_task_callback(
278282
self._get_http_args(context),
279283
context,
280284
)
281-
modified_kwargs, self.extensions = update_extension_header(
282-
modified_kwargs, self.extensions, extensions
285+
modified_kwargs = update_extension_header(
286+
modified_kwargs,
287+
extensions if extensions is not None else self.extensions,
283288
)
284289
response_data = await self._send_request(payload, modified_kwargs)
285290
response = SetTaskPushNotificationConfigResponse.model_validate(
@@ -306,8 +311,9 @@ async def get_task_callback(
306311
self._get_http_args(context),
307312
context,
308313
)
309-
modified_kwargs, self.extensions = update_extension_header(
310-
modified_kwargs, self.extensions, extensions
314+
modified_kwargs = update_extension_header(
315+
modified_kwargs,
316+
extensions if extensions is not None else self.extensions,
311317
)
312318
response_data = await self._send_request(payload, modified_kwargs)
313319
response = GetTaskPushNotificationConfigResponse.model_validate(
@@ -334,8 +340,9 @@ async def resubscribe(
334340
self._get_http_args(context),
335341
context,
336342
)
337-
modified_kwargs, self.extensions = update_extension_header(
338-
modified_kwargs, self.extensions, extensions
343+
modified_kwargs = update_extension_header(
344+
modified_kwargs,
345+
extensions if extensions is not None else self.extensions,
339346
)
340347
modified_kwargs.setdefault('timeout', None)
341348

@@ -393,8 +400,9 @@ async def get_card(
393400
self._get_http_args(context),
394401
context,
395402
)
396-
modified_kwargs, self.extensions = update_extension_header(
397-
modified_kwargs, self.extensions, extensions
403+
modified_kwargs = update_extension_header(
404+
modified_kwargs,
405+
extensions if extensions is not None else self.extensions,
398406
)
399407

400408
response_data = await self._send_request(

src/a2a/client/transports/rest.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ async def _prepare_send_message(
104104
self._get_http_args(context),
105105
context,
106106
)
107-
modified_kwargs, self.extensions = update_extension_header(
108-
modified_kwargs, self.extensions, extensions
107+
modified_kwargs = update_extension_header(
108+
modified_kwargs,
109+
extensions if extensions is not None else self.extensions,
109110
)
110111
return payload, modified_kwargs
111112

@@ -223,8 +224,9 @@ async def get_task(
223224
self._get_http_args(context),
224225
context,
225226
)
226-
modified_kwargs, self.extensions = update_extension_header(
227-
modified_kwargs, self.extensions, extensions
227+
modified_kwargs = update_extension_header(
228+
modified_kwargs,
229+
extensions if extensions is not None else self.extensions,
228230
)
229231
response_data = await self._send_get_request(
230232
f'/v1/tasks/{request.id}',
@@ -252,8 +254,9 @@ async def cancel_task(
252254
self._get_http_args(context),
253255
context,
254256
)
255-
modified_kwargs, self.extensions = update_extension_header(
256-
modified_kwargs, self.extensions, extensions
257+
modified_kwargs = update_extension_header(
258+
modified_kwargs,
259+
extensions if extensions is not None else self.extensions,
257260
)
258261
response_data = await self._send_post_request(
259262
f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs
@@ -279,8 +282,9 @@ async def set_task_callback(
279282
payload, modified_kwargs = await self._apply_interceptors(
280283
payload, self._get_http_args(context), context
281284
)
282-
modified_kwargs, self.extensions = update_extension_header(
283-
modified_kwargs, self.extensions, extensions
285+
modified_kwargs = update_extension_header(
286+
modified_kwargs,
287+
extensions if extensions is not None else self.extensions,
284288
)
285289
response_data = await self._send_post_request(
286290
f'/v1/tasks/{request.task_id}/pushNotificationConfigs',
@@ -308,8 +312,9 @@ async def get_task_callback(
308312
self._get_http_args(context),
309313
context,
310314
)
311-
modified_kwargs, self.extensions = update_extension_header(
312-
modified_kwargs, self.extensions, extensions
315+
modified_kwargs = update_extension_header(
316+
modified_kwargs,
317+
extensions if extensions is not None else self.extensions,
313318
)
314319
response_data = await self._send_get_request(
315320
f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
@@ -332,8 +337,9 @@ async def resubscribe(
332337
"""Reconnects to get task updates."""
333338
http_kwargs = self._get_http_args(context) or {}
334339
http_kwargs.setdefault('timeout', None)
335-
modified_kwargs, self.extensions = update_extension_header(
336-
http_kwargs, self.extensions, extensions
340+
modified_kwargs = update_extension_header(
341+
http_kwargs,
342+
extensions if extensions is not None else self.extensions,
337343
)
338344

339345
async with aconnect_sse(
@@ -384,8 +390,9 @@ async def get_card(
384390
self._get_http_args(context),
385391
context,
386392
)
387-
modified_kwargs, self.extensions = update_extension_header(
388-
modified_kwargs, self.extensions, extensions
393+
modified_kwargs = update_extension_header(
394+
modified_kwargs,
395+
extensions if extensions is not None else self.extensions,
389396
)
390397
response_data = await self._send_get_request(
391398
'/v1/card', {}, modified_kwargs

src/a2a/extensions/common.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,10 @@ def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None:
3131

3232
def update_extension_header(
3333
http_kwargs: dict[str, Any],
34-
active_extensions: list[str] | None,
35-
new_extensions: list[str] | None,
36-
) -> tuple[dict[str, Any], list[str] | None]:
34+
extensions: list[str] | None,
35+
) -> dict[str, Any]:
3736
"""Update the X-A2A-Extensions header and update active extensions."""
38-
if new_extensions:
39-
active_extensions = new_extensions
40-
if active_extensions:
37+
if extensions is not None:
4138
headers = http_kwargs.setdefault('headers', {})
42-
existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '')
43-
44-
existing_extensions = get_requested_extensions(
45-
[existing_extensions_str]
46-
)
47-
all_extensions = existing_extensions.union(active_extensions)
48-
headers[HTTP_EXTENSION_HEADER] = ','.join(all_extensions)
49-
return http_kwargs, active_extensions
39+
headers[HTTP_EXTENSION_HEADER] = ','.join(extensions)
40+
return http_kwargs

tests/client/transports/test_grpc_client.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -193,14 +193,17 @@ async def test_send_message_task_response(
193193
task=proto_utils.ToProto.task(sample_task)
194194
)
195195

196-
response = await grpc_transport.send_message(sample_message_send_params)
196+
response = await grpc_transport.send_message(
197+
sample_message_send_params,
198+
extensions=['https://example.com/test-ext/v3'],
199+
)
197200

198201
mock_grpc_stub.SendMessage.assert_awaited_once()
199202
_, kwargs = mock_grpc_stub.SendMessage.call_args
200203
assert kwargs['metadata'] == [
201204
(
202205
HTTP_EXTENSION_HEADER,
203-
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
206+
'https://example.com/test-ext/v3',
204207
)
205208
]
206209
assert isinstance(response, Task)
@@ -226,7 +229,7 @@ async def test_send_message_message_response(
226229
assert kwargs['metadata'] == [
227230
(
228231
HTTP_EXTENSION_HEADER,
229-
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
232+
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
230233
)
231234
]
232235
assert isinstance(response, Message)
@@ -281,7 +284,7 @@ async def test_send_message_streaming( # noqa: PLR0913
281284
assert kwargs['metadata'] == [
282285
(
283286
HTTP_EXTENSION_HEADER,
284-
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
287+
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
285288
)
286289
]
287290
assert isinstance(responses[0], Message)
@@ -311,7 +314,7 @@ async def test_get_task(
311314
metadata=[
312315
(
313316
HTTP_EXTENSION_HEADER,
314-
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
317+
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
315318
)
316319
],
317320
)
@@ -336,7 +339,7 @@ async def test_get_task_with_history(
336339
metadata=[
337340
(
338341
HTTP_EXTENSION_HEADER,
339-
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
342+
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
340343
)
341344
],
342345
)
@@ -393,7 +396,7 @@ async def test_set_task_callback_with_valid_task(
393396
metadata=[
394397
(
395398
HTTP_EXTENSION_HEADER,
396-
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
399+
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
397400
)
398401
],
399402
)
@@ -456,7 +459,7 @@ async def test_get_task_callback_with_valid_task(
456459
metadata=[
457460
(
458461
HTTP_EXTENSION_HEADER,
459-
'https://example.com/test-ext/v1, https://example.com/test-ext/v2',
462+
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
460463
)
461464
],
462465
)
@@ -493,43 +496,37 @@ async def test_get_task_callback_with_invalid_task(
493496

494497

495498
@pytest.mark.parametrize(
496-
'initial_extensions, input_extensions, expected_metadata, expected_extensions',
499+
'initial_extensions, input_extensions, expected_metadata',
497500
[
498501
(
499502
None,
500503
None,
501504
None,
502-
None,
503505
), # Case 1: No initial, No input
504506
(
505507
['ext1'],
506508
None,
507509
[(HTTP_EXTENSION_HEADER, 'ext1')],
508-
['ext1'],
509510
), # Case 2: Initial, No input
510511
(
511512
None,
512513
['ext2'],
513514
[(HTTP_EXTENSION_HEADER, 'ext2')],
514-
['ext2'],
515515
), # Case 3: No initial, Input
516516
(
517517
['ext1'],
518518
['ext2'],
519519
[(HTTP_EXTENSION_HEADER, 'ext2')],
520-
['ext2'],
521520
), # Case 4: Initial, Input (override)
522521
(
523522
['ext1'],
524523
['ext2', 'ext3'],
525-
[(HTTP_EXTENSION_HEADER, 'ext2, ext3')],
526-
['ext2', 'ext3'],
524+
[(HTTP_EXTENSION_HEADER, 'ext2,ext3')],
527525
), # Case 5: Initial, Multiple inputs (override)
528526
(
529527
['ext1', 'ext2'],
530528
['ext3'],
531529
[(HTTP_EXTENSION_HEADER, 'ext3')],
532-
['ext3'],
533530
), # Case 6: Multiple initial, Single input (override)
534531
],
535532
)
@@ -538,12 +535,8 @@ def test_get_grpc_metadata(
538535
initial_extensions: list[str] | None,
539536
input_extensions: list[str] | None,
540537
expected_metadata: list[tuple[str, str]] | None,
541-
expected_extensions: list[str] | None,
542538
) -> None:
543539
"""Tests _get_grpc_metadata for correct metadata generation and self.extensions update."""
544540
grpc_transport.extensions = initial_extensions
545-
546-
metadata = grpc_transport._get_grpc_metadata(extensions=input_extensions)
547-
541+
metadata = grpc_transport._get_grpc_metadata(input_extensions)
548542
assert metadata == expected_metadata
549-
assert grpc_transport.extensions == expected_extensions

0 commit comments

Comments
 (0)