|
1 | 1 | import json |
2 | 2 | import logging |
| 3 | + |
3 | 4 | from collections.abc import AsyncGenerator |
4 | 5 | from typing import Any |
5 | 6 | from uuid import uuid4 |
6 | 7 |
|
7 | 8 | import httpx |
| 9 | + |
8 | 10 | from httpx_sse import SSEError, aconnect_sse |
9 | 11 | from pydantic import ValidationError |
10 | 12 |
|
11 | 13 | from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError |
12 | 14 | from a2a.client.middleware import ClientCallContext, ClientCallInterceptor |
13 | | -from a2a.types import (AgentCard, CancelTaskRequest, CancelTaskResponse, |
14 | | - GetTaskPushNotificationConfigRequest, |
15 | | - GetTaskPushNotificationConfigResponse, GetTaskRequest, |
16 | | - GetTaskResponse, SendMessageRequest, |
17 | | - SendMessageResponse, SendStreamingMessageRequest, |
18 | | - SendStreamingMessageResponse, |
19 | | - SetTaskPushNotificationConfigRequest, |
20 | | - SetTaskPushNotificationConfigResponse) |
| 15 | +from a2a.types import ( |
| 16 | + AgentCard, |
| 17 | + CancelTaskRequest, |
| 18 | + CancelTaskResponse, |
| 19 | + GetTaskPushNotificationConfigRequest, |
| 20 | + GetTaskPushNotificationConfigResponse, |
| 21 | + GetTaskRequest, |
| 22 | + GetTaskResponse, |
| 23 | + SendMessageRequest, |
| 24 | + SendMessageResponse, |
| 25 | + SendStreamingMessageRequest, |
| 26 | + SendStreamingMessageResponse, |
| 27 | + SetTaskPushNotificationConfigRequest, |
| 28 | + SetTaskPushNotificationConfigResponse, |
| 29 | +) |
21 | 30 | from a2a.utils.telemetry import SpanKind, trace_class |
22 | 31 |
|
| 32 | + |
23 | 33 | logger = logging.getLogger(__name__) |
24 | 34 |
|
25 | 35 |
|
@@ -157,15 +167,18 @@ async def _apply_interceptors( |
157 | 167 | final_request_payload = request_payload |
158 | 168 |
|
159 | 169 | for interceptor in self.interceptors: |
160 | | - final_request_payload, final_http_kwargs = await interceptor.intercept( |
| 170 | + ( |
| 171 | + final_request_payload, |
| 172 | + final_http_kwargs, |
| 173 | + ) = await interceptor.intercept( |
161 | 174 | method_name, |
162 | 175 | final_request_payload, |
163 | 176 | final_http_kwargs, |
164 | 177 | self.agent_card, |
165 | 178 | context, |
166 | 179 | ) |
167 | 180 | return final_request_payload, final_http_kwargs |
168 | | - |
| 181 | + |
169 | 182 | @staticmethod |
170 | 183 | async def get_client_from_agent_card_url( |
171 | 184 | httpx_client: httpx.AsyncClient, |
@@ -224,10 +237,13 @@ async def send_message( |
224 | 237 | """ |
225 | 238 | if not request.id: |
226 | 239 | request.id = str(uuid4()) |
227 | | - |
| 240 | + |
228 | 241 | # Apply interceptors before sending |
229 | 242 | payload, modified_kwargs = await self._apply_interceptors( |
230 | | - 'message/send', request.model_dump(mode='json', exclude_none=True), http_kwargs, context |
| 243 | + 'message/send', |
| 244 | + request.model_dump(mode='json', exclude_none=True), |
| 245 | + http_kwargs, |
| 246 | + context, |
231 | 247 | ) |
232 | 248 | response_data = await self._send_request(payload, modified_kwargs) |
233 | 249 | return SendMessageResponse(response_data) |
@@ -261,9 +277,12 @@ async def send_message_streaming( |
261 | 277 |
|
262 | 278 | # Apply interceptors before sending |
263 | 279 | payload, modified_kwargs = await self._apply_interceptors( |
264 | | - 'message/stream', request.model_dump(mode='json', exclude_none=True), http_kwargs, context |
| 280 | + 'message/stream', |
| 281 | + request.model_dump(mode='json', exclude_none=True), |
| 282 | + http_kwargs, |
| 283 | + context, |
265 | 284 | ) |
266 | | - |
| 285 | + |
267 | 286 | modified_kwargs.setdefault('timeout', None) |
268 | 287 |
|
269 | 288 | async with aconnect_sse( |
@@ -345,10 +364,13 @@ async def get_task( |
345 | 364 | """ |
346 | 365 | if not request.id: |
347 | 366 | request.id = str(uuid4()) |
348 | | - |
| 367 | + |
349 | 368 | # Apply interceptors before sending |
350 | 369 | payload, modified_kwargs = await self._apply_interceptors( |
351 | | - 'tasks/get', request.model_dump(mode='json', exclude_none=True), http_kwargs, context |
| 370 | + 'tasks/get', |
| 371 | + request.model_dump(mode='json', exclude_none=True), |
| 372 | + http_kwargs, |
| 373 | + context, |
352 | 374 | ) |
353 | 375 | response_data = await self._send_request(payload, modified_kwargs) |
354 | 376 | return GetTaskResponse(response_data) |
@@ -386,7 +408,10 @@ async def cancel_task( |
386 | 408 |
|
387 | 409 | # Apply interceptors before sending |
388 | 410 | payload, modified_kwargs = await self._apply_interceptors( |
389 | | - 'tasks/cancel', request.model_dump(mode='json', exclude_none=True), http_kwargs, context |
| 411 | + 'tasks/cancel', |
| 412 | + request.model_dump(mode='json', exclude_none=True), |
| 413 | + http_kwargs, |
| 414 | + context, |
390 | 415 | ) |
391 | 416 | response_data = await self._send_request(payload, modified_kwargs) |
392 | 417 | return CancelTaskResponse(response_data) |
@@ -417,7 +442,10 @@ async def set_task_callback( |
417 | 442 |
|
418 | 443 | # Apply interceptors before sending |
419 | 444 | payload, modified_kwargs = await self._apply_interceptors( |
420 | | - 'tasks/pushNotificationConfig/set', request.model_dump(mode='json', exclude_none=True), http_kwargs, context |
| 445 | + 'tasks/pushNotificationConfig/set', |
| 446 | + request.model_dump(mode='json', exclude_none=True), |
| 447 | + http_kwargs, |
| 448 | + context, |
421 | 449 | ) |
422 | 450 | response_data = await self._send_request(payload, modified_kwargs) |
423 | 451 | return SetTaskPushNotificationConfigResponse(response_data) |
@@ -448,7 +476,10 @@ async def get_task_callback( |
448 | 476 |
|
449 | 477 | # Apply interceptors before sending |
450 | 478 | payload, modified_kwargs = await self._apply_interceptors( |
451 | | - 'tasks/pushNotificationConfig/get', request.model_dump(mode='json', exclude_none=True), http_kwargs, context |
| 479 | + 'tasks/pushNotificationConfig/get', |
| 480 | + request.model_dump(mode='json', exclude_none=True), |
| 481 | + http_kwargs, |
| 482 | + context, |
452 | 483 | ) |
453 | 484 | response_data = await self._send_request(payload, modified_kwargs) |
454 | 485 | return GetTaskPushNotificationConfigResponse(response_data) |
0 commit comments