diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 63baf9fe644..e5745ebddf3 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -5,6 +5,7 @@ import logging from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence +from urllib.parse import parse_qs from pydantic import BaseModel @@ -30,6 +31,11 @@ logger = logging.getLogger(__name__) +# Constants +CONTENT_DISPOSITION_NAME_PARAM = "name=" +APPLICATION_JSON_CONTENT_TYPE = "application/json" +APPLICATION_FORM_CONTENT_TYPE = "application/x-www-form-urlencoded" + class OpenAPIValidationMiddleware(BaseMiddlewareHandler): """ @@ -246,28 +252,61 @@ def _prepare_response_content( def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]: """ - Get the request body from the event, and parse it as JSON. + Get the request body from the event, and parse it according to content type. """ + content_type = app.current_event.headers.get("content-type", "").strip() + + # Handle JSON content + if not content_type or content_type.startswith(APPLICATION_JSON_CONTENT_TYPE): + return self._parse_json_data(app) + + # Handle URL-encoded form data + elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE): + return self._parse_form_data(app) - content_type = app.current_event.headers.get("content-type") - if not content_type or content_type.strip().startswith("application/json"): - try: - return app.current_event.json_body - except json.JSONDecodeError as e: - raise RequestValidationError( - [ - { - "type": "json_invalid", - "loc": ("body", e.pos), - "msg": "JSON decode error", - "input": {}, - "ctx": {"error": e.msg}, - }, - ], - body=e.doc, - ) from e else: - raise NotImplementedError("Only JSON body is supported") + raise NotImplementedError("Only JSON body or Form() are supported") + + def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]: + """Parse JSON data from the request body.""" + try: + return app.current_event.json_body + except json.JSONDecodeError as e: + raise RequestValidationError( + [ + { + "type": "json_invalid", + "loc": ("body", e.pos), + "msg": "JSON decode error", + "input": {}, + "ctx": {"error": e.msg}, + }, + ], + body=e.doc, + ) from e + + def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]: + """Parse URL-encoded form data from the request body.""" + try: + body = app.current_event.decoded_body or "" + # parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values + parsed = parse_qs(body, keep_blank_values=True) + + result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()} + return result + + except Exception as e: # pragma: no cover + raise RequestValidationError( # pragma: no cover + [ + { + "type": "form_invalid", + "loc": ("body",), + "msg": "Form data parsing error", + "input": {}, + "ctx": {"error": str(e)}, + }, + ], + ) from e def _request_params_to_args( diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 976ce9f0454..98a8740a74f 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -14,12 +14,12 @@ from aws_lambda_powertools.event_handler.openapi.params import ( Body, Dependant, + Form, Header, Param, ParamTypes, Query, _File, - _Form, analyze_param, create_response_field, get_flat_dependant, @@ -348,6 +348,7 @@ def get_body_field(*, dependant: Dependant, name: str) -> ModelField | None: alias="body", field_info=body_field_info(**body_field_info_kwargs), ) + return final_field @@ -369,9 +370,9 @@ def get_body_field_info( if any(isinstance(f.field_info, _File) for f in flat_dependant.body_params): # MAINTENANCE: body_field_info: type[Body] = _File raise NotImplementedError("_File fields are not supported in request bodies") - elif any(isinstance(f.field_info, _Form) for f in flat_dependant.body_params): - # MAINTENANCE: body_field_info: type[Body] = _Form - raise NotImplementedError("_Form fields are not supported in request bodies") + elif any(isinstance(f.field_info, Form) for f in flat_dependant.body_params): + body_field_info = Body + body_field_info_kwargs["media_type"] = "application/x-www-form-urlencoded" else: body_field_info = Body diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 7b1b1c06f49..8fc8d0becfa 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -737,9 +737,9 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.default})" -class _Form(Body): +class Form(Body): """ - A class used internally to represent a form parameter in a path operation. + A class used to represent a form parameter in a path operation. """ def __init__( @@ -809,9 +809,9 @@ def __init__( ) -class _File(_Form): +class _File(Form): """ - A class used internally to represent a file parameter in a path operation. + A class used to represent a file parameter in a path operation. """ def __init__( @@ -848,6 +848,14 @@ def __init__( json_schema_extra: dict[str, Any] | None = None, **extra: Any, ): + # For file uploads, ensure the OpenAPI schema has the correct format + # Also we can't test it + file_schema_extra = {"format": "binary"} # pragma: no cover + if json_schema_extra: # pragma: no cover + json_schema_extra.update(file_schema_extra) # pragma: no cover + else: # pragma: no cover + json_schema_extra = file_schema_extra # pragma: no cover + super().__init__( default=default, default_factory=default_factory, diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 8afc1a0ca6e..2b7ef205227 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -523,6 +523,18 @@ In the following example, we use a new `Header` OpenAPI type to add [one out of 1. `cloudfront_viewer_country` is a list that must contain values from the `CountriesAllowed` enumeration. +#### Handling form data + +!!! info "You must set `enable_validation=True` to handle file uploads and form data via type annotation." + +You can use the `Form` type to tell the Event Handler that a parameter expects file upload or form data. This automatically sets the correct OpenAPI schema for `application/x-www-form-urlencoded` requests. + +=== "working_with_form_data.py" + + ```python hl_lines="4 11 12" + --8<-- "examples/event_handler_rest/src/working_with_form_data.py" + ``` + #### Supported types for response serialization With data validation enabled, we natively support serializing the following data types to JSON: diff --git a/examples/event_handler_rest/src/working_with_form_data.py b/examples/event_handler_rest/src/working_with_form_data.py new file mode 100644 index 00000000000..632626475da --- /dev/null +++ b/examples/event_handler_rest/src/working_with_form_data.py @@ -0,0 +1,19 @@ +from typing import Annotated + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.params import Form + +app = APIGatewayRestResolver(enable_validation=True) + + +@app.post("/submit_form") +def upload_file( + name: Annotated[str, Form(description="Your name")], + age: Annotated[str, Form(description="Your age")], +): + # You can access form data + return {"message": f"Your name is {name} and age is {age}"} + + +def lambda_handler(event, context): + return app.resolve(event, context) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_params.py b/tests/functional/event_handler/_pydantic/test_openapi_params.py index fdaf23c5a0b..19b5287d66a 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_params.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_params.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import datetime -from typing import List, Tuple +from typing import List, Optional, Tuple from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -14,6 +14,7 @@ ) from aws_lambda_powertools.event_handler.openapi.params import ( Body, + Form, Header, Param, ParamTypes, @@ -649,3 +650,129 @@ def handler( assert parameter.schema_.type == "integer" assert parameter.schema_.default == 1 assert parameter.schema_.title == "Count" + + +def test_openapi_form_only_parameters(): + """Test Form parameters generate application/x-www-form-urlencoded content type.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/form-data") + def create_form_data( + name: Annotated[str, Form(description="User name")], + email: Annotated[str, Form(description="User email")] = "test@example.com", + ): + return {"name": name, "email": email} + + schema = app.get_openapi_schema() + + # Check that the endpoint is present + assert "/form-data" in schema.paths + + post_op = schema.paths["/form-data"].post + assert post_op is not None + + # Check request body + request_body = post_op.requestBody + assert request_body is not None + + # Check content type is application/x-www-form-urlencoded + assert "application/x-www-form-urlencoded" in request_body.content + + # Get the schema reference + form_content = request_body.content["application/x-www-form-urlencoded"] + assert form_content.schema_ is not None + + # Check that it references a component schema + schema_ref = form_content.schema_.ref + assert schema_ref is not None + assert schema_ref.startswith("#/components/schemas/") + + # Get the component schema + component_name = schema_ref.split("/")[-1] + assert component_name in schema.components.schemas + + component_schema = schema.components.schemas[component_name] + properties = component_schema.properties + + # Check form parameters + assert "name" in properties + name_prop = properties["name"] + assert name_prop.type == "string" + assert name_prop.description == "User name" + + assert "email" in properties + email_prop = properties["email"] + assert email_prop.type == "string" + assert email_prop.description == "User email" + assert email_prop.default == "test@example.com" + + # Check required fields (only name should be required since email has default) + assert component_schema.required == ["name"] + + +def test_openapi_mixed_body_media_types(): + """Test mixed Body parameters with different media types.""" + + class UserData(BaseModel): + name: str + email: str + + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/mixed-body") + def mixed_body_endpoint(user_data: Annotated[UserData, Body(media_type="application/json")]): + return {"status": "created"} + + schema = app.get_openapi_schema() + + # Check that the endpoint uses the specified media type + assert "/mixed-body" in schema.paths + + post_op = schema.paths["/mixed-body"].post + request_body = post_op.requestBody + + # Should use the specified media type + assert "application/json" in request_body.content + + +def test_openapi_form_parameter_edge_cases(): + """Test Form parameters with various edge cases.""" + + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/form-edge-cases") + def form_edge_cases( + required_field: Annotated[str, Form(description="Required field")], + optional_field: Annotated[Optional[str], Form(description="Optional field")] = None, + field_with_default: Annotated[str, Form(description="Field with default")] = "default_value", + ): + return {"required": required_field, "optional": optional_field, "default": field_with_default} + + schema = app.get_openapi_schema() + + # Check that the endpoint is present + assert "/form-edge-cases" in schema.paths + + post_op = schema.paths["/form-edge-cases"].post + request_body = post_op.requestBody + + # Should use application/x-www-form-urlencoded for form-only parameters + assert "application/x-www-form-urlencoded" in request_body.content + + # Get the component schema + form_content = request_body.content["application/x-www-form-urlencoded"] + schema_ref = form_content.schema_.ref + component_name = schema_ref.split("/")[-1] + component_schema = schema.components.schemas[component_name] + + properties = component_schema.properties + + # Check all fields are present + assert "required_field" in properties + assert "optional_field" in properties + assert "field_with_default" in properties + + # Check required vs optional handling + assert "required_field" in component_schema.required + assert "optional_field" not in component_schema.required # Optional + assert "field_with_default" not in component_schema.required # Has default diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index c1cc0462bf7..b41beda36bc 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1,3 +1,4 @@ +import base64 import json from dataclasses import dataclass from enum import Enum @@ -18,7 +19,7 @@ VPCLatticeV2Resolver, ) from aws_lambda_powertools.event_handler.openapi.exceptions import ResponseValidationError -from aws_lambda_powertools.event_handler.openapi.params import Body, Header, Query +from aws_lambda_powertools.event_handler.openapi.params import Body, Form, Header, Query def test_validate_scalars(gw_event): @@ -1068,49 +1069,6 @@ def handler3(): assert any(text in result["body"] for text in expected_error_text) -def test_validation_with_alias(gw_event): - # GIVEN a REST API V2 proxy type event - app = APIGatewayRestResolver(enable_validation=True) - - # GIVEN that it has a multiple parameters called "parameter1" - gw_event["queryStringParameters"] = { - "parameter1": "value1,value2", - } - - @app.get("/my/path") - def my_path( - parameter: Annotated[Optional[str], Query(alias="parameter1")] = None, - ) -> str: - assert parameter == "value1" - return parameter - - result = app(gw_event, {}) - assert result["statusCode"] == 200 - - -def test_validation_with_http_single_param(gw_event_http): - # GIVEN a HTTP API V2 proxy type event - app = APIGatewayHttpResolver(enable_validation=True) - - # GIVEN that it has a single parameter called "parameter2" - gw_event_http["queryStringParameters"] = { - "parameter1": "value1,value2", - "parameter2": "value", - } - - # WHEN a handler is defined with a single parameter - @app.post("/my/path") - def my_path( - parameter2: str, - ) -> str: - assert parameter2 == "value" - return parameter2 - - # THEN the handler should be invoked and return 200 - result = app(gw_event_http, {}) - assert result["statusCode"] == 200 - - def test_validate_with_minimal_event(): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -1519,3 +1477,158 @@ def handler_custom_route_response_validation_error() -> Model: str(exception_info.value) == f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code or an enum of type HTTPStatus." # noqa: E501 ) + + +def test_parse_form_data_url_encoded(gw_event): + """Test _parse_form_data method with URL-encoded form data""" + + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/form") + def post_form(name: Annotated[str, Form()], tags: Annotated[List[str], Form()]): + return {"name": name, "tags": tags} + + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/form" + gw_event["headers"]["content-type"] = "application/x-www-form-urlencoded" + gw_event["body"] = "name=test&tags=tag1&tags=tag2" + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + +def test_parse_form_data_wrong_value(gw_event): + """Test _parse_form_data method with URL-encoded form data""" + + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/form") + def post_form(name: Annotated[str, Form()], tags: Annotated[List[str], Form()]): + return {"name": name, "tags": tags} + + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/form" + gw_event["headers"]["content-type"] = "application/x-www-form-urlencoded" + gw_event["body"] = "123" + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + +def test_parse_form_data_empty_body(gw_event): + """Test _parse_form_data method with empty body""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/form") + def post_form(name: Annotated[str, Form()] = "default"): + return {"name": name} + + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/form" + gw_event["headers"]["content-type"] = "application/x-www-form-urlencoded" + gw_event["body"] = "" + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + +def test_form_data_parsing_exception(gw_event): + """Test _parse_form_data method exception handling""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/form") + def post_form(name: Annotated[str, Form()]): + return {"name": name} + + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/form" + gw_event["headers"]["content-type"] = "application/x-www-form-urlencoded" + # Set body to None to trigger exception handling + gw_event["body"] = None + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + # With None body, it becomes empty string and missing field validation triggers + assert "missing" in result["body"] + + +def test_prepare_response_content_nested_structures(): + """Test _prepare_response_content method with nested data structures""" + from dataclasses import dataclass + + app = APIGatewayRestResolver(enable_validation=True) + + @dataclass + class TestDataclass: + name: str + value: int + + class TestModel(BaseModel): + title: str + count: int + + @app.get("/complex") + def get_complex() -> dict: + # Return complex nested structure to trigger _prepare_response_content paths + return { + "models": [TestModel(title="test1", count=1), TestModel(title="test2", count=2)], + "dataclasses": [TestDataclass(name="dc1", value=10)], + "nested_dicts": {"inner": {"key": "value"}}, + "mixed_list": [{"a": 1}, TestModel(title="mixed", count=3)], + } + + event = { + "httpMethod": "GET", + "path": "/complex", + "headers": {}, + "queryStringParameters": None, + "body": None, + "isBase64Encoded": False, + "requestContext": {"requestId": "test"}, + "pathParameters": None, + } + + result = app(event, {}) + assert result["statusCode"] == 200 + + +def test_multipart_empty_parts(gw_event): + """Test handling of multipart data with empty parts.""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.post("/test") + def handler(): + return {"status": "ok"} + + content_type = "multipart/form-data; boundary=----boundary" + + # Test with completely empty multipart content + empty_multipart = "------boundary--\r\n" + gw_event["body"] = base64.b64encode(empty_multipart.encode()).decode() + gw_event["headers"]["content-type"] = content_type + gw_event["isBase64Encoded"] = True + gw_event["path"] = "/test" + gw_event["httpMethod"] = "POST" + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + +def test_response_serialization_with_custom_serializer(): + """Test response serialization using custom serializer path.""" + app = APIGatewayRestResolver(enable_validation=True) + + class CustomModel(BaseModel): + id: int + name: str + + def dict(self, **kwargs): + # Custom dict method that triggers alternative serialization path + return {"custom": "value", "id": self.id, "name": self.name} + + @app.get("/test") + def handler() -> CustomModel: + return CustomModel(id=1, name="test") + + result = app({"httpMethod": "GET", "path": "/test"}, {}) + assert result["statusCode"] == 200