2121 Consumer ,
2222)
2323from a2a .client .client_task_manager import ClientTaskManager
24+ from a2a .client .errors import A2AClientInvalidStateError
2425from a2a .client .middleware import ClientCallInterceptor
2526from a2a .grpc import a2a_pb2 , a2a_pb2_grpc
2627from a2a .types import (
@@ -104,8 +105,7 @@ async def send_message_streaming(
104105 * ,
105106 context : ClientCallContext | None = None ,
106107 ) -> AsyncGenerator [
107- Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ,
108- None ,
108+ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
109109 ]:
110110 """Sends a streaming message request to the agent and yields responses as they arrive.
111111
@@ -134,18 +134,36 @@ async def send_message_streaming(
134134 response = await stream .read ()
135135 if response == grpc .aio .EOF : # pyright: ignore [reportAttributeAccessIssue]
136136 break
137- if response .HasField ('msg' ):
138- yield proto_utils .FromProto .message (response .msg )
139- elif response .HasField ('task' ):
140- yield proto_utils .FromProto .task (response .task )
141- elif response .HasField ('status_update' ):
142- yield proto_utils .FromProto .task_status_update_event (
143- response .status_update
144- )
145- elif response .HasField ('artifact_update' ):
146- yield proto_utils .FromProto .task_artifact_update_event (
147- response .artifact_update
148- )
137+ yield proto_utils .FromProto .stream_response (response )
138+
139+ async def resubscribe (
140+ self , request : TaskIdParams , * , context : ClientCallContext | None = None
141+ ) -> AsyncGenerator [
142+ Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
143+ ]:
144+ """Reconnects to get task updates.
145+
146+ This method uses a unary server-side stream to receive updates.
147+
148+ Args:
149+ request: The `TaskIdParams` object containing the task information to reconnect to.
150+ context: The client call context.
151+
152+ Yields:
153+ Task update events, which can be either a Task, Message,
154+ TaskStatusUpdateEvent, or TaskArtifactUpdateEvent.
155+
156+ Raises:
157+ A2AClientInvalidStateError: If the server returns an invalid response.
158+ """
159+ stream = self .stub .TaskSubscription (
160+ a2a_pb2 .TaskSubscriptionRequest (name = f'tasks/{ request .id } ' )
161+ )
162+ while True :
163+ response = await stream .read ()
164+ if response == grpc .aio .EOF : # pyright: ignore [reportAttributeAccessIssue]
165+ break
166+ yield proto_utils .FromProto .stream_response (response )
149167
150168 async def get_task (
151169 self ,
@@ -283,9 +301,7 @@ def __init__(
283301 raise ValueError ('GRPC client requires channel factory.' )
284302 self ._card = card
285303 self ._config = config
286- # Defer init to first use.
287- self ._transport_client = None
288- channel = self ._config .grpc_channel_factory (self ._card .url )
304+ channel = config .grpc_channel_factory (self ._card .url )
289305 stub = a2a_pb2_grpc .A2AServiceStub (channel )
290306 self ._transport_client = GrpcTransportClient (stub , self ._card )
291307
@@ -331,27 +347,45 @@ async def send_message(
331347 await self .consume (result , self ._card )
332348 yield result
333349 return
334- # Get Task tracker
335350 tracker = ClientTaskManager ()
336- async for event in self ._transport_client .send_message_streaming (
351+ stream = self ._transport_client .send_message_streaming (
337352 MessageSendParams (
338353 message = request ,
339354 configuration = config ,
340355 ),
341356 context = context ,
342- ):
343- # Update task, check for errors, etc.
344- if isinstance (event , Message ):
345- await self .consume (event , self ._card )
346- yield event
347- return
348- await tracker .process (event )
349- result = (
350- tracker .get_task (),
351- None if isinstance (event , Task ) else event ,
357+ )
358+ # Only the first event may be a Message. All others must be Task
359+ # or TaskStatusUpdates. Separate this one out, which allows our core
360+ # event processing logic to ignore that case.
361+ # TODO(mikeas1): Reconcile with other transport logic.
362+ first_event = await anext (stream )
363+ if isinstance (first_event , Message ):
364+ yield first_event
365+ return
366+ yield await self ._process_response (tracker , first_event )
367+ async for result in stream :
368+ yield await self ._process_response (tracker , result )
369+
370+ async def _process_response (
371+ self ,
372+ tracker : ClientTaskManager ,
373+ event : Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ,
374+ ) -> ClientEvent :
375+ result = event .root .result
376+ # Update task, check for errors, etc.
377+ if isinstance (result , Message ):
378+ raise A2AClientInvalidStateError (
379+ 'received a streamed Message from server after first response; this'
380+ ' is not supported'
352381 )
353- await self .consume (result , self ._card )
354- yield result
382+ await tracker .process (result )
383+ result = (
384+ tracker .get_task_or_raise (),
385+ None if isinstance (result , Task ) else result ,
386+ )
387+ await self .consume (result , self ._card )
388+ return result
355389
356390 async def get_task (
357391 self ,
@@ -438,7 +472,7 @@ async def resubscribe(
438472 request : TaskIdParams ,
439473 * ,
440474 context : ClientCallContext | None = None ,
441- ) -> AsyncIterator [Task | Message ]:
475+ ) -> AsyncIterator [ClientEvent ]:
442476 """Resubscribes to a task's event stream.
443477
444478 This is only available if both the client and server support streaming.
@@ -464,12 +498,14 @@ async def resubscribe(
464498 raise NotImplementedError (
465499 'Resubscribe is not implemented on the gRPC transport client.'
466500 )
467- async for event in self ._transport_client .resubscribe (
501+ # Note: works correctly for resubscription where the first event is the
502+ # current Task state.
503+ tracker = ClientTaskManager ()
504+ async for result in self ._transport_client .resubscribe (
468505 request ,
469506 context = context ,
470507 ):
471- # Update task, check for errors, etc.
472- yield event
508+ yield await self ._process_response (tracker , result )
473509
474510 async def get_card (
475511 self ,
0 commit comments