2020 TaskQueryParams ,
2121)
2222from a2a .utils import proto_utils
23- from a2a .utils .errors import ServerError
23+ from a2a .utils .errors import (
24+ A2AErrorWrapperError ,
25+ ServerError ,
26+ )
2427from a2a .utils .helpers import validate
2528from a2a .utils .telemetry import SpanKind , trace_class
2629
@@ -56,7 +59,7 @@ def __init__(
5659 async def on_message_send (
5760 self ,
5861 request : Request ,
59- context : ServerCallContext | None = None ,
62+ context : ServerCallContext ,
6063 ) -> str :
6164 """Handles the 'message/send' REST method.
6265
@@ -86,7 +89,13 @@ async def on_message_send(
8689 proto_utils .ToProto .task_or_message (task_or_message )
8790 )
8891 except ServerError as e :
89- raise A2AError (error = e .error if e .error else InternalError ()) from e
92+ raise A2AErrorWrapperError (
93+ error = A2AError (
94+ root = e .error
95+ if e .error
96+ else InternalError (message = 'Internal error' )
97+ )
98+ ) from e
9099
91100 @validate (
92101 lambda self : self .agent_card .capabilities .streaming ,
@@ -95,7 +104,7 @@ async def on_message_send(
95104 async def on_message_send_stream (
96105 self ,
97106 request : Request ,
98- context : ServerCallContext | None = None ,
107+ context : ServerCallContext ,
99108 ) -> AsyncIterable [str ]:
100109 """Handles the 'message/stream' REST method.
101110
@@ -125,13 +134,15 @@ async def on_message_send_stream(
125134 response = proto_utils .ToProto .stream_response (event )
126135 yield MessageToJson (response )
127136 except ServerError as e :
128- raise A2AError (error = e .error if e .error else InternalError ()) from e
137+ raise A2AErrorWrapperError (
138+ error = A2AError (root = e .error if e .error else InternalError ())
139+ ) from e
129140 return
130141
131142 async def on_cancel_task (
132143 self ,
133144 request : Request ,
134- context : ServerCallContext | None = None ,
145+ context : ServerCallContext ,
135146 ) -> str :
136147 """Handles the 'tasks/cancel' REST method.
137148
@@ -153,8 +164,10 @@ async def on_cancel_task(
153164 return MessageToJson (proto_utils .ToProto .task (task ))
154165 raise ServerError (error = TaskNotFoundError ())
155166 except ServerError as e :
156- raise A2AError (
157- error = e .error if e .error else InternalError (),
167+ raise A2AErrorWrapperError (
168+ error = A2AError (
169+ root = e .error if e .error else InternalError (),
170+ )
158171 ) from e
159172
160173 @validate (
@@ -164,7 +177,7 @@ async def on_cancel_task(
164177 async def on_resubscribe_to_task (
165178 self ,
166179 request : Request ,
167- context : ServerCallContext | None = None ,
180+ context : ServerCallContext ,
168181 ) -> AsyncIterable [str ]:
169182 """Handles the 'tasks/resubscribe' REST method.
170183
@@ -189,12 +202,14 @@ async def on_resubscribe_to_task(
189202 MessageToJson (proto_utils .ToProto .stream_response (event ))
190203 )
191204 except ServerError as e :
192- raise A2AError (error = e .error if e .error else InternalError ()) from e
205+ raise A2AErrorWrapperError (
206+ error = A2AError (root = e .error if e .error else InternalError ())
207+ ) from e
193208
194209 async def get_push_notification (
195210 self ,
196211 request : Request ,
197- context : ServerCallContext | None = None ,
212+ context : ServerCallContext ,
198213 ) -> str :
199214 """Handles the 'tasks/pushNotificationConfig/get' REST method.
200215
@@ -212,7 +227,7 @@ async def get_push_notification(
212227 push_id = request .path_params ['push_id' ]
213228 if push_id :
214229 params = GetTaskPushNotificationConfigParams (
215- id = task_id , push_id = push_id
230+ id = task_id , push_notification_config_id = push_id
216231 )
217232 else :
218233 params = TaskIdParams (id = task_id )
@@ -225,7 +240,9 @@ async def get_push_notification(
225240 proto_utils .ToProto .task_push_notification_config (config )
226241 )
227242 except ServerError as e :
228- raise A2AError (error = e .error if e .error else InternalError ()) from e
243+ raise A2AErrorWrapperError (
244+ error = A2AError (root = e .error if e .error else InternalError ())
245+ ) from e
229246
230247 @validate (
231248 lambda self : self .agent_card .capabilities .pushNotifications ,
@@ -234,7 +251,7 @@ async def get_push_notification(
234251 async def set_push_notification (
235252 self ,
236253 request : Request ,
237- context : ServerCallContext | None = None ,
254+ context : ServerCallContext ,
238255 ) -> str :
239256 """Handles the 'tasks/pushNotificationConfig/set' REST method.
240257
@@ -255,9 +272,8 @@ async def set_push_notification(
255272 try :
256273 _ = request .path_params ['id' ]
257274 body = await request .body ()
258- params = a2a_pb2 .TaskPushNotificationConfig ()
275+ params = a2a_pb2 .CreateTaskPushNotificationConfigRequest ()
259276 Parse (body , params )
260- params = TaskPushNotificationConfig .model_validate (body )
261277 a2a_request = (
262278 proto_utils .FromProto .task_push_notification_config_request (
263279 params ,
@@ -272,12 +288,14 @@ async def set_push_notification(
272288 proto_utils .ToProto .task_push_notification_config (config )
273289 )
274290 except ServerError as e :
275- raise A2AError (error = e .error if e .error else InternalError ()) from e
291+ raise A2AErrorWrapperError (
292+ error = A2AError (root = e .error if e .error else InternalError ())
293+ ) from e
276294
277295 async def on_get_task (
278296 self ,
279297 request : Request ,
280- context : ServerCallContext | None = None ,
298+ context : ServerCallContext ,
281299 ) -> str :
282300 """Handles the 'v1/tasks/{id}' REST method.
283301
@@ -293,21 +311,24 @@ async def on_get_task(
293311 """
294312 try :
295313 task_id = request .path_params ['id' ]
296- history_length = request .query_params .get ('historyLength' , None )
297- if history_length :
298- history_length = int (history_length )
314+ history_length_str = request .query_params .get ('historyLength' )
315+ history_length = (
316+ int (history_length_str ) if history_length_str else None
317+ )
299318 params = TaskQueryParams (id = task_id , history_length = history_length )
300319 task = await self .request_handler .on_get_task (params , context )
301320 if task :
302321 return MessageToJson (proto_utils .ToProto .task (task ))
303322 raise ServerError (error = TaskNotFoundError ())
304323 except ServerError as e :
305- raise A2AError (error = e .error if e .error else InternalError ()) from e
324+ raise A2AErrorWrapperError (
325+ error = A2AError (root = e .error if e .error else InternalError ())
326+ ) from e
306327
307328 async def list_push_notifications (
308329 self ,
309330 request : Request ,
310- context : ServerCallContext | None = None ,
331+ context : ServerCallContext ,
311332 ) -> list [TaskPushNotificationConfig ]:
312333 """Handles the 'tasks/pushNotificationConfig/list' REST method.
313334
@@ -328,7 +349,7 @@ async def list_push_notifications(
328349 async def list_tasks (
329350 self ,
330351 request : Request ,
331- context : ServerCallContext | None = None ,
352+ context : ServerCallContext ,
332353 ) -> list [Task ]:
333354 """Handles the 'tasks/list' REST method.
334355
0 commit comments