From 9c056fe10deae69f28ab158623467be239787787 Mon Sep 17 00:00:00 2001 From: Nico Tonnhofer Date: Mon, 22 Sep 2025 22:07:43 +0200 Subject: [PATCH 1/4] fix: parse single list items in form data --- .../middlewares/openapi_validation.py | 12 ++++++++---- .../test_openapi_validation_middleware.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 6a276de20fb..19754b71ac3 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -15,6 +15,7 @@ _normalize_errors, _regenerate_error_with_loc, get_missing_field_error, + is_sequence_field, ) from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder @@ -150,11 +151,10 @@ def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]: """Parse URL-encoded form data from the request body.""" try: body = app.current_event.decoded_body or "" - # parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values + # NOTE: Keep values as lists; we'll normalize per-field later based on the expected type. + # This avoids breaking List[...] fields when only a single value is provided. parsed = parse_qs(body, keep_blank_values=True) - - result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()} - return result + return parsed except Exception as e: # pragma: no cover raise RequestValidationError( # pragma: no cover @@ -388,6 +388,10 @@ def _request_body_to_args( values[field.name] = deepcopy(field.default) continue + # Normalize lists for non-sequence fields + if isinstance(value, list) and not is_sequence_field(field): + value = value[0] + # MAINTENANCE: Handle byte and file fields # Finally, validate the value diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 1fd919b7b71..3ceaf6c0790 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1481,20 +1481,33 @@ def handler_custom_route_response_validation_error() -> Model: def test_parse_form_data_url_encoded(gw_event): """Test _parse_form_data method with URL-encoded form data""" - + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @app.post("/form") def post_form(name: Annotated[str, Form()], tags: Annotated[List[str], Form()]): return {"name": name, "tags": tags} + # WHEN sending a POST request with URL-encoded form data gw_event["httpMethod"] = "POST" gw_event["path"] = "/form" gw_event["headers"]["content-type"] = "application/x-www-form-urlencoded" gw_event["body"] = "name=test&tags=tag1&tags=tag2" result = app(gw_event, {}) + + # THEN it should parse the form data correctly + assert result["statusCode"] == 200 + assert result["body"] == '{"name":"test","tags":["tag1","tag2"]}' + + # WHEN sending a POST request with a single value for a list field + gw_event["body"] = "name=test&tags=tag1" + + result = app(gw_event, {}) + + # THEN it should parse the form data correctly assert result["statusCode"] == 200 + assert result["body"] == '{"name":"test","tags":["tag1"]}' def test_parse_form_data_wrong_value(gw_event): From c2659aaeaf6cc40ce6ab8cc72c1230736545274a Mon Sep 17 00:00:00 2001 From: Nico Tonnhofer Date: Tue, 23 Sep 2025 12:19:03 +0200 Subject: [PATCH 2/4] fix: reduce cognitive complexity --- .../middlewares/openapi_validation.py | 92 +++++++++++-------- 1 file changed, 54 insertions(+), 38 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 19754b71ac3..6b1f37ae8a4 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -3,7 +3,6 @@ import dataclasses import json import logging -from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence from urllib.parse import parse_qs @@ -314,12 +313,12 @@ def _prepare_response_content( def _request_params_to_args( required_params: Sequence[ModelField], received_params: Mapping[str, Any], -) -> tuple[dict[str, Any], list[Any]]: +) -> tuple[dict[str, Any], list[dict[str, Any]]]: """ Convert the request params to a dictionary of values using validation, and returns a list of errors. """ - values = {} - errors = [] + values: dict[str, Any] = {} + errors: list[dict[str, Any]] = [] for field in required_params: field_info = field.field_info @@ -328,16 +327,12 @@ def _request_params_to_args( if not isinstance(field_info, Param): raise AssertionError(f"Expected Param field_info, got {field_info}") - value = received_params.get(field.alias) - loc = (field_info.in_.value, field.alias) + value = received_params.get(field.alias) # If we don't have a value, see if it's required or has a default if value is None: - if field.required: - errors.append(get_missing_field_error(loc=loc)) - else: - values[field.name] = deepcopy(field.default) + _handle_missing_field_value(field, values, errors, loc) continue # Finally, validate the value @@ -363,43 +358,64 @@ def _request_body_to_args( ) for field in required_params: - # This sets the location to: - # { "user": { object } } if field.alias == user - # { { object } if field_alias is omitted - loc: tuple[str, ...] = ("body", field.alias) - if field_alias_omitted: - loc = ("body",) - - value: Any | None = None + loc = _get_body_field_location(field, field_alias_omitted) + value = _extract_field_value_from_body(field, received_body, loc, errors) - # Now that we know what to look for, try to get the value from the received body - if received_body is not None: - try: - value = received_body.get(field.alias) - except AttributeError: - errors.append(get_missing_field_error(loc)) - continue - - # Determine if the field is required + # If we don't have a value, see if it's required or has a default if value is None: - if field.required: - errors.append(get_missing_field_error(loc)) - else: - values[field.name] = deepcopy(field.default) + _handle_missing_field_value(field, values, errors, loc) continue - # Normalize lists for non-sequence fields - if isinstance(value, list) and not is_sequence_field(field): - value = value[0] - - # MAINTENANCE: Handle byte and file fields - - # Finally, validate the value + value = _normalize_field_value(field, value) values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) return values, errors +def _get_body_field_location(field: ModelField, field_alias_omitted: bool) -> tuple[str, ...]: + """Get the location tuple for a body field based on whether the field alias is omitted.""" + if field_alias_omitted: + return ("body",) + return ("body", field.alias) + + +def _extract_field_value_from_body( + field: ModelField, + received_body: dict[str, Any] | None, + loc: tuple[str, ...], + errors: list[dict[str, Any]], +) -> Any | None: + """Extract field value from the received body, handling potential AttributeError.""" + if received_body is None: + return None + + try: + return received_body.get(field.alias) + except AttributeError: + errors.append(get_missing_field_error(loc)) + return None + + +def _handle_missing_field_value( + field: ModelField, + values: dict[str, Any], + errors: list[dict[str, Any]], + loc: tuple[str, ...], +) -> None: + """Handle the case when a field value is missing.""" + if field.required: + errors.append(get_missing_field_error(loc)) + else: + values[field.name] = field.get_default() + + +def _normalize_field_value(field: ModelField, value: Any) -> Any: + """Normalize field value, converting lists to single values for non-sequence fields.""" + if isinstance(value, list) and not is_sequence_field(field): + return value[0] + return value + + def _validate_field( *, field: ModelField, From ed3791874558e7d96a2d5ee4b28185d2b43ba66e Mon Sep 17 00:00:00 2001 From: Nico Tonnhofer Date: Tue, 23 Sep 2025 16:51:23 +0200 Subject: [PATCH 3/4] feat: add tests for requests with empty body --- .../test_openapi_validation_middleware.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 3ceaf6c0790..f5aca36ab1f 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -366,6 +366,53 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model: assert result["statusCode"] == 200 +def test_validate_body_params_with_missing_body_sets_received_body_none(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with multiple body parameters + @app.post("/") + def handler(name: str, age: int): + return {"name": name, "age": age} + + # WHEN the event has no body + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = None # simulate event without body + gw_event["headers"]["content-type"] = "application/json" + + result = app(gw_event, {}) + + # THEN the handler should be invoked and return 422 + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + +def test_validate_embed_body_param_with_missing_body_sets_received_body_none(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + + # WHEN a handler is defined with a body parameter + @app.post("/") + def handler(user: Annotated[Model, Body(embed=True)]) -> Model: + return user + + # WHEN the event has no body + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = None # simulate event without body + gw_event["headers"]["content-type"] = "application/json" + + result = app(gw_event, {}) + + # THEN the handler should be invoked and return 422 + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + def test_validate_response_return(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) From 54bf884fe8f0facc14e16d8ab149b4f2891874c5 Mon Sep 17 00:00:00 2001 From: Nico Tonnhofer Date: Tue, 23 Sep 2025 17:28:07 +0200 Subject: [PATCH 4/4] feat: test empty body events --- .../test_openapi_validation_middleware.py | 51 ++++++++++++++++++- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 44239404ac1..58a3f19e504 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -366,7 +366,7 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model: assert result["statusCode"] == 200 -def test_validate_body_params_with_missing_body_sets_received_body_none(gw_event): +def test_validate_body_param_with_missing_body(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -388,7 +388,29 @@ def handler(name: str, age: int): assert "missing" in result["body"] -def test_validate_embed_body_param_with_missing_body_sets_received_body_none(gw_event): +def test_validate_body_param_with_empty_body(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with multiple body parameters + @app.post("/") + def handler(name: str, age: int): + return {"name": name, "age": age} + + # WHEN the event has no body + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = "[]" # JSON array -> received_body is a list (no .get) + gw_event["headers"]["content-type"] = "application/json" + + result = app(gw_event, {}) + + # THEN the handler should be invoked and return 422 + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + +def test_validate_embed_body_param_with_missing_body(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -413,6 +435,31 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model: assert "missing" in result["body"] +def test_validate_embed_body_param_with_empty_body(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + + # WHEN a handler is defined with a body parameter + @app.post("/") + def handler(user: Annotated[Model, Body(embed=True)]) -> Model: + return user + + # WHEN the event has no body + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = "[]" # JSON array -> received_body is a list (no .get) + gw_event["headers"]["content-type"] = "application/json" + + result = app(gw_event, {}) + + # THEN the handler should be invoked and return 422 + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + def test_validate_response_return(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True)