2929 GetTaskPushNotificationConfigRequest ,
3030 GetTaskRequest ,
3131 InternalError ,
32+ InvalidParamsError ,
3233 InvalidRequestError ,
3334 JSONParseError ,
3435 JSONRPCError ,
3536 JSONRPCErrorResponse ,
3637 JSONRPCRequest ,
3738 JSONRPCResponse ,
3839 ListTaskPushNotificationConfigRequest ,
40+ MethodNotFoundError ,
3941 SendMessageRequest ,
4042 SendStreamingMessageRequest ,
4143 SendStreamingMessageResponse ,
9193 URL = Any
9294 HTTP_413_REQUEST_ENTITY_TOO_LARGE = Any
9395
96+ MAX_CONTENT_LENGTH = 1_000_000
97+
9498
9599class StarletteUserProxy (A2AUser ):
96100 """Adapts the Starlette User class to the A2A user representation."""
@@ -153,6 +157,25 @@ class JSONRPCApplication(ABC):
153157 (SSE).
154158 """
155159
160+ # Method-to-model mapping for centralized routing
161+ A2ARequestModel = (
162+ SendMessageRequest
163+ | SendStreamingMessageRequest
164+ | GetTaskRequest
165+ | CancelTaskRequest
166+ | SetTaskPushNotificationConfigRequest
167+ | GetTaskPushNotificationConfigRequest
168+ | ListTaskPushNotificationConfigRequest
169+ | DeleteTaskPushNotificationConfigRequest
170+ | TaskResubscriptionRequest
171+ | GetAuthenticatedExtendedCardRequest
172+ )
173+
174+ METHOD_TO_MODEL : dict [str , type [A2ARequestModel ]] = {
175+ model .model_fields ['method' ].default : model
176+ for model in A2ARequestModel .__args__
177+ }
178+
156179 def __init__ ( # noqa: PLR0913
157180 self ,
158181 agent_card : AgentCard ,
@@ -273,17 +296,60 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
273296 body = await request .json ()
274297 if isinstance (body , dict ):
275298 request_id = body .get ('id' )
299+ # Ensure request_id is valid for JSON-RPC response (str/int/None only)
300+ if request_id is not None and not isinstance (
301+ request_id , str | int
302+ ):
303+ request_id = None
304+ # Treat very large payloads as invalid request (-32600) before routing
305+ with contextlib .suppress (Exception ):
306+ content_length = int (request .headers .get ('content-length' , '0' ))
307+ if content_length and content_length > MAX_CONTENT_LENGTH :
308+ return self ._generate_error_response (
309+ request_id ,
310+ A2AError (
311+ root = InvalidRequestError (
312+ message = 'Payload too large'
313+ )
314+ ),
315+ )
316+ logger .debug ('Request body: %s' , body )
317+ # 1) Validate base JSON-RPC structure only (-32600 on failure)
318+ try :
319+ base_request = JSONRPCRequest .model_validate (body )
320+ except ValidationError as e :
321+ logger .exception ('Failed to validate base JSON-RPC request' )
322+ return self ._generate_error_response (
323+ request_id ,
324+ A2AError (
325+ root = InvalidRequestError (data = json .loads (e .json ()))
326+ ),
327+ )
276328
277- # First, validate the basic JSON-RPC structure. This is crucial
278- # because the A2ARequest model is a discriminated union where some
279- # request types have default values for the 'method' field
280- JSONRPCRequest .model_validate (body )
329+ # 2) Route by method name; unknown -> -32601, known -> validate params (-32602 on failure)
330+ method = base_request .method
281331
282- a2a_request = A2ARequest .model_validate (body )
332+ model_class = self .METHOD_TO_MODEL .get (method )
333+ if not model_class :
334+ return self ._generate_error_response (
335+ request_id , A2AError (root = MethodNotFoundError ())
336+ )
337+ try :
338+ specific_request = model_class .model_validate (body )
339+ except ValidationError as e :
340+ logger .exception ('Failed to validate base JSON-RPC request' )
341+ return self ._generate_error_response (
342+ request_id ,
343+ A2AError (
344+ root = InvalidParamsError (data = json .loads (e .json ()))
345+ ),
346+ )
283347
348+ # 3) Build call context and wrap the request for downstream handling
284349 call_context = self ._context_builder .build (request )
285350
286- request_id = a2a_request .root .id
351+ request_id = specific_request .id
352+ a2a_request = A2ARequest (root = specific_request )
287353 request_obj = a2a_request .root
288354
289355 if isinstance (
@@ -307,12 +373,6 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
307373 return self ._generate_error_response (
308374 None , A2AError (root = JSONParseError (message = str (e )))
309375 )
310- except ValidationError as e :
311- traceback .print_exc ()
312- return self ._generate_error_response (
313- request_id ,
314- A2AError (root = InvalidRequestError (data = json .loads (e .json ()))),
315- )
316376 except HTTPException as e :
317377 if e .status_code == HTTP_413_REQUEST_ENTITY_TOO_LARGE :
318378 return self ._generate_error_response (
0 commit comments