diff --git a/src/a2a/server/apps/jsonrpc/fastapi_app.py b/src/a2a/server/apps/jsonrpc/fastapi_app.py index 4ba7fdce..ace2c6ae 100644 --- a/src/a2a/server/apps/jsonrpc/fastapi_app.py +++ b/src/a2a/server/apps/jsonrpc/fastapi_app.py @@ -77,6 +77,7 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], AgentCard ] | None = None, + max_content_length: int | None = 10 * 1024 * 1024, # 10MB ) -> None: """Initializes the A2AFastAPIApplication. @@ -94,6 +95,8 @@ def __init__( # noqa: PLR0913 extended_card_modifier: An optional callback to dynamically modify the extended agent card before it is served. It receives the call context. + max_content_length: The maximum allowed content length for incoming + requests. Defaults to 10MB. Set to None for unbounded maximum. """ if not _package_fastapi_installed: raise ImportError( @@ -108,6 +111,7 @@ def __init__( # noqa: PLR0913 context_builder=context_builder, card_modifier=card_modifier, extended_card_modifier=extended_card_modifier, + max_content_length=max_content_length, ) def add_routes_to_app( diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index d258916c..3e7c2854 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -91,8 +91,6 @@ Response = Any HTTP_413_REQUEST_ENTITY_TOO_LARGE = Any -MAX_CONTENT_LENGTH = 10_000_000 - class StarletteUserProxy(A2AUser): """Adapts the Starlette User class to the A2A user representation.""" @@ -185,6 +183,7 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], AgentCard ] | None = None, + max_content_length: int | None = 10 * 1024 * 1024, # 10MB ) -> None: """Initializes the JSONRPCApplication. @@ -202,6 +201,8 @@ def __init__( # noqa: PLR0913 extended_card_modifier: An optional callback to dynamically modify the extended agent card before it is served. It receives the call context. + max_content_length: The maximum allowed content length for incoming + requests. Defaults to 10MB. Set to None for unbounded maximum. """ if not _package_starlette_installed: raise ImportError( @@ -220,6 +221,7 @@ def __init__( # noqa: PLR0913 extended_card_modifier=extended_card_modifier, ) self._context_builder = context_builder or DefaultCallContextBuilder() + self._max_content_length = max_content_length def _generate_error_response( self, request_id: str | int | None, error: JSONRPCError | A2AError @@ -261,6 +263,22 @@ def _generate_error_response( status_code=200, ) + def _allowed_content_length(self, request: Request) -> bool: + """Checks if the request content length is within the allowed maximum. + + Args: + request: The incoming Starlette Request object. + + Returns: + False if the content length is larger than the allowed maximum, True otherwise. + """ + if self._max_content_length is not None: + with contextlib.suppress(ValueError): + content_length = int(request.headers.get('content-length', '0')) + if content_length and content_length > self._max_content_length: + return False + return True + async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 """Handles incoming POST requests to the main A2A endpoint. @@ -291,18 +309,14 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 request_id, str | int ): request_id = None - # Treat very large payloads as invalid request (-32600) before routing - with contextlib.suppress(Exception): - content_length = int(request.headers.get('content-length', '0')) - if content_length and content_length > MAX_CONTENT_LENGTH: - return self._generate_error_response( - request_id, - A2AError( - root=InvalidRequestError( - message='Payload too large' - ) - ), - ) + # Treat payloads lager than allowed as invalid request (-32600) before routing + if not self._allowed_content_length(request): + return self._generate_error_response( + request_id, + A2AError( + root=InvalidRequestError(message='Payload too large') + ), + ) logger.debug('Request body: %s', body) # 1) Validate base JSON-RPC structure only (-32600 on failure) try: diff --git a/src/a2a/server/apps/jsonrpc/starlette_app.py b/src/a2a/server/apps/jsonrpc/starlette_app.py index b268d043..1effa9d5 100644 --- a/src/a2a/server/apps/jsonrpc/starlette_app.py +++ b/src/a2a/server/apps/jsonrpc/starlette_app.py @@ -59,6 +59,7 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], AgentCard ] | None = None, + max_content_length: int | None = 10 * 1024 * 1024, # 10MB ) -> None: """Initializes the A2AStarletteApplication. @@ -76,6 +77,8 @@ def __init__( # noqa: PLR0913 extended_card_modifier: An optional callback to dynamically modify the extended agent card before it is served. It receives the call context. + max_content_length: The maximum allowed content length for incoming + requests. Defaults to 10MB. Set to None for unbounded maximum. """ if not _package_starlette_installed: raise ImportError( @@ -90,6 +93,7 @@ def __init__( # noqa: PLR0913 context_builder=context_builder, card_modifier=card_modifier, extended_card_modifier=extended_card_modifier, + max_content_length=max_content_length, ) def routes( diff --git a/tests/server/apps/jsonrpc/test_serialization.py b/tests/server/apps/jsonrpc/test_serialization.py index 9365017b..f6778046 100644 --- a/tests/server/apps/jsonrpc/test_serialization.py +++ b/tests/server/apps/jsonrpc/test_serialization.py @@ -136,6 +136,42 @@ def test_handle_oversized_payload(agent_card_with_api_key: AgentCard): assert data['error']['code'] == InvalidRequestError().code +@pytest.mark.parametrize( + 'max_content_length', + [ + None, + 11 * 1024 * 1024, + 30 * 1024 * 1024, + ], +) +def test_handle_oversized_payload_with_max_content_length( + agent_card_with_api_key: AgentCard, + max_content_length: int | None, +): + """Test handling of JSON payloads with sizes within custom max_content_length.""" + handler = mock.AsyncMock() + app_instance = A2AStarletteApplication( + agent_card_with_api_key, handler, max_content_length=max_content_length + ) + client = TestClient(app_instance.build()) + + large_string = 'a' * 11 * 1_000_000 # 11MB string + payload = { + 'jsonrpc': '2.0', + 'method': 'test', + 'id': 1, + 'params': {'data': large_string}, + } + + response = client.post('/', json=payload) + assert response.status_code == 200 + data = response.json() + # When max_content_length is set, requests up to that size should not be + # rejected due to payload size. The request might fail for other reasons, + # but it shouldn't be an InvalidRequestError related to the content length. + assert data['error']['code'] != InvalidRequestError().code + + def test_handle_unicode_characters(agent_card_with_api_key: AgentCard): """Test handling of unicode characters in JSON payload.""" handler = mock.AsyncMock()