diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 6a276de20fb..6b1f37ae8a4 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -3,7 +3,6 @@ import dataclasses import json import logging -from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence from urllib.parse import parse_qs @@ -15,6 +14,7 @@ _normalize_errors, _regenerate_error_with_loc, get_missing_field_error, + is_sequence_field, ) from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder @@ -150,11 +150,10 @@ def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]: """Parse URL-encoded form data from the request body.""" try: body = app.current_event.decoded_body or "" - # parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values + # NOTE: Keep values as lists; we'll normalize per-field later based on the expected type. + # This avoids breaking List[...] fields when only a single value is provided. parsed = parse_qs(body, keep_blank_values=True) - - result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()} - return result + return parsed except Exception as e: # pragma: no cover raise RequestValidationError( # pragma: no cover @@ -314,12 +313,12 @@ def _prepare_response_content( def _request_params_to_args( required_params: Sequence[ModelField], received_params: Mapping[str, Any], -) -> tuple[dict[str, Any], list[Any]]: +) -> tuple[dict[str, Any], list[dict[str, Any]]]: """ Convert the request params to a dictionary of values using validation, and returns a list of errors. """ - values = {} - errors = [] + values: dict[str, Any] = {} + errors: list[dict[str, Any]] = [] for field in required_params: field_info = field.field_info @@ -328,16 +327,12 @@ def _request_params_to_args( if not isinstance(field_info, Param): raise AssertionError(f"Expected Param field_info, got {field_info}") - value = received_params.get(field.alias) - loc = (field_info.in_.value, field.alias) + value = received_params.get(field.alias) # If we don't have a value, see if it's required or has a default if value is None: - if field.required: - errors.append(get_missing_field_error(loc=loc)) - else: - values[field.name] = deepcopy(field.default) + _handle_missing_field_value(field, values, errors, loc) continue # Finally, validate the value @@ -363,39 +358,64 @@ def _request_body_to_args( ) for field in required_params: - # This sets the location to: - # { "user": { object } } if field.alias == user - # { { object } if field_alias is omitted - loc: tuple[str, ...] = ("body", field.alias) - if field_alias_omitted: - loc = ("body",) + loc = _get_body_field_location(field, field_alias_omitted) + value = _extract_field_value_from_body(field, received_body, loc, errors) - value: Any | None = None - - # Now that we know what to look for, try to get the value from the received body - if received_body is not None: - try: - value = received_body.get(field.alias) - except AttributeError: - errors.append(get_missing_field_error(loc)) - continue - - # Determine if the field is required + # If we don't have a value, see if it's required or has a default if value is None: - if field.required: - errors.append(get_missing_field_error(loc)) - else: - values[field.name] = deepcopy(field.default) + _handle_missing_field_value(field, values, errors, loc) continue - # MAINTENANCE: Handle byte and file fields - - # Finally, validate the value + value = _normalize_field_value(field, value) values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) return values, errors +def _get_body_field_location(field: ModelField, field_alias_omitted: bool) -> tuple[str, ...]: + """Get the location tuple for a body field based on whether the field alias is omitted.""" + if field_alias_omitted: + return ("body",) + return ("body", field.alias) + + +def _extract_field_value_from_body( + field: ModelField, + received_body: dict[str, Any] | None, + loc: tuple[str, ...], + errors: list[dict[str, Any]], +) -> Any | None: + """Extract field value from the received body, handling potential AttributeError.""" + if received_body is None: + return None + + try: + return received_body.get(field.alias) + except AttributeError: + errors.append(get_missing_field_error(loc)) + return None + + +def _handle_missing_field_value( + field: ModelField, + values: dict[str, Any], + errors: list[dict[str, Any]], + loc: tuple[str, ...], +) -> None: + """Handle the case when a field value is missing.""" + if field.required: + errors.append(get_missing_field_error(loc)) + else: + values[field.name] = field.get_default() + + +def _normalize_field_value(field: ModelField, value: Any) -> Any: + """Normalize field value, converting lists to single values for non-sequence fields.""" + if isinstance(value, list) and not is_sequence_field(field): + return value[0] + return value + + def _validate_field( *, field: ModelField, diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index de1add78d55..58a3f19e504 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -366,6 +366,100 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model: assert result["statusCode"] == 200 +def test_validate_body_param_with_missing_body(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with multiple body parameters + @app.post("/") + def handler(name: str, age: int): + return {"name": name, "age": age} + + # WHEN the event has no body + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = None # simulate event without body + gw_event["headers"]["content-type"] = "application/json" + + result = app(gw_event, {}) + + # THEN the handler should be invoked and return 422 + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + +def test_validate_body_param_with_empty_body(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with multiple body parameters + @app.post("/") + def handler(name: str, age: int): + return {"name": name, "age": age} + + # WHEN the event has no body + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = "[]" # JSON array -> received_body is a list (no .get) + gw_event["headers"]["content-type"] = "application/json" + + result = app(gw_event, {}) + + # THEN the handler should be invoked and return 422 + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + +def test_validate_embed_body_param_with_missing_body(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + + # WHEN a handler is defined with a body parameter + @app.post("/") + def handler(user: Annotated[Model, Body(embed=True)]) -> Model: + return user + + # WHEN the event has no body + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = None # simulate event without body + gw_event["headers"]["content-type"] = "application/json" + + result = app(gw_event, {}) + + # THEN the handler should be invoked and return 422 + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + +def test_validate_embed_body_param_with_empty_body(gw_event): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + + # WHEN a handler is defined with a body parameter + @app.post("/") + def handler(user: Annotated[Model, Body(embed=True)]) -> Model: + return user + + # WHEN the event has no body + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = "[]" # JSON array -> received_body is a list (no .get) + gw_event["headers"]["content-type"] = "application/json" + + result = app(gw_event, {}) + + # THEN the handler should be invoked and return 422 + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + def test_validate_response_return(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -1481,20 +1575,33 @@ def handler_custom_route_response_validation_error() -> Model: def test_parse_form_data_url_encoded(gw_event): """Test _parse_form_data method with URL-encoded form data""" - + # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @app.post("/form") def post_form(name: Annotated[str, Form()], tags: Annotated[List[str], Form()]): return {"name": name, "tags": tags} + # WHEN sending a POST request with URL-encoded form data gw_event["httpMethod"] = "POST" gw_event["path"] = "/form" gw_event["headers"]["content-type"] = "application/x-www-form-urlencoded" gw_event["body"] = "name=test&tags=tag1&tags=tag2" result = app(gw_event, {}) + + # THEN it should parse the form data correctly + assert result["statusCode"] == 200 + assert result["body"] == '{"name":"test","tags":["tag1","tag2"]}' + + # WHEN sending a POST request with a single value for a list field + gw_event["body"] = "name=test&tags=tag1" + + result = app(gw_event, {}) + + # THEN it should parse the form data correctly assert result["statusCode"] == 200 + assert result["body"] == '{"name":"test","tags":["tag1"]}' def test_parse_form_data_wrong_value(gw_event):