Skip to content

Commit a2eeb7b

Browse files
committed
feat: enhance extension handling across client and transport layers
1 parent 5b47562 commit a2eeb7b

File tree

12 files changed

+261
-103
lines changed

12 files changed

+261
-103
lines changed

src/a2a/client/base_client.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ def __init__(
3737
transport: ClientTransport,
3838
consumers: list[Consumer],
3939
middleware: list[ClientCallInterceptor],
40-
extensions: list[str],
4140
):
42-
super().__init__(consumers, middleware, extensions)
41+
super().__init__(consumers, middleware)
4342
self._card = card
4443
self._config = config
4544
self._transport = transport
@@ -50,6 +49,7 @@ async def send_message(
5049
*,
5150
context: ClientCallContext | None = None,
5251
request_metadata: dict[str, Any] | None = None,
52+
extensions: list[str] | None = None,
5353
) -> AsyncIterator[ClientEvent | Message]:
5454
"""Sends a message to the agent.
5555
@@ -61,6 +61,7 @@ async def send_message(
6161
request: The message to send to the agent.
6262
context: The client call context.
6363
request_metadata: Extensions Metadata attached to the request.
64+
extensions: List of extensions to be activated.
6465
6566
Yields:
6667
An async iterator of `ClientEvent` or a final `Message` response.
@@ -80,7 +81,7 @@ async def send_message(
8081

8182
if not self._config.streaming or not self._card.capabilities.streaming:
8283
response = await self._transport.send_message(
83-
params, context=context
84+
params, context=context, extensions=extensions
8485
)
8586
result = (
8687
(response, None) if isinstance(response, Task) else response
@@ -90,7 +91,9 @@ async def send_message(
9091
return
9192

9293
tracker = ClientTaskManager()
93-
stream = self._transport.send_message_streaming(params, context=context)
94+
stream = self._transport.send_message_streaming(
95+
params, context=context, extensions=extensions
96+
)
9497

9598
first_event = await anext(stream)
9699
# The response from a server may be either exactly one Message or a
@@ -127,74 +130,91 @@ async def get_task(
127130
request: TaskQueryParams,
128131
*,
129132
context: ClientCallContext | None = None,
133+
extensions: list[str] | None = None,
130134
) -> Task:
131135
"""Retrieves the current state and history of a specific task.
132136
133137
Args:
134138
request: The `TaskQueryParams` object specifying the task ID.
135139
context: The client call context.
140+
extensions: List of extensions to be activated.
136141
137142
Returns:
138143
A `Task` object representing the current state of the task.
139144
"""
140-
return await self._transport.get_task(request, context=context)
145+
return await self._transport.get_task(
146+
request, context=context, extensions=extensions
147+
)
141148

142149
async def cancel_task(
143150
self,
144151
request: TaskIdParams,
145152
*,
146153
context: ClientCallContext | None = None,
154+
extensions: list[str] | None = None,
147155
) -> Task:
148156
"""Requests the agent to cancel a specific task.
149157
150158
Args:
151159
request: The `TaskIdParams` object specifying the task ID.
152160
context: The client call context.
161+
extensions: List of extensions to be activated.
153162
154163
Returns:
155164
A `Task` object containing the updated task status.
156165
"""
157-
return await self._transport.cancel_task(request, context=context)
166+
return await self._transport.cancel_task(
167+
request, context=context, extensions=extensions
168+
)
158169

159170
async def set_task_callback(
160171
self,
161172
request: TaskPushNotificationConfig,
162173
*,
163174
context: ClientCallContext | None = None,
175+
extensions: list[str] | None = None,
164176
) -> TaskPushNotificationConfig:
165177
"""Sets or updates the push notification configuration for a specific task.
166178
167179
Args:
168180
request: The `TaskPushNotificationConfig` object with the new configuration.
169181
context: The client call context.
182+
extensions: List of extensions to be activated.
170183
171184
Returns:
172185
The created or updated `TaskPushNotificationConfig` object.
173186
"""
174-
return await self._transport.set_task_callback(request, context=context)
187+
return await self._transport.set_task_callback(
188+
request, context=context, extensions=extensions
189+
)
175190

176191
async def get_task_callback(
177192
self,
178193
request: GetTaskPushNotificationConfigParams,
179194
*,
180195
context: ClientCallContext | None = None,
196+
extensions: list[str] | None = None,
181197
) -> TaskPushNotificationConfig:
182198
"""Retrieves the push notification configuration for a specific task.
183199
184200
Args:
185201
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
186202
context: The client call context.
203+
extensions: List of extensions to be activated.
187204
188205
Returns:
189206
A `TaskPushNotificationConfig` object containing the configuration.
190207
"""
191-
return await self._transport.get_task_callback(request, context=context)
208+
return await self._transport.get_task_callback(
209+
request, context=context, extensions=extensions
210+
)
192211

193212
async def resubscribe(
194213
self,
195214
request: TaskIdParams,
196215
*,
197216
context: ClientCallContext | None = None,
217+
extensions: list[str] | None = None,
198218
) -> AsyncIterator[ClientEvent]:
199219
"""Resubscribes to a task's event stream.
200220
@@ -203,6 +223,7 @@ async def resubscribe(
203223
Args:
204224
request: Parameters to identify the task to resubscribe to.
205225
context: The client call context.
226+
extensions: List of extensions to be activated.
206227
207228
Yields:
208229
An async iterator of `ClientEvent` objects.
@@ -220,12 +241,15 @@ async def resubscribe(
220241
# we should never see Message updates, despite the typing of the service
221242
# definition indicating it may be possible.
222243
async for event in self._transport.resubscribe(
223-
request, context=context
244+
request, context=context, extensions=extensions
224245
):
225246
yield await self._process_response(tracker, event)
226247

227248
async def get_card(
228-
self, *, context: ClientCallContext | None = None
249+
self,
250+
*,
251+
context: ClientCallContext | None = None,
252+
extensions: list[str] | None = None,
229253
) -> AgentCard:
230254
"""Retrieves the agent's card.
231255
@@ -234,11 +258,14 @@ async def get_card(
234258
235259
Args:
236260
context: The client call context.
261+
extensions: List of extensions to be activated.
237262
238263
Returns:
239264
The `AgentCard` for the agent.
240265
"""
241-
card = await self._transport.get_card(context=context)
266+
card = await self._transport.get_card(
267+
context=context, extensions=extensions
268+
)
242269
self._card = card
243270
return card
244271

src/a2a/client/client.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,19 @@ def __init__(
9393
self,
9494
consumers: list[Consumer] | None = None,
9595
middleware: list[ClientCallInterceptor] | None = None,
96-
extensions: list[str] | None = None,
9796
):
9897
"""Initializes the client with consumers and middleware.
9998
10099
Args:
101100
consumers: A list of callables to process events from the agent.
102101
middleware: A list of interceptors to process requests and responses.
103-
extensions: A list of extension URIs the client supports.
104102
"""
105103
if middleware is None:
106104
middleware = []
107105
if consumers is None:
108106
consumers = []
109-
if extensions is None:
110-
extensions = []
111107
self._consumers = consumers
112108
self._middleware = middleware
113-
self._extensions = extensions
114109

115110
@abstractmethod
116111
async def send_message(
@@ -119,6 +114,7 @@ async def send_message(
119114
*,
120115
context: ClientCallContext | None = None,
121116
request_metadata: dict[str, Any] | None = None,
117+
extensions: list[str] | None = None,
122118
) -> AsyncIterator[ClientEvent | Message]:
123119
"""Sends a message to the server.
124120
@@ -137,6 +133,7 @@ async def get_task(
137133
request: TaskQueryParams,
138134
*,
139135
context: ClientCallContext | None = None,
136+
extensions: list[str] | None = None,
140137
) -> Task:
141138
"""Retrieves the current state and history of a specific task."""
142139

@@ -146,6 +143,7 @@ async def cancel_task(
146143
request: TaskIdParams,
147144
*,
148145
context: ClientCallContext | None = None,
146+
extensions: list[str] | None = None,
149147
) -> Task:
150148
"""Requests the agent to cancel a specific task."""
151149

@@ -155,6 +153,7 @@ async def set_task_callback(
155153
request: TaskPushNotificationConfig,
156154
*,
157155
context: ClientCallContext | None = None,
156+
extensions: list[str] | None = None,
158157
) -> TaskPushNotificationConfig:
159158
"""Sets or updates the push notification configuration for a specific task."""
160159

@@ -164,6 +163,7 @@ async def get_task_callback(
164163
request: GetTaskPushNotificationConfigParams,
165164
*,
166165
context: ClientCallContext | None = None,
166+
extensions: list[str] | None = None,
167167
) -> TaskPushNotificationConfig:
168168
"""Retrieves the push notification configuration for a specific task."""
169169

@@ -173,14 +173,18 @@ async def resubscribe(
173173
request: TaskIdParams,
174174
*,
175175
context: ClientCallContext | None = None,
176+
extensions: list[str] | None = None,
176177
) -> AsyncIterator[ClientEvent]:
177178
"""Resubscribes to a task's event stream."""
178179
return
179180
yield
180181

181182
@abstractmethod
182183
async def get_card(
183-
self, *, context: ClientCallContext | None = None
184+
self,
185+
*,
186+
context: ClientCallContext | None = None,
187+
extensions: list[str] | None = None,
184188
) -> AgentCard:
185189
"""Retrieves the agent's card."""
186190

src/a2a/client/client_factory.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ def create(
245245
transport,
246246
all_consumers,
247247
interceptors or [],
248-
all_extensions,
249248
)
250249

251250

src/a2a/client/transports/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ async def send_message(
2525
request: MessageSendParams,
2626
*,
2727
context: ClientCallContext | None = None,
28+
extensions: list[str] | None = None,
2829
) -> Task | Message:
2930
"""Sends a non-streaming message request to the agent."""
3031

@@ -34,6 +35,7 @@ async def send_message_streaming(
3435
request: MessageSendParams,
3536
*,
3637
context: ClientCallContext | None = None,
38+
extensions: list[str] | None = None,
3739
) -> AsyncGenerator[
3840
Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
3941
]:
@@ -47,6 +49,7 @@ async def get_task(
4749
request: TaskQueryParams,
4850
*,
4951
context: ClientCallContext | None = None,
52+
extensions: list[str] | None = None,
5053
) -> Task:
5154
"""Retrieves the current state and history of a specific task."""
5255

@@ -56,6 +59,7 @@ async def cancel_task(
5659
request: TaskIdParams,
5760
*,
5861
context: ClientCallContext | None = None,
62+
extensions: list[str] | None = None,
5963
) -> Task:
6064
"""Requests the agent to cancel a specific task."""
6165

@@ -65,6 +69,7 @@ async def set_task_callback(
6569
request: TaskPushNotificationConfig,
6670
*,
6771
context: ClientCallContext | None = None,
72+
extensions: list[str] | None = None,
6873
) -> TaskPushNotificationConfig:
6974
"""Sets or updates the push notification configuration for a specific task."""
7075

@@ -74,6 +79,7 @@ async def get_task_callback(
7479
request: GetTaskPushNotificationConfigParams,
7580
*,
7681
context: ClientCallContext | None = None,
82+
extensions: list[str] | None = None,
7783
) -> TaskPushNotificationConfig:
7884
"""Retrieves the push notification configuration for a specific task."""
7985

@@ -83,6 +89,7 @@ async def resubscribe(
8389
request: TaskIdParams,
8490
*,
8591
context: ClientCallContext | None = None,
92+
extensions: list[str] | None = None,
8693
) -> AsyncGenerator[
8794
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
8895
]:
@@ -95,6 +102,7 @@ async def get_card(
95102
self,
96103
*,
97104
context: ClientCallContext | None = None,
105+
extensions: list[str] | None = None,
98106
) -> AgentCard:
99107
"""Retrieves the AgentCard."""
100108

0 commit comments

Comments
 (0)