@@ -78,13 +78,9 @@ def _get_http_args(
7878 ) -> dict [str , Any ] | None :
7979 return context .state .get ('http_kwargs' ) if context else None
8080
81- async def send_message (
82- self ,
83- request : MessageSendParams ,
84- * ,
85- context : ClientCallContext | None = None ,
86- ) -> Task | Message :
87- """Sends a non-streaming message request to the agent."""
81+ async def _prepare_send_message (
82+ self , request : MessageSendParams , context : ClientCallContext | None
83+ ) -> tuple [dict [str , Any ], dict [str , Any ]]:
8884 pb = a2a_pb2 .SendMessageRequest (
8985 request = proto_utils .ToProto .message (request .message ),
9086 configuration = proto_utils .ToProto .message_send_configuration (
@@ -102,6 +98,18 @@ async def send_message(
10298 self ._get_http_args (context ),
10399 context ,
104100 )
101+ return payload , modified_kwargs
102+
103+ async def send_message (
104+ self ,
105+ request : MessageSendParams ,
106+ * ,
107+ context : ClientCallContext | None = None ,
108+ ) -> Task | Message :
109+ """Sends a non-streaming message request to the agent."""
110+ payload , modified_kwargs = await self ._prepare_send_message (
111+ request , context
112+ )
105113 response_data = await self ._send_post_request (
106114 '/v1/message:send' , payload , modified_kwargs
107115 )
@@ -118,22 +126,8 @@ async def send_message_streaming(
118126 Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message
119127 ]:
120128 """Sends a streaming message request to the agent and yields responses as they arrive."""
121- pb = a2a_pb2 .SendMessageRequest (
122- request = proto_utils .ToProto .message (request .message ),
123- configuration = proto_utils .ToProto .message_send_configuration (
124- request .configuration
125- ),
126- metadata = (
127- proto_utils .ToProto .metadata (request .metadata )
128- if request .metadata
129- else None
130- ),
131- )
132- payload = MessageToDict (pb )
133- payload , modified_kwargs = await self ._apply_interceptors (
134- payload ,
135- self ._get_http_args (context ),
136- context ,
129+ payload , modified_kwargs = await self ._prepare_send_message (
130+ request , context
137131 )
138132
139133 modified_kwargs .setdefault ('timeout' , None )
@@ -161,18 +155,9 @@ async def send_message_streaming(
161155 503 , f'Network communication error: { e } '
162156 ) from e
163157
164- async def _send_post_request (
165- self ,
166- target : str ,
167- rpc_request_payload : dict [str , Any ],
168- http_kwargs : dict [str , Any ] | None = None ,
169- ) -> dict [str , Any ]:
158+ async def _send_request (self , request : httpx .Request ) -> dict [str , Any ]:
170159 try :
171- response = await self .httpx_client .post (
172- f'{ self .url } { target } ' ,
173- json = rpc_request_payload ,
174- ** (http_kwargs or {}),
175- )
160+ response = await self .httpx_client .send (request )
176161 response .raise_for_status ()
177162 return response .json ()
178163 except httpx .HTTPStatusError as e :
@@ -184,28 +169,35 @@ async def _send_post_request(
184169 503 , f'Network communication error: { e } '
185170 ) from e
186171
172+ async def _send_post_request (
173+ self ,
174+ target : str ,
175+ rpc_request_payload : dict [str , Any ],
176+ http_kwargs : dict [str , Any ] | None = None ,
177+ ) -> dict [str , Any ]:
178+ return await self ._send_request (
179+ self .httpx_client .build_request (
180+ 'POST' ,
181+ f'{ self .url } { target } ' ,
182+ json = rpc_request_payload ,
183+ ** (http_kwargs or {}),
184+ )
185+ )
186+
187187 async def _send_get_request (
188188 self ,
189189 target : str ,
190190 query_params : dict [str , str ],
191191 http_kwargs : dict [str , Any ] | None = None ,
192192 ) -> dict [str , Any ]:
193- try :
194- response = await self .httpx_client .get (
193+ return await self ._send_request (
194+ self .httpx_client .build_request (
195+ 'GET' ,
195196 f'{ self .url } { target } ' ,
196197 params = query_params ,
197198 ** (http_kwargs or {}),
198199 )
199- response .raise_for_status ()
200- return response .json ()
201- except httpx .HTTPStatusError as e :
202- raise A2AClientHTTPError (e .response .status_code , str (e )) from e
203- except json .JSONDecodeError as e :
204- raise A2AClientJSONError (str (e )) from e
205- except httpx .RequestError as e :
206- raise A2AClientHTTPError (
207- 503 , f'Network communication error: { e } '
208- ) from e
200+ )
209201
210202 async def get_task (
211203 self ,
0 commit comments