Skip to content

Commit c2f4f2f

Browse files
mikeas1holtskinner
andauthored
refactor: Refactor client into BaseClient + ClientTransport (#363)
A refactor of the refactor. This pulls out the common code from across transport-specific clients into one BaseClient, then shoves all the transport-specific code into implementations of `ClientTransport`. --------- Co-authored-by: Holt Skinner <[email protected]>
1 parent a063a8e commit c2f4f2f

25 files changed

+3098
-3175
lines changed

src/a2a/client/__init__.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer
1212
from a2a.client.client_factory import (
1313
ClientFactory,
14-
ClientProducer,
1514
minimal_agent_card,
1615
)
1716
from a2a.client.errors import (
@@ -21,77 +20,49 @@
2120
A2AClientTimeoutError,
2221
)
2322
from a2a.client.helpers import create_text_message_object
24-
from a2a.client.jsonrpc_client import (
25-
A2AClient,
26-
JsonRpcClient,
27-
JsonRpcTransportClient,
28-
NewJsonRpcClient,
29-
)
23+
from a2a.client.legacy import A2AClient
3024
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
31-
from a2a.client.rest_client import (
32-
NewRestfulClient,
33-
RestClient,
34-
RestTransportClient,
35-
)
3625

3726

3827
logger = logging.getLogger(__name__)
3928

4029
try:
41-
from a2a.client.grpc_client import (
42-
GrpcClient,
43-
GrpcTransportClient, # type: ignore
44-
NewGrpcClient,
45-
)
30+
from a2a.client.legacy_grpc import A2AGrpcClient
4631
except ImportError as e:
4732
_original_error = e
4833
logger.debug(
4934
'A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s',
5035
_original_error,
5136
)
5237

53-
class GrpcTransportClient: # type: ignore
38+
class A2AGrpcClient: # type: ignore
5439
"""Placeholder for A2AGrpcClient when dependencies are not installed."""
5540

5641
def __init__(self, *args, **kwargs):
5742
raise ImportError(
5843
'To use A2AGrpcClient, its dependencies must be installed. '
5944
'You can install them with \'pip install "a2a-sdk[grpc]"\''
6045
) from _original_error
61-
finally:
62-
# For backward compatability define this alias. This will be deprecated in
63-
# a future release.
64-
A2AGrpcClient = GrpcTransportClient # type: ignore
6546

6647

6748
__all__ = [
6849
'A2ACardResolver',
69-
'A2AClient', # for backward compatability
50+
'A2AClient',
7051
'A2AClientError',
7152
'A2AClientHTTPError',
7253
'A2AClientJSONError',
7354
'A2AClientTimeoutError',
74-
'A2AGrpcClient', # for backward compatability
55+
'A2AGrpcClient',
7556
'AuthInterceptor',
7657
'Client',
7758
'ClientCallContext',
7859
'ClientCallInterceptor',
7960
'ClientConfig',
8061
'ClientEvent',
8162
'ClientFactory',
82-
'ClientProducer',
8363
'Consumer',
8464
'CredentialService',
85-
'GrpcClient',
86-
'GrpcTransportClient',
8765
'InMemoryContextCredentialStore',
88-
'JsonRpcClient',
89-
'JsonRpcTransportClient',
90-
'NewGrpcClient',
91-
'NewJsonRpcClient',
92-
'NewRestfulClient',
93-
'RestClient',
94-
'RestTransportClient',
9566
'create_text_message_object',
9667
'minimal_agent_card',
9768
]

src/a2a/client/base_client.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from collections.abc import AsyncIterator
2+
3+
from a2a.client.client import (
4+
Client,
5+
ClientCallContext,
6+
ClientConfig,
7+
ClientEvent,
8+
Consumer,
9+
)
10+
from a2a.client.client_task_manager import ClientTaskManager
11+
from a2a.client.errors import A2AClientInvalidStateError
12+
from a2a.client.middleware import ClientCallInterceptor
13+
from a2a.client.transports.base import ClientTransport
14+
from a2a.types import (
15+
AgentCard,
16+
GetTaskPushNotificationConfigParams,
17+
Message,
18+
MessageSendConfiguration,
19+
MessageSendParams,
20+
Task,
21+
TaskArtifactUpdateEvent,
22+
TaskIdParams,
23+
TaskPushNotificationConfig,
24+
TaskQueryParams,
25+
TaskStatusUpdateEvent,
26+
)
27+
28+
29+
class BaseClient(Client):
30+
"""Base implementation of the A2A client, containing transport-independent logic."""
31+
32+
def __init__(
33+
self,
34+
card: AgentCard,
35+
config: ClientConfig,
36+
transport: ClientTransport,
37+
consumers: list[Consumer],
38+
middleware: list[ClientCallInterceptor],
39+
):
40+
super().__init__(consumers, middleware)
41+
self._card = card
42+
self._config = config
43+
self._transport = transport
44+
45+
async def send_message(
46+
self,
47+
request: Message,
48+
*,
49+
context: ClientCallContext | None = None,
50+
) -> AsyncIterator[ClientEvent | Message]:
51+
"""Sends a message to the agent.
52+
53+
This method handles both streaming and non-streaming (polling) interactions
54+
based on the client configuration and agent capabilities. It will yield
55+
events as they are received from the agent.
56+
57+
Args:
58+
request: The message to send to the agent.
59+
context: The client call context.
60+
61+
Yields:
62+
An async iterator of `ClientEvent` or a final `Message` response.
63+
"""
64+
config = MessageSendConfiguration(
65+
accepted_output_modes=self._config.accepted_output_modes,
66+
blocking=not self._config.polling,
67+
push_notification_config=(
68+
self._config.push_notification_configs[0]
69+
if self._config.push_notification_configs
70+
else None
71+
),
72+
)
73+
params = MessageSendParams(message=request, configuration=config)
74+
75+
if not self._config.streaming or not self._card.capabilities.streaming:
76+
response = await self._transport.send_message(
77+
params, context=context
78+
)
79+
result = (
80+
(response, None) if isinstance(response, Task) else response
81+
)
82+
await self.consume(result, self._card)
83+
yield result
84+
return
85+
86+
tracker = ClientTaskManager()
87+
stream = self._transport.send_message_streaming(params, context=context)
88+
89+
first_event = await anext(stream)
90+
# The response from a server may be either exactly one Message or a
91+
# series of Task updates. Separate out the first message for special
92+
# case handling, which allows us to simplify further stream processing.
93+
if isinstance(first_event, Message):
94+
await self.consume(first_event, self._card)
95+
yield first_event
96+
return
97+
98+
yield await self._process_response(tracker, first_event)
99+
100+
async for event in stream:
101+
yield await self._process_response(tracker, event)
102+
103+
async def _process_response(
104+
self,
105+
tracker: ClientTaskManager,
106+
event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
107+
) -> ClientEvent:
108+
if isinstance(event, Message):
109+
raise A2AClientInvalidStateError(
110+
'received a streamed Message from server after first response; this is not supported'
111+
)
112+
await tracker.process(event)
113+
task = tracker.get_task_or_raise()
114+
update = None if isinstance(event, Task) else event
115+
client_event = (task, update)
116+
await self.consume(client_event, self._card)
117+
return client_event
118+
119+
async def get_task(
120+
self,
121+
request: TaskQueryParams,
122+
*,
123+
context: ClientCallContext | None = None,
124+
) -> Task:
125+
"""Retrieves the current state and history of a specific task.
126+
127+
Args:
128+
request: The `TaskQueryParams` object specifying the task ID.
129+
context: The client call context.
130+
131+
Returns:
132+
A `Task` object representing the current state of the task.
133+
"""
134+
return await self._transport.get_task(request, context=context)
135+
136+
async def cancel_task(
137+
self,
138+
request: TaskIdParams,
139+
*,
140+
context: ClientCallContext | None = None,
141+
) -> Task:
142+
"""Requests the agent to cancel a specific task.
143+
144+
Args:
145+
request: The `TaskIdParams` object specifying the task ID.
146+
context: The client call context.
147+
148+
Returns:
149+
A `Task` object containing the updated task status.
150+
"""
151+
return await self._transport.cancel_task(request, context=context)
152+
153+
async def set_task_callback(
154+
self,
155+
request: TaskPushNotificationConfig,
156+
*,
157+
context: ClientCallContext | None = None,
158+
) -> TaskPushNotificationConfig:
159+
"""Sets or updates the push notification configuration for a specific task.
160+
161+
Args:
162+
request: The `TaskPushNotificationConfig` object with the new configuration.
163+
context: The client call context.
164+
165+
Returns:
166+
The created or updated `TaskPushNotificationConfig` object.
167+
"""
168+
return await self._transport.set_task_callback(request, context=context)
169+
170+
async def get_task_callback(
171+
self,
172+
request: GetTaskPushNotificationConfigParams,
173+
*,
174+
context: ClientCallContext | None = None,
175+
) -> TaskPushNotificationConfig:
176+
"""Retrieves the push notification configuration for a specific task.
177+
178+
Args:
179+
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
180+
context: The client call context.
181+
182+
Returns:
183+
A `TaskPushNotificationConfig` object containing the configuration.
184+
"""
185+
return await self._transport.get_task_callback(request, context=context)
186+
187+
async def resubscribe(
188+
self,
189+
request: TaskIdParams,
190+
*,
191+
context: ClientCallContext | None = None,
192+
) -> AsyncIterator[ClientEvent]:
193+
"""Resubscribes to a task's event stream.
194+
195+
This is only available if both the client and server support streaming.
196+
197+
Args:
198+
request: Parameters to identify the task to resubscribe to.
199+
context: The client call context.
200+
201+
Yields:
202+
An async iterator of `ClientEvent` objects.
203+
204+
Raises:
205+
NotImplementedError: If streaming is not supported by the client or server.
206+
"""
207+
if not self._config.streaming or not self._card.capabilities.streaming:
208+
raise NotImplementedError(
209+
'client and/or server do not support resubscription.'
210+
)
211+
212+
tracker = ClientTaskManager()
213+
# Note: resubscribe can only be called on an existing task. As such,
214+
# we should never see Message updates, despite the typing of the service
215+
# definition indicating it may be possible.
216+
async for event in self._transport.resubscribe(
217+
request, context=context
218+
):
219+
yield await self._process_response(tracker, event)
220+
221+
async def get_card(
222+
self, *, context: ClientCallContext | None = None
223+
) -> AgentCard:
224+
"""Retrieves the agent's card.
225+
226+
This will fetch the authenticated card if necessary and update the
227+
client's internal state with the new card.
228+
229+
Args:
230+
context: The client call context.
231+
232+
Returns:
233+
The `AgentCard` for the agent.
234+
"""
235+
card = await self._transport.get_card(context=context)
236+
self._card = card
237+
return card
238+
239+
async def close(self) -> None:
240+
"""Closes the underlying transport."""
241+
await self._transport.close()

src/a2a/client/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ async def send_message(
119119
pairs, or a `Message`. Client will also send these values to any
120120
configured `Consumer`s in the client.
121121
"""
122+
return
123+
yield
122124

123125
@abstractmethod
124126
async def get_task(
@@ -164,6 +166,8 @@ async def resubscribe(
164166
context: ClientCallContext | None = None,
165167
) -> AsyncIterator[ClientEvent]:
166168
"""Resubscribes to a task's event stream."""
169+
return
170+
yield
167171

168172
@abstractmethod
169173
async def get_card(

0 commit comments

Comments
 (0)