11from collections .abc import AsyncIterator
2+ from typing import cast
23
34from a2a .client .client import (
45 Client ,
1213from a2a .client .middleware import ClientCallInterceptor
1314from a2a .client .transports .base import ClientTransport
1415from a2a .extensions .base import Extension
16+ from a2a .extensions .trace import (
17+ AgentInvocation ,
18+ CallTypeEnum ,
19+ StepAction ,
20+ TraceExtension ,
21+ )
1522from a2a .types import (
1623 AgentCard ,
1724 GetTaskPushNotificationConfigParams ,
@@ -68,9 +75,31 @@ async def send_message(
6875 Yields:
6976 An async iterator of `ClientEvent` or a final `Message` response.
7077 """
78+ trace_extension : TraceExtension | None = None
7179 for extension in self ._extensions :
80+ if isinstance (extension , TraceExtension ):
81+ trace_extension = cast (TraceExtension , extension )
7282 extension .on_client_message (request )
7383
84+ step = None
85+ if trace_extension :
86+ trace_id = request .metadata .get ('trace' , {}).get ('trace_id' )
87+ parent_step_id = request .metadata .get ('trace' , {}).get (
88+ 'parent_step_id'
89+ )
90+ step = trace_extension .start_step (
91+ trace_id = trace_id ,
92+ parent_step_id = parent_step_id ,
93+ call_type = CallTypeEnum .AGENT ,
94+ step_action = StepAction (
95+ agent_invocation = AgentInvocation (
96+ agent_url = self ._card .url ,
97+ agent_name = self ._card .name ,
98+ requests = request .model_dump (mode = 'json' ),
99+ )
100+ ),
101+ )
102+
74103 config = MessageSendConfiguration (
75104 accepted_output_modes = self ._config .accepted_output_modes ,
76105 blocking = not self ._config .polling ,
@@ -82,33 +111,44 @@ async def send_message(
82111 )
83112 params = MessageSendParams (message = request , configuration = config )
84113
85- if not self ._config .streaming or not self ._card .capabilities .streaming :
86- response = await self ._transport .send_message (
114+ try :
115+ if (
116+ not self ._config .streaming
117+ or not self ._card .capabilities .streaming
118+ ):
119+ response = await self ._transport .send_message (
120+ params , context = context
121+ )
122+ result = (
123+ (response , None )
124+ if isinstance (response , Task )
125+ else response
126+ )
127+ await self .consume (result , self ._card )
128+ yield result
129+ return
130+
131+ tracker = ClientTaskManager ()
132+ stream = self ._transport .send_message_streaming (
87133 params , context = context
88134 )
89- result = (
90- (response , None ) if isinstance (response , Task ) else response
91- )
92- await self .consume (result , self ._card )
93- yield result
94- return
95135
96- tracker = ClientTaskManager ( )
97- stream = self . _transport . send_message_streaming ( params , context = context )
98-
99- first_event = await anext ( stream )
100- # The response from a server may be either exactly one Message or a
101- # series of Task updates. Separate out the first message for special
102- # case handling, which allows us to simplify further stream processing.
103- if isinstance ( first_event , Message ):
104- await self . consume ( first_event , self . _card )
105- yield first_event
106- return
107-
108- yield await self ._process_response (tracker , first_event )
109-
110- async for event in stream :
111- yield await self . _process_response ( tracker , event )
136+ first_event = await anext ( stream )
137+ # The response from a server may be either exactly one Message or a
138+ # series of Task updates. Separate out the first message for special
139+ # case handling, which allows us to simplify further stream processing.
140+ if isinstance ( first_event , Message ):
141+ await self . consume ( first_event , self . _card )
142+ yield first_event
143+ return
144+
145+ yield await self . _process_response ( tracker , first_event )
146+
147+ async for event in stream :
148+ yield await self ._process_response (tracker , event )
149+ finally :
150+ if trace_extension and step :
151+ trace_extension . end_step ( step . step_id )
112152
113153 async def _process_response (
114154 self ,
0 commit comments