Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import dataclasses
import json
import logging
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
from urllib.parse import parse_qs

Expand All @@ -15,6 +14,7 @@
_normalize_errors,
_regenerate_error_with_loc,
get_missing_field_error,
is_sequence_field,
)
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
Expand Down Expand Up @@ -150,11 +150,10 @@ def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]:
"""Parse URL-encoded form data from the request body."""
try:
body = app.current_event.decoded_body or ""
# parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values
# NOTE: Keep values as lists; we'll normalize per-field later based on the expected type.
# This avoids breaking List[...] fields when only a single value is provided.
parsed = parse_qs(body, keep_blank_values=True)

result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()}
return result
return parsed

except Exception as e: # pragma: no cover
raise RequestValidationError( # pragma: no cover
Expand Down Expand Up @@ -314,12 +313,12 @@ def _prepare_response_content(
def _request_params_to_args(
required_params: Sequence[ModelField],
received_params: Mapping[str, Any],
) -> tuple[dict[str, Any], list[Any]]:
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
"""
Convert the request params to a dictionary of values using validation, and returns a list of errors.
"""
values = {}
errors = []
values: dict[str, Any] = {}
errors: list[dict[str, Any]] = []

for field in required_params:
field_info = field.field_info
Expand All @@ -328,16 +327,12 @@ def _request_params_to_args(
if not isinstance(field_info, Param):
raise AssertionError(f"Expected Param field_info, got {field_info}")

value = received_params.get(field.alias)

loc = (field_info.in_.value, field.alias)
value = received_params.get(field.alias)

# If we don't have a value, see if it's required or has a default
if value is None:
if field.required:
errors.append(get_missing_field_error(loc=loc))
else:
values[field.name] = deepcopy(field.default)
_handle_missing_field_value(field, values, errors, loc)
continue

# Finally, validate the value
Expand All @@ -363,39 +358,64 @@ def _request_body_to_args(
)

for field in required_params:
# This sets the location to:
# { "user": { object } } if field.alias == user
# { { object } if field_alias is omitted
loc: tuple[str, ...] = ("body", field.alias)
if field_alias_omitted:
loc = ("body",)
loc = _get_body_field_location(field, field_alias_omitted)
value = _extract_field_value_from_body(field, received_body, loc, errors)

value: Any | None = None

# Now that we know what to look for, try to get the value from the received body
if received_body is not None:
try:
value = received_body.get(field.alias)
except AttributeError:
errors.append(get_missing_field_error(loc))
continue

# Determine if the field is required
# If we don't have a value, see if it's required or has a default
if value is None:
if field.required:
errors.append(get_missing_field_error(loc))
else:
values[field.name] = deepcopy(field.default)
_handle_missing_field_value(field, values, errors, loc)
continue

# MAINTENANCE: Handle byte and file fields

# Finally, validate the value
value = _normalize_field_value(field, value)
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)

return values, errors


def _get_body_field_location(field: ModelField, field_alias_omitted: bool) -> tuple[str, ...]:
"""Get the location tuple for a body field based on whether the field alias is omitted."""
if field_alias_omitted:
return ("body",)
return ("body", field.alias)


def _extract_field_value_from_body(
field: ModelField,
received_body: dict[str, Any] | None,
loc: tuple[str, ...],
errors: list[dict[str, Any]],
) -> Any | None:
"""Extract field value from the received body, handling potential AttributeError."""
if received_body is None:
return None

try:
return received_body.get(field.alias)
except AttributeError:
errors.append(get_missing_field_error(loc))
return None


def _handle_missing_field_value(
field: ModelField,
values: dict[str, Any],
errors: list[dict[str, Any]],
loc: tuple[str, ...],
) -> None:
"""Handle the case when a field value is missing."""
if field.required:
errors.append(get_missing_field_error(loc))
else:
values[field.name] = field.get_default()


def _normalize_field_value(field: ModelField, value: Any) -> Any:
"""Normalize field value, converting lists to single values for non-sequence fields."""
if isinstance(value, list) and not is_sequence_field(field):
return value[0]
return value


def _validate_field(
*,
field: ModelField,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,100 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model:
assert result["statusCode"] == 200


def test_validate_body_param_with_missing_body(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler is defined with multiple body parameters
@app.post("/")
def handler(name: str, age: int):
return {"name": name, "age": age}

# WHEN the event has no body
gw_event["httpMethod"] = "POST"
gw_event["path"] = "/"
gw_event["body"] = None # simulate event without body
gw_event["headers"]["content-type"] = "application/json"

result = app(gw_event, {})

# THEN the handler should be invoked and return 422
assert result["statusCode"] == 422
assert "missing" in result["body"]


def test_validate_body_param_with_empty_body(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler is defined with multiple body parameters
@app.post("/")
def handler(name: str, age: int):
return {"name": name, "age": age}

# WHEN the event has no body
gw_event["httpMethod"] = "POST"
gw_event["path"] = "/"
gw_event["body"] = "[]" # JSON array -> received_body is a list (no .get)
gw_event["headers"]["content-type"] = "application/json"

result = app(gw_event, {})

# THEN the handler should be invoked and return 422
assert result["statusCode"] == 422
assert "missing" in result["body"]


def test_validate_embed_body_param_with_missing_body(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str

# WHEN a handler is defined with a body parameter
@app.post("/")
def handler(user: Annotated[Model, Body(embed=True)]) -> Model:
return user

# WHEN the event has no body
gw_event["httpMethod"] = "POST"
gw_event["path"] = "/"
gw_event["body"] = None # simulate event without body
gw_event["headers"]["content-type"] = "application/json"

result = app(gw_event, {})

# THEN the handler should be invoked and return 422
assert result["statusCode"] == 422
assert "missing" in result["body"]


def test_validate_embed_body_param_with_empty_body(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str

# WHEN a handler is defined with a body parameter
@app.post("/")
def handler(user: Annotated[Model, Body(embed=True)]) -> Model:
return user

# WHEN the event has no body
gw_event["httpMethod"] = "POST"
gw_event["path"] = "/"
gw_event["body"] = "[]" # JSON array -> received_body is a list (no .get)
gw_event["headers"]["content-type"] = "application/json"

result = app(gw_event, {})

# THEN the handler should be invoked and return 422
assert result["statusCode"] == 422
assert "missing" in result["body"]


def test_validate_response_return(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)
Expand Down Expand Up @@ -1481,20 +1575,33 @@ def handler_custom_route_response_validation_error() -> Model:

def test_parse_form_data_url_encoded(gw_event):
"""Test _parse_form_data method with URL-encoded form data"""

# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

@app.post("/form")
def post_form(name: Annotated[str, Form()], tags: Annotated[List[str], Form()]):
return {"name": name, "tags": tags}

# WHEN sending a POST request with URL-encoded form data
gw_event["httpMethod"] = "POST"
gw_event["path"] = "/form"
gw_event["headers"]["content-type"] = "application/x-www-form-urlencoded"
gw_event["body"] = "name=test&tags=tag1&tags=tag2"

result = app(gw_event, {})

# THEN it should parse the form data correctly
assert result["statusCode"] == 200
assert result["body"] == '{"name":"test","tags":["tag1","tag2"]}'

# WHEN sending a POST request with a single value for a list field
gw_event["body"] = "name=test&tags=tag1"

result = app(gw_event, {})

# THEN it should parse the form data correctly
assert result["statusCode"] == 200
assert result["body"] == '{"name":"test","tags":["tag1"]}'


def test_parse_form_data_wrong_value(gw_event):
Expand Down