11import functools
2- import json
32import logging
4- import traceback
53
6- from collections .abc import AsyncGenerator , AsyncIterator , Awaitable , Callable
4+ from collections .abc import AsyncIterable , AsyncIterator , Awaitable , Callable
75from typing import Any
86
9- from google .protobuf import message as message_pb2
10- from pydantic import ValidationError
117from sse_starlette .sse import EventSourceResponse
128from starlette .requests import Request
139from starlette .responses import JSONResponse , Response
2218 RESTHandler ,
2319)
2420from a2a .types import (
25- A2AError ,
2621 AgentCard ,
27- InternalError ,
28- InvalidRequestError ,
29- JSONParseError ,
30- UnsupportedOperationError ,
22+ AuthenticatedExtendedCardNotConfiguredError ,
3123)
32- from a2a .utils .errors import MethodNotImplementedError
24+ from a2a .utils .error_handlers import (
25+ rest_error_handler ,
26+ rest_stream_error_handler ,
27+ )
28+ from a2a .utils .errors import ServerError
3329
3430
3531logger = logging .getLogger (__name__ )
@@ -64,92 +60,51 @@ def __init__(
6460 )
6561 self ._context_builder = context_builder or DefaultCallContextBuilder ()
6662
67- def _generate_error_response (self , error : A2AError ) -> JSONResponse :
68- """Creates a JSONResponse for an error.
69-
70- Logs the error based on its type.
71-
72- Args:
73- error: The Error object.
74-
75- Returns:
76- A `JSONResponse` object formatted as a JSON error response.
77- """
78- log_level = (
79- logging .ERROR
80- if isinstance (error , InternalError )
81- else logging .WARNING
82- )
83- logger .log (
84- log_level ,
85- 'Request Error: '
86- f"Code={ error .root .code } , Message='{ error .root .message } '"
87- f'{ ", Data=" + str (error .root .data ) if error .root .data else "" } ' ,
88- )
89- return JSONResponse (
90- f'{{"message": "{ error .root .message } "}}' ,
91- status_code = 500 ,
92- )
93-
94- def _handle_error (self , error : Exception ) -> JSONResponse :
95- traceback .print_exc ()
96- if isinstance (error , MethodNotImplementedError ):
97- return self ._generate_error_response (
98- A2AError (UnsupportedOperationError (message = error .message ))
99- )
100- if isinstance (error , json .decoder .JSONDecodeError ):
101- return self ._generate_error_response (
102- A2AError (JSONParseError (message = str (error )))
103- )
104- if isinstance (error , ValidationError ):
105- return self ._generate_error_response (
106- A2AError (InvalidRequestError (data = json .loads (error .json ()))),
107- )
108- logger .error (f'Unhandled exception: { error } ' )
109- return self ._generate_error_response (
110- A2AError (InternalError (message = str (error )))
111- )
112-
63+ @rest_error_handler
11364 async def _handle_request (
11465 self ,
11566 method : Callable [
11667 [Request , ServerCallContext ], Awaitable [dict [str , Any ]]
11768 ],
11869 request : Request ,
11970 ) -> Response :
120- try :
121- call_context = self ._context_builder .build (request )
122- response = await method (request , call_context )
123- return JSONResponse (content = response )
124- except Exception as e :
125- return self ._handle_error (e )
71+ call_context = self ._context_builder .build (request )
72+ response = await method (request , call_context )
73+ return JSONResponse (content = response )
74+
75+ @rest_error_handler
76+ async def _handle_list_request (
77+ self ,
78+ method : Callable [
79+ [Request , ServerCallContext ], Awaitable [list [dict [str , Any ]]]
80+ ],
81+ request : Request ,
82+ ) -> Response :
83+ call_context = self ._context_builder .build (request )
84+ response = await method (request , call_context )
85+ return JSONResponse (content = response )
12686
87+ @rest_stream_error_handler
12788 async def _handle_streaming_request (
12889 self ,
12990 method : Callable [
130- [Request , ServerCallContext ], AsyncIterator [ message_pb2 . Message ]
91+ [Request , ServerCallContext ], AsyncIterable [ dict [ str , Any ] ]
13192 ],
13293 request : Request ,
13394 ) -> EventSourceResponse :
134- try :
135- call_context = self ._context_builder .build (request )
95+ call_context = self ._context_builder .build (request )
13696
137- async def event_generator (
138- stream : AsyncGenerator [ str ],
139- ) -> AsyncGenerator [dict [str , str ]]:
140- async for item in stream :
141- yield {'data' : item }
97+ async def event_generator (
98+ stream : AsyncIterable [ dict [ str , Any ] ],
99+ ) -> AsyncIterator [dict [str , dict [ str , Any ] ]]:
100+ async for item in stream :
101+ yield {'data' : item }
142102
143- return EventSourceResponse (
144- event_generator (method (request , call_context ))
145- )
146- except Exception as e :
147- # Since the stream has started, we can't return a JSONResponse.
148- # Instead, we runt the error handling logic (provides logging)
149- # and reraise the error and let server framework manage
150- self ._handle_error (e )
151- raise e
103+ return EventSourceResponse (
104+ event_generator (method (request , call_context ))
105+ )
152106
107+ @rest_error_handler
153108 async def _handle_get_agent_card (self , request : Request ) -> JSONResponse :
154109 """Handles GET requests for the agent card endpoint.
155110
@@ -165,6 +120,7 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
165120 self .agent_card .model_dump (mode = 'json' , exclude_none = True )
166121 )
167122
123+ @rest_error_handler
168124 async def handle_authenticated_agent_card (
169125 self , request : Request
170126 ) -> JSONResponse :
@@ -180,9 +136,10 @@ async def handle_authenticated_agent_card(
180136 A JSONResponse containing the authenticated card.
181137 """
182138 if not self .agent_card .supports_authenticated_extended_card :
183- return JSONResponse (
184- '{"detail": "Authenticated card not supported"}' ,
185- status_code = 404 ,
139+ raise ServerError (
140+ error = AuthenticatedExtendedCardNotConfiguredError (
141+ message = 'Authenticated card not supported'
142+ )
186143 )
187144 return JSONResponse (
188145 self .agent_card .model_dump (mode = 'json' , exclude_none = True )
@@ -230,10 +187,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
230187 '/v1/tasks/{id}/pushNotificationConfigs' ,
231188 'GET' ,
232189 ): functools .partial (
233- self ._handle_request , self .handler .list_push_notifications
190+ self ._handle_list_request , self .handler .list_push_notifications
234191 ),
235192 ('/v1/tasks' , 'GET' ): functools .partial (
236- self ._handle_request , self .handler .list_tasks
193+ self ._handle_list_request , self .handler .list_tasks
237194 ),
238195 }
239196 if self .agent_card .supports_authenticated_extended_card :
0 commit comments