Skip to content

Commit 842e192

Browse files
committed
Create a wrapper exception for A2AError so that it can be raised in rest_handler.py
1 parent 4f2d121 commit 842e192

File tree

2 files changed

+61
-24
lines changed

2 files changed

+61
-24
lines changed

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
TaskQueryParams,
2121
)
2222
from a2a.utils import proto_utils
23-
from a2a.utils.errors import ServerError
23+
from a2a.utils.errors import (
24+
A2AErrorWrapperError,
25+
ServerError,
26+
)
2427
from a2a.utils.helpers import validate
2528
from a2a.utils.telemetry import SpanKind, trace_class
2629

@@ -56,7 +59,7 @@ def __init__(
5659
async def on_message_send(
5760
self,
5861
request: Request,
59-
context: ServerCallContext | None = None,
62+
context: ServerCallContext,
6063
) -> str:
6164
"""Handles the 'message/send' REST method.
6265
@@ -86,7 +89,13 @@ async def on_message_send(
8689
proto_utils.ToProto.task_or_message(task_or_message)
8790
)
8891
except ServerError as e:
89-
raise A2AError(error=e.error if e.error else InternalError()) from e
92+
raise A2AErrorWrapperError(
93+
error=A2AError(
94+
root=e.error
95+
if e.error
96+
else InternalError(message='Internal error')
97+
)
98+
) from e
9099

91100
@validate(
92101
lambda self: self.agent_card.capabilities.streaming,
@@ -95,7 +104,7 @@ async def on_message_send(
95104
async def on_message_send_stream(
96105
self,
97106
request: Request,
98-
context: ServerCallContext | None = None,
107+
context: ServerCallContext,
99108
) -> AsyncIterable[str]:
100109
"""Handles the 'message/stream' REST method.
101110
@@ -125,13 +134,15 @@ async def on_message_send_stream(
125134
response = proto_utils.ToProto.stream_response(event)
126135
yield MessageToJson(response)
127136
except ServerError as e:
128-
raise A2AError(error=e.error if e.error else InternalError()) from e
137+
raise A2AErrorWrapperError(
138+
error=A2AError(root=e.error if e.error else InternalError())
139+
) from e
129140
return
130141

131142
async def on_cancel_task(
132143
self,
133144
request: Request,
134-
context: ServerCallContext | None = None,
145+
context: ServerCallContext,
135146
) -> str:
136147
"""Handles the 'tasks/cancel' REST method.
137148
@@ -153,8 +164,10 @@ async def on_cancel_task(
153164
return MessageToJson(proto_utils.ToProto.task(task))
154165
raise ServerError(error=TaskNotFoundError())
155166
except ServerError as e:
156-
raise A2AError(
157-
error=e.error if e.error else InternalError(),
167+
raise A2AErrorWrapperError(
168+
error=A2AError(
169+
root=e.error if e.error else InternalError(),
170+
)
158171
) from e
159172

160173
@validate(
@@ -164,7 +177,7 @@ async def on_cancel_task(
164177
async def on_resubscribe_to_task(
165178
self,
166179
request: Request,
167-
context: ServerCallContext | None = None,
180+
context: ServerCallContext,
168181
) -> AsyncIterable[str]:
169182
"""Handles the 'tasks/resubscribe' REST method.
170183
@@ -189,12 +202,14 @@ async def on_resubscribe_to_task(
189202
MessageToJson(proto_utils.ToProto.stream_response(event))
190203
)
191204
except ServerError as e:
192-
raise A2AError(error=e.error if e.error else InternalError()) from e
205+
raise A2AErrorWrapperError(
206+
error=A2AError(root=e.error if e.error else InternalError())
207+
) from e
193208

194209
async def get_push_notification(
195210
self,
196211
request: Request,
197-
context: ServerCallContext | None = None,
212+
context: ServerCallContext,
198213
) -> str:
199214
"""Handles the 'tasks/pushNotificationConfig/get' REST method.
200215
@@ -212,7 +227,7 @@ async def get_push_notification(
212227
push_id = request.path_params['push_id']
213228
if push_id:
214229
params = GetTaskPushNotificationConfigParams(
215-
id=task_id, push_id=push_id
230+
id=task_id, push_notification_config_id=push_id
216231
)
217232
else:
218233
params = TaskIdParams(id=task_id)
@@ -225,7 +240,9 @@ async def get_push_notification(
225240
proto_utils.ToProto.task_push_notification_config(config)
226241
)
227242
except ServerError as e:
228-
raise A2AError(error=e.error if e.error else InternalError()) from e
243+
raise A2AErrorWrapperError(
244+
error=A2AError(root=e.error if e.error else InternalError())
245+
) from e
229246

230247
@validate(
231248
lambda self: self.agent_card.capabilities.pushNotifications,
@@ -234,7 +251,7 @@ async def get_push_notification(
234251
async def set_push_notification(
235252
self,
236253
request: Request,
237-
context: ServerCallContext | None = None,
254+
context: ServerCallContext,
238255
) -> str:
239256
"""Handles the 'tasks/pushNotificationConfig/set' REST method.
240257
@@ -255,9 +272,8 @@ async def set_push_notification(
255272
try:
256273
_ = request.path_params['id']
257274
body = await request.body()
258-
params = a2a_pb2.TaskPushNotificationConfig()
275+
params = a2a_pb2.CreateTaskPushNotificationConfigRequest()
259276
Parse(body, params)
260-
params = TaskPushNotificationConfig.model_validate(body)
261277
a2a_request = (
262278
proto_utils.FromProto.task_push_notification_config_request(
263279
params,
@@ -272,12 +288,14 @@ async def set_push_notification(
272288
proto_utils.ToProto.task_push_notification_config(config)
273289
)
274290
except ServerError as e:
275-
raise A2AError(error=e.error if e.error else InternalError()) from e
291+
raise A2AErrorWrapperError(
292+
error=A2AError(root=e.error if e.error else InternalError())
293+
) from e
276294

277295
async def on_get_task(
278296
self,
279297
request: Request,
280-
context: ServerCallContext | None = None,
298+
context: ServerCallContext,
281299
) -> str:
282300
"""Handles the 'v1/tasks/{id}' REST method.
283301
@@ -293,21 +311,24 @@ async def on_get_task(
293311
"""
294312
try:
295313
task_id = request.path_params['id']
296-
history_length = request.query_params.get('historyLength', None)
297-
if history_length:
298-
history_length = int(history_length)
314+
history_length_str = request.query_params.get('historyLength')
315+
history_length = (
316+
int(history_length_str) if history_length_str else None
317+
)
299318
params = TaskQueryParams(id=task_id, history_length=history_length)
300319
task = await self.request_handler.on_get_task(params, context)
301320
if task:
302321
return MessageToJson(proto_utils.ToProto.task(task))
303322
raise ServerError(error=TaskNotFoundError())
304323
except ServerError as e:
305-
raise A2AError(error=e.error if e.error else InternalError()) from e
324+
raise A2AErrorWrapperError(
325+
error=A2AError(root=e.error if e.error else InternalError())
326+
) from e
306327

307328
async def list_push_notifications(
308329
self,
309330
request: Request,
310-
context: ServerCallContext | None = None,
331+
context: ServerCallContext,
311332
) -> list[TaskPushNotificationConfig]:
312333
"""Handles the 'tasks/pushNotificationConfig/list' REST method.
313334
@@ -328,7 +349,7 @@ async def list_push_notifications(
328349
async def list_tasks(
329350
self,
330351
request: Request,
331-
context: ServerCallContext | None = None,
352+
context: ServerCallContext,
332353
) -> list[Task]:
333354
"""Handles the 'tasks/list' REST method.
334355

src/a2a/utils/errors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Custom exceptions for A2A server-side errors."""
22

33
from a2a.types import (
4+
A2AError,
45
ContentTypeNotSupportedError,
56
InternalError,
67
InvalidAgentResponseError,
@@ -35,6 +36,21 @@ def __init__(
3536
super().__init__(f'Not Implemented operation Error: {message}')
3637

3738

39+
class A2AErrorWrapperError(Exception):
40+
"""Wrapper exception for A2A that is discriminated union of all standard JSON-RPC and A2A-specific error types."""
41+
42+
def __init__(
43+
self,
44+
error: (A2AError),
45+
):
46+
"""Initialize the A2AError wrapper.
47+
48+
Args:
49+
error: The specific A2AError model instance.
50+
"""
51+
self.error = error
52+
53+
3854
class ServerError(Exception):
3955
"""Wrapper exception for A2A or JSON-RPC errors originating from the server's logic.
4056

0 commit comments

Comments
 (0)