Skip to content

Commit a063a8e

Browse files
committed
Centralize error handling and status code mapping for rest interface
1 parent 6f592bb commit a063a8e

File tree

5 files changed

+255
-263
lines changed

5 files changed

+255
-263
lines changed

error_handlers.py

Whitespace-only changes.

src/a2a/server/apps/rest/rest_app.py

Lines changed: 42 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import functools
2-
import json
32
import logging
4-
import traceback
53

6-
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
4+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
75
from typing import Any
86

9-
from google.protobuf import message as message_pb2
10-
from pydantic import ValidationError
117
from sse_starlette.sse import EventSourceResponse
128
from starlette.requests import Request
139
from starlette.responses import JSONResponse, Response
@@ -22,14 +18,14 @@
2218
RESTHandler,
2319
)
2420
from a2a.types import (
25-
A2AError,
2621
AgentCard,
27-
InternalError,
28-
InvalidRequestError,
29-
JSONParseError,
30-
UnsupportedOperationError,
22+
AuthenticatedExtendedCardNotConfiguredError,
3123
)
32-
from a2a.utils.errors import MethodNotImplementedError
24+
from a2a.utils.error_handlers import (
25+
rest_error_handler,
26+
rest_stream_error_handler,
27+
)
28+
from a2a.utils.errors import ServerError
3329

3430

3531
logger = logging.getLogger(__name__)
@@ -64,92 +60,51 @@ def __init__(
6460
)
6561
self._context_builder = context_builder or DefaultCallContextBuilder()
6662

67-
def _generate_error_response(self, error: A2AError) -> JSONResponse:
68-
"""Creates a JSONResponse for an error.
69-
70-
Logs the error based on its type.
71-
72-
Args:
73-
error: The Error object.
74-
75-
Returns:
76-
A `JSONResponse` object formatted as a JSON error response.
77-
"""
78-
log_level = (
79-
logging.ERROR
80-
if isinstance(error, InternalError)
81-
else logging.WARNING
82-
)
83-
logger.log(
84-
log_level,
85-
'Request Error: '
86-
f"Code={error.root.code}, Message='{error.root.message}'"
87-
f'{", Data=" + str(error.root.data) if error.root.data else ""}',
88-
)
89-
return JSONResponse(
90-
f'{{"message": "{error.root.message}"}}',
91-
status_code=500,
92-
)
93-
94-
def _handle_error(self, error: Exception) -> JSONResponse:
95-
traceback.print_exc()
96-
if isinstance(error, MethodNotImplementedError):
97-
return self._generate_error_response(
98-
A2AError(UnsupportedOperationError(message=error.message))
99-
)
100-
if isinstance(error, json.decoder.JSONDecodeError):
101-
return self._generate_error_response(
102-
A2AError(JSONParseError(message=str(error)))
103-
)
104-
if isinstance(error, ValidationError):
105-
return self._generate_error_response(
106-
A2AError(InvalidRequestError(data=json.loads(error.json()))),
107-
)
108-
logger.error(f'Unhandled exception: {error}')
109-
return self._generate_error_response(
110-
A2AError(InternalError(message=str(error)))
111-
)
112-
63+
@rest_error_handler
11364
async def _handle_request(
11465
self,
11566
method: Callable[
11667
[Request, ServerCallContext], Awaitable[dict[str, Any]]
11768
],
11869
request: Request,
11970
) -> Response:
120-
try:
121-
call_context = self._context_builder.build(request)
122-
response = await method(request, call_context)
123-
return JSONResponse(content=response)
124-
except Exception as e:
125-
return self._handle_error(e)
71+
call_context = self._context_builder.build(request)
72+
response = await method(request, call_context)
73+
return JSONResponse(content=response)
74+
75+
@rest_error_handler
76+
async def _handle_list_request(
77+
self,
78+
method: Callable[
79+
[Request, ServerCallContext], Awaitable[list[dict[str, Any]]]
80+
],
81+
request: Request,
82+
) -> Response:
83+
call_context = self._context_builder.build(request)
84+
response = await method(request, call_context)
85+
return JSONResponse(content=response)
12686

87+
@rest_stream_error_handler
12788
async def _handle_streaming_request(
12889
self,
12990
method: Callable[
130-
[Request, ServerCallContext], AsyncIterator[message_pb2.Message]
91+
[Request, ServerCallContext], AsyncIterable[dict[str, Any]]
13192
],
13293
request: Request,
13394
) -> EventSourceResponse:
134-
try:
135-
call_context = self._context_builder.build(request)
95+
call_context = self._context_builder.build(request)
13696

137-
async def event_generator(
138-
stream: AsyncGenerator[str],
139-
) -> AsyncGenerator[dict[str, str]]:
140-
async for item in stream:
141-
yield {'data': item}
97+
async def event_generator(
98+
stream: AsyncIterable[dict[str, Any]],
99+
) -> AsyncIterator[dict[str, dict[str, Any]]]:
100+
async for item in stream:
101+
yield {'data': item}
142102

143-
return EventSourceResponse(
144-
event_generator(method(request, call_context))
145-
)
146-
except Exception as e:
147-
# Since the stream has started, we can't return a JSONResponse.
148-
# Instead, we runt the error handling logic (provides logging)
149-
# and reraise the error and let server framework manage
150-
self._handle_error(e)
151-
raise e
103+
return EventSourceResponse(
104+
event_generator(method(request, call_context))
105+
)
152106

107+
@rest_error_handler
153108
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
154109
"""Handles GET requests for the agent card endpoint.
155110
@@ -165,6 +120,7 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
165120
self.agent_card.model_dump(mode='json', exclude_none=True)
166121
)
167122

123+
@rest_error_handler
168124
async def handle_authenticated_agent_card(
169125
self, request: Request
170126
) -> JSONResponse:
@@ -180,9 +136,10 @@ async def handle_authenticated_agent_card(
180136
A JSONResponse containing the authenticated card.
181137
"""
182138
if not self.agent_card.supports_authenticated_extended_card:
183-
return JSONResponse(
184-
'{"detail": "Authenticated card not supported"}',
185-
status_code=404,
139+
raise ServerError(
140+
error=AuthenticatedExtendedCardNotConfiguredError(
141+
message='Authenticated card not supported'
142+
)
186143
)
187144
return JSONResponse(
188145
self.agent_card.model_dump(mode='json', exclude_none=True)
@@ -230,10 +187,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
230187
'/v1/tasks/{id}/pushNotificationConfigs',
231188
'GET',
232189
): functools.partial(
233-
self._handle_request, self.handler.list_push_notifications
190+
self._handle_list_request, self.handler.list_push_notifications
234191
),
235192
('/v1/tasks', 'GET'): functools.partial(
236-
self._handle_request, self.handler.list_tasks
193+
self._handle_list_request, self.handler.list_tasks
237194
),
238195
}
239196
if self.agent_card.supports_authenticated_extended_card:

0 commit comments

Comments
 (0)