Skip to content

Commit ebb394c

Browse files
Merge branch '1.0-dev' into guglielmoc/remove_rest_handler
2 parents 65d8e84 + 4586c3e commit ebb394c

File tree

12 files changed

+432
-37
lines changed

12 files changed

+432
-37
lines changed

src/a2a/client/transports/grpc.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747
TaskPushNotificationConfig,
4848
)
4949
from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER
50-
from a2a.utils.errors import A2A_REASON_TO_ERROR
50+
from a2a.utils.errors import A2A_REASON_TO_ERROR, A2AError
51+
from a2a.utils.proto_utils import bad_request_to_validation_errors
5152
from a2a.utils.telemetry import SpanKind, trace_class
5253

5354

@@ -61,17 +62,23 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
6162

6263
# Use grpc_status to cleanly extract the rich Status from the call
6364
status = rpc_status.from_call(cast('grpc.Call', e))
65+
data = None
6466

6567
if status is not None:
68+
exception_cls: type[A2AError] | None = None
6669
for detail in status.details:
6770
if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
6871
error_info = error_details_pb2.ErrorInfo()
6972
detail.Unpack(error_info)
70-
7173
if error_info.domain == 'a2a-protocol.org':
7274
exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason)
73-
if exception_cls:
74-
raise exception_cls(status.message) from e
75+
elif detail.Is(error_details_pb2.BadRequest.DESCRIPTOR):
76+
bad_request = error_details_pb2.BadRequest()
77+
detail.Unpack(bad_request)
78+
data = {'errors': bad_request_to_validation_errors(bad_request)}
79+
80+
if exception_cls:
81+
raise exception_cls(status.message, data=data) from e
7582

7683
raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e
7784

src/a2a/client/transports/jsonrpc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,10 @@ def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> Exception:
318318
"""Creates the appropriate A2AError from a JSON-RPC error dictionary."""
319319
code = error_dict.get('code')
320320
message = error_dict.get('message', str(error_dict))
321+
data = error_dict.get('data')
321322

322323
if isinstance(code, int) and code in _JSON_RPC_ERROR_CODE_TO_A2A_ERROR:
323-
return _JSON_RPC_ERROR_CODE_TO_A2A_ERROR[code](message)
324+
return _JSON_RPC_ERROR_CODE_TO_A2A_ERROR[code](message, data=data)
324325

325326
# Fallback to general A2AClientError
326327
return A2AClientError(f'JSON-RPC Error {code}: {message}')

src/a2a/server/request_handlers/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from a2a.server.request_handlers.default_request_handler import (
66
DefaultRequestHandler,
77
)
8-
from a2a.server.request_handlers.request_handler import RequestHandler
8+
from a2a.server.request_handlers.request_handler import (
9+
RequestHandler,
10+
validate_request_params,
11+
)
912
from a2a.server.request_handlers.response_helpers import (
1013
build_error_response,
1114
prepare_response_object,
@@ -42,4 +45,5 @@ def __init__(self, *args, **kwargs):
4245
'RequestHandler',
4346
'build_error_response',
4447
'prepare_response_object',
48+
'validate_request_params',
4549
]

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
InMemoryQueueManager,
1919
QueueManager,
2020
)
21-
from a2a.server.request_handlers.request_handler import RequestHandler
21+
from a2a.server.request_handlers.request_handler import (
22+
RequestHandler,
23+
validate_request_params,
24+
)
2225
from a2a.server.tasks import (
2326
PushNotificationConfigStore,
2427
PushNotificationEvent,
@@ -118,6 +121,7 @@ def __init__( # noqa: PLR0913
118121
# asyncio tasks and to surface unexpected exceptions.
119122
self._background_tasks = set()
120123

124+
@validate_request_params
121125
async def on_get_task(
122126
self,
123127
params: GetTaskRequest,
@@ -133,6 +137,7 @@ async def on_get_task(
133137

134138
return apply_history_length(task, params)
135139

140+
@validate_request_params
136141
async def on_list_tasks(
137142
self,
138143
params: ListTasksRequest,
@@ -154,6 +159,7 @@ async def on_list_tasks(
154159

155160
return page
156161

162+
@validate_request_params
157163
async def on_cancel_task(
158164
self,
159165
params: CancelTaskRequest,
@@ -317,6 +323,7 @@ async def _send_push_notification_if_needed(
317323
):
318324
await self._push_sender.send_notification(task_id, event)
319325

326+
@validate_request_params
320327
async def on_message_send(
321328
self,
322329
params: SendMessageRequest,
@@ -386,6 +393,7 @@ async def push_notification_callback(event: Event) -> None:
386393

387394
return result
388395

396+
@validate_request_params
389397
async def on_message_send_stream(
390398
self,
391399
params: SendMessageRequest,
@@ -474,6 +482,7 @@ async def _cleanup_producer(
474482
async with self._running_agents_lock:
475483
self._running_agents.pop(task_id, None)
476484

485+
@validate_request_params
477486
async def on_create_task_push_notification_config(
478487
self,
479488
params: TaskPushNotificationConfig,
@@ -499,6 +508,7 @@ async def on_create_task_push_notification_config(
499508

500509
return params
501510

511+
@validate_request_params
502512
async def on_get_task_push_notification_config(
503513
self,
504514
params: GetTaskPushNotificationConfigRequest,
@@ -530,6 +540,7 @@ async def on_get_task_push_notification_config(
530540

531541
raise InternalError(message='Push notification config not found')
532542

543+
@validate_request_params
533544
async def on_subscribe_to_task(
534545
self,
535546
params: SubscribeToTaskRequest,
@@ -572,6 +583,7 @@ async def on_subscribe_to_task(
572583
async for event in result_aggregator.consume_and_emit(consumer):
573584
yield event
574585

586+
@validate_request_params
575587
async def on_list_task_push_notification_configs(
576588
self,
577589
params: ListTaskPushNotificationConfigsRequest,
@@ -597,6 +609,7 @@ async def on_list_task_push_notification_configs(
597609
configs=push_notification_config_list
598610
)
599611

612+
@validate_request_params
600613
async def on_delete_task_push_notification_config(
601614
self,
602615
params: DeleteTaskPushNotificationConfigRequest,

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,9 @@
3535
from a2a.types import a2a_pb2
3636
from a2a.types.a2a_pb2 import AgentCard
3737
from a2a.utils import proto_utils
38-
from a2a.utils.errors import (
39-
A2A_ERROR_REASONS,
40-
A2AError,
41-
TaskNotFoundError,
42-
)
38+
from a2a.utils.errors import A2A_ERROR_REASONS, A2AError, TaskNotFoundError
4339
from a2a.utils.helpers import maybe_await, validate
40+
from a2a.utils.proto_utils import validation_errors_to_bad_request
4441

4542

4643
logger = logging.getLogger(__name__)
@@ -403,11 +400,23 @@ async def abort_context(
403400
error.message if hasattr(error, 'message') else str(error)
404401
)
405402

406-
# Create standard Status and pack the ErrorInfo
403+
# Create standard Status with ErrorInfo for all A2A errors
407404
status = status_pb2.Status(code=status_code, message=error_msg)
408-
detail = any_pb2.Any()
409-
detail.Pack(error_info)
410-
status.details.append(detail)
405+
error_info_detail = any_pb2.Any()
406+
error_info_detail.Pack(error_info)
407+
status.details.append(error_info_detail)
408+
409+
# Append structured field violations for validation errors
410+
if (
411+
isinstance(error, types.InvalidParamsError)
412+
and error.data
413+
and error.data.get('errors')
414+
):
415+
bad_request_detail = any_pb2.Any()
416+
bad_request_detail.Pack(
417+
validation_errors_to_bad_request(error.data['errors'])
418+
)
419+
status.details.append(bad_request_detail)
411420

412421
# Use grpc_status to safely generate standard trailing metadata
413422
rich_status = rpc_status.to_status(status)

src/a2a/server/request_handlers/request_handler.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
import functools
2+
import inspect
3+
14
from abc import ABC, abstractmethod
2-
from collections.abc import AsyncGenerator
5+
from collections.abc import AsyncGenerator, Callable
6+
from typing import Any
7+
8+
from google.protobuf.message import Message as ProtoMessage
39

410
from a2a.server.context import ServerCallContext
511
from a2a.server.events.event_queue import Event
@@ -19,6 +25,7 @@
1925
TaskPushNotificationConfig,
2026
)
2127
from a2a.utils.errors import UnsupportedOperationError
28+
from a2a.utils.proto_utils import validate_proto_required_fields
2229

2330

2431
class RequestHandler(ABC):
@@ -218,3 +225,46 @@ async def on_delete_task_push_notification_config(
218225
Returns:
219226
None
220227
"""
228+
229+
230+
def validate_request_params(method: Callable) -> Callable:
231+
"""Decorator for RequestHandler methods to validate required fields on incoming requests."""
232+
if inspect.isasyncgenfunction(method):
233+
234+
@functools.wraps(method)
235+
async def async_gen_wrapper(
236+
self: RequestHandler,
237+
params: ProtoMessage,
238+
context: ServerCallContext,
239+
*args: Any,
240+
**kwargs: Any,
241+
) -> Any:
242+
if params is not None:
243+
validate_proto_required_fields(params)
244+
# Ensure the inner async generator is closed explicitly;
245+
# bare async-for does not call aclose() on GeneratorExit,
246+
# which on Python 3.12+ prevents the except/finally blocks
247+
# in on_message_send_stream from running on client disconnect
248+
# (background_consume and cleanup_producer tasks are never created).
249+
inner = method(self, params, context, *args, **kwargs)
250+
try:
251+
async for item in inner:
252+
yield item
253+
finally:
254+
await inner.aclose()
255+
256+
return async_gen_wrapper
257+
258+
@functools.wraps(method)
259+
async def async_wrapper(
260+
self: RequestHandler,
261+
params: ProtoMessage,
262+
context: ServerCallContext,
263+
*args: Any,
264+
**kwargs: Any,
265+
) -> Any:
266+
if params is not None:
267+
validate_proto_required_fields(params)
268+
return await method(self, params, context, *args, **kwargs)
269+
270+
return async_wrapper

src/a2a/server/request_handlers/response_helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def build_error_response(
135135
jsonrpc_error = model_class(
136136
code=code,
137137
message=str(error),
138+
data=error.data,
138139
)
139140
else:
140141
jsonrpc_error = JSONRPCInternalError(message=str(error))

src/a2a/utils/errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class A2AError(Exception):
2121
message: str = 'A2A Error'
2222
data: dict | None = None
2323

24-
def __init__(self, message: str | None = None):
24+
def __init__(self, message: str | None = None, data: dict | None = None):
2525
if message:
2626
self.message = message
27+
self.data = data
2728
super().__init__(self.message)
2829

2930

0 commit comments

Comments
 (0)