Skip to content

Commit 1a6bee9

Browse files
committed
Implement resubscribe across transport clients
1 parent d3e1027 commit 1a6bee9

File tree

7 files changed

+217
-95
lines changed

7 files changed

+217
-95
lines changed

src/a2a/client/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@
88
InMemoryContextCredentialStore,
99
)
1010
from a2a.client.card_resolver import A2ACardResolver
11-
from a2a.client.client import (
12-
Client,
13-
ClientConfig,
14-
ClientEvent,
15-
Consumer,
16-
)
11+
from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer
1712
from a2a.client.client_factory import (
1813
ClientFactory,
1914
ClientProducer,

src/a2a/client/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ async def resubscribe(
162162
request: TaskIdParams,
163163
*,
164164
context: ClientCallContext | None = None,
165-
) -> AsyncIterator[Task | Message]:
165+
) -> AsyncIterator[ClientEvent]:
166166
"""Resubscribes to a task's event stream."""
167167

168168
@abstractmethod

src/a2a/client/client_task_manager.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import logging
22

3-
from a2a.client.errors import A2AClientInvalidArgsError
3+
from a2a.client.errors import (
4+
A2AClientInvalidArgsError,
5+
A2AClientInvalidStateError,
6+
)
47
from a2a.server.events.event_queue import Event
58
from a2a.types import (
69
Message,
@@ -12,7 +15,6 @@
1215
)
1316
from a2a.utils import append_artifact_to_task
1417

15-
1618
logger = logging.getLogger(__name__)
1719

1820

@@ -45,6 +47,24 @@ def get_task(self) -> Task | None:
4547

4648
return self._current_task
4749

50+
def get_task_or_raise(self) -> Task:
51+
"""Retrieves the current task object.
52+
53+
Returns:
54+
The `Task` object.
55+
56+
Raises:
57+
A2AClientInvalidStateError: If there is no current known Task.
58+
"""
59+
if not (task := self.get_task()):
60+
# Note: The source of this error is either from bad client usage
61+
# or from the server sending invalid updates. It indicates that this
62+
# task manager has not consumed any information about a task, yet
63+
# the caller is attempting to retrieve the current state of the task
64+
# it expects to be present.
65+
raise A2AClientInvalidStateError('no current Task')
66+
return task
67+
4868
async def save_task_event(
4969
self, event: Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
5070
) -> Task | None:

src/a2a/client/errors.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Custom exceptions for the A2A client."""
22

3+
from a2a.types import JSONRPCErrorResponse
4+
35

46
class A2AClientError(Exception):
57
"""Base exception for A2A Client errors."""
@@ -57,3 +59,31 @@ def __init__(self, message: str):
5759
"""
5860
self.message = message
5961
super().__init__(f'Invalid arguments error: {message}')
62+
63+
64+
class A2AClientInvalidStateError(A2AClientError):
65+
"""Client exception for an invalid client state."""
66+
67+
def __init__(self, message: str):
68+
"""Initializes the A2AClientInvalidStateError.
69+
70+
Args:
71+
message: A descriptive error message.
72+
"""
73+
self.message = message
74+
super().__init__(f'Invalid state error: {message}')
75+
76+
77+
class A2AClientJSONRPCError(A2AClientError):
78+
"""Client exception for JSON-RPC errors returned by the server."""
79+
80+
def __init__(self, error: JSONRPCErrorResponse):
81+
"""Initializes the A2AClientJsonRPCError.
82+
83+
Args:
84+
code: The error code.
85+
message: A descriptive error message.
86+
data: Optional additional error data.
87+
"""
88+
self.error = error.error
89+
super().__init__(f'JSON-RPC Error {error.error}')

src/a2a/client/grpc_client.py

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Consumer,
2222
)
2323
from a2a.client.client_task_manager import ClientTaskManager
24+
from a2a.client.errors import A2AClientInvalidStateError
2425
from a2a.client.middleware import ClientCallInterceptor
2526
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
2627
from a2a.types import (
@@ -104,8 +105,7 @@ async def send_message_streaming(
104105
*,
105106
context: ClientCallContext | None = None,
106107
) -> AsyncGenerator[
107-
Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
108-
None,
108+
Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
109109
]:
110110
"""Sends a streaming message request to the agent and yields responses as they arrive.
111111
@@ -134,18 +134,36 @@ async def send_message_streaming(
134134
response = await stream.read()
135135
if response == grpc.aio.EOF: # pyright: ignore [reportAttributeAccessIssue]
136136
break
137-
if response.HasField('msg'):
138-
yield proto_utils.FromProto.message(response.msg)
139-
elif response.HasField('task'):
140-
yield proto_utils.FromProto.task(response.task)
141-
elif response.HasField('status_update'):
142-
yield proto_utils.FromProto.task_status_update_event(
143-
response.status_update
144-
)
145-
elif response.HasField('artifact_update'):
146-
yield proto_utils.FromProto.task_artifact_update_event(
147-
response.artifact_update
148-
)
137+
yield proto_utils.FromProto.stream_response(response)
138+
139+
async def resubscribe(
140+
self, request: TaskIdParams, *, context: ClientCallContext | None = None
141+
) -> AsyncGenerator[
142+
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
143+
]:
144+
"""Reconnects to get task updates.
145+
146+
This method uses a unary server-side stream to receive updates.
147+
148+
Args:
149+
request: The `TaskIdParams` object containing the task information to reconnect to.
150+
context: The client call context.
151+
152+
Yields:
153+
Task update events, which can be either a Task, Message,
154+
TaskStatusUpdateEvent, or TaskArtifactUpdateEvent.
155+
156+
Raises:
157+
A2AClientInvalidStateError: If the server returns an invalid response.
158+
"""
159+
stream = self.stub.TaskSubscription(
160+
a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}')
161+
)
162+
while True:
163+
response = await stream.read()
164+
if response == grpc.aio.EOF: # pyright: ignore [reportAttributeAccessIssue]
165+
break
166+
yield proto_utils.FromProto.stream_response(response)
149167

150168
async def get_task(
151169
self,
@@ -283,9 +301,7 @@ def __init__(
283301
raise ValueError('GRPC client requires channel factory.')
284302
self._card = card
285303
self._config = config
286-
# Defer init to first use.
287-
self._transport_client = None
288-
channel = self._config.grpc_channel_factory(self._card.url)
304+
channel = config.grpc_channel_factory(self._card.url)
289305
stub = a2a_pb2_grpc.A2AServiceStub(channel)
290306
self._transport_client = GrpcTransportClient(stub, self._card)
291307

@@ -331,27 +347,45 @@ async def send_message(
331347
await self.consume(result, self._card)
332348
yield result
333349
return
334-
# Get Task tracker
335350
tracker = ClientTaskManager()
336-
async for event in self._transport_client.send_message_streaming(
351+
stream = self._transport_client.send_message_streaming(
337352
MessageSendParams(
338353
message=request,
339354
configuration=config,
340355
),
341356
context=context,
342-
):
343-
# Update task, check for errors, etc.
344-
if isinstance(event, Message):
345-
await self.consume(event, self._card)
346-
yield event
347-
return
348-
await tracker.process(event)
349-
result = (
350-
tracker.get_task(),
351-
None if isinstance(event, Task) else event,
357+
)
358+
# Only the first event may be a Message. All others must be Task
359+
# or TaskStatusUpdates. Separate this one out, which allows our core
360+
# event processing logic to ignore that case.
361+
# TODO(mikeas1): Reconcile with other transport logic.
362+
first_event = await anext(stream)
363+
if isinstance(first_event, Message):
364+
yield first_event
365+
return
366+
yield await self._process_response(tracker, first_event)
367+
async for result in stream:
368+
yield await self._process_response(tracker, result)
369+
370+
async def _process_response(
371+
self,
372+
tracker: ClientTaskManager,
373+
event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
374+
) -> ClientEvent:
375+
result = event.root.result
376+
# Update task, check for errors, etc.
377+
if isinstance(result, Message):
378+
raise A2AClientInvalidStateError(
379+
'received a streamed Message from server after first response; this'
380+
' is not supported'
352381
)
353-
await self.consume(result, self._card)
354-
yield result
382+
await tracker.process(result)
383+
result = (
384+
tracker.get_task_or_raise(),
385+
None if isinstance(result, Task) else result,
386+
)
387+
await self.consume(result, self._card)
388+
return result
355389

356390
async def get_task(
357391
self,
@@ -438,7 +472,7 @@ async def resubscribe(
438472
request: TaskIdParams,
439473
*,
440474
context: ClientCallContext | None = None,
441-
) -> AsyncIterator[Task | Message]:
475+
) -> AsyncIterator[ClientEvent]:
442476
"""Resubscribes to a task's event stream.
443477
444478
This is only available if both the client and server support streaming.
@@ -464,12 +498,14 @@ async def resubscribe(
464498
raise NotImplementedError(
465499
'Resubscribe is not implemented on the gRPC transport client.'
466500
)
467-
async for event in self._transport_client.resubscribe(
501+
# Note: works correctly for resubscription where the first event is the
502+
# current Task state.
503+
tracker = ClientTaskManager()
504+
async for result in self._transport_client.resubscribe(
468505
request,
469506
context=context,
470507
):
471-
# Update task, check for errors, etc.
472-
yield event
508+
yield await self._process_response(tracker, result)
473509

474510
async def get_card(
475511
self,

0 commit comments

Comments
 (0)