diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 407cd00781b..841f9372f18 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -92,7 +92,7 @@ Server, Tag, ) - from aws_lambda_powertools.event_handler.openapi.params import Dependant + from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import ( OAuth2Config, ) @@ -812,46 +812,123 @@ def _openapi_operation_parameters( """ Returns the OpenAPI operation parameters. """ - from aws_lambda_powertools.event_handler.openapi.compat import ( - get_schema_from_model_field, - ) from aws_lambda_powertools.event_handler.openapi.params import Param - parameters = [] - parameter: dict[str, Any] = {} + parameters: list[dict[str, Any]] = [] for param in all_route_params: - field_info = param.field_info - field_info = cast(Param, field_info) + field_info = cast(Param, param.field_info) if not field_info.include_in_schema: continue - param_schema = get_schema_from_model_field( - field=param, - model_name_map=model_name_map, - field_mapping=field_mapping, - ) + # Check if this is a Pydantic model that should be expanded + if Route._is_pydantic_model_param(field_info): + parameters.extend(Route._expand_pydantic_model_parameters(field_info)) + else: + parameters.append(Route._create_regular_parameter(param, model_name_map, field_mapping)) - parameter = { - "name": param.alias, - "in": field_info.in_.value, - "required": param.required, - "schema": param_schema, - } + return parameters - if field_info.description: - parameter["description"] = field_info.description + @staticmethod + def _is_pydantic_model_param(field_info: Param) -> bool: + """Check if the field info represents a Pydantic model parameter.""" + from pydantic import BaseModel - if field_info.openapi_examples: - parameter["examples"] = field_info.openapi_examples + from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass - if field_info.deprecated: - parameter["deprecated"] = field_info.deprecated + return lenient_issubclass(field_info.annotation, BaseModel) - parameters.append(parameter) + @staticmethod + def _expand_pydantic_model_parameters(field_info: Param) -> list[dict[str, Any]]: + """Expand a Pydantic model into individual OpenAPI parameters.""" + from pydantic import BaseModel + + model_class = cast(type[BaseModel], field_info.annotation) + parameters: list[dict[str, Any]] = [] + + for field_name, field_def in model_class.model_fields.items(): + param_name = field_def.alias or field_name + individual_param = Route._create_pydantic_field_parameter( + param_name=param_name, + field_def=field_def, + param_location=field_info.in_.value, + ) + parameters.append(individual_param) return parameters + @staticmethod + def _create_pydantic_field_parameter( + param_name: str, + field_def: Any, + param_location: str, + ) -> dict[str, Any]: + """Create an OpenAPI parameter from a Pydantic field definition.""" + individual_param: dict[str, Any] = { + "name": param_name, + "in": param_location, + "required": field_def.is_required() if hasattr(field_def, "is_required") else field_def.default is ..., + "schema": Route._get_basic_type_schema(field_def.annotation or type(None)), + } + + if field_def.description: + individual_param["description"] = field_def.description + + return individual_param + + @staticmethod + def _create_regular_parameter( + param: ModelField, + model_name_map: dict[TypeModelOrEnum, str], + field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], + ) -> dict[str, Any]: + """Create an OpenAPI parameter from a regular ModelField.""" + from aws_lambda_powertools.event_handler.openapi.compat import get_schema_from_model_field + from aws_lambda_powertools.event_handler.openapi.params import Param + + field_info = cast(Param, param.field_info) + param_schema = get_schema_from_model_field( + field=param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + parameter: dict[str, Any] = { + "name": param.alias, + "in": field_info.in_.value, + "required": param.required, + "schema": param_schema, + } + + # Add optional attributes if present + if field_info.description: + parameter["description"] = field_info.description + if field_info.openapi_examples: + parameter["examples"] = field_info.openapi_examples + if field_info.deprecated: + parameter["deprecated"] = field_info.deprecated + + return parameter + + @staticmethod + def _get_basic_type_schema(param_type: type) -> dict[str, str]: + """ + Get basic OpenAPI schema for simple types + """ + try: + # Check bool before int, since bool is a subclass of int in Python + if issubclass(param_type, bool): + return {"type": "boolean"} + elif issubclass(param_type, int): + return {"type": "integer"} + elif issubclass(param_type, float): + return {"type": "number"} + else: + return {"type": "string"} + except TypeError: + # param_type may not be a type (e.g., typing.Optional[int]), fallback to string + return {"type": "string"} + @staticmethod def _openapi_operation_return( *, diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 6a276de20fb..a120c5e8a3e 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -4,7 +4,7 @@ import json import logging from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast, get_origin from urllib.parse import parse_qs from pydantic import BaseModel @@ -15,6 +15,7 @@ _normalize_errors, _regenerate_error_with_loc, get_missing_field_error, + lenient_issubclass, ) from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder @@ -64,7 +65,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> ) # Normalize query values before validate this - query_string = _normalize_multi_query_string_with_param( + query_string = _normalize_multi_params( app.current_event.resolved_query_string_parameters, route.dependant.query_params, ) @@ -76,7 +77,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> ) # Normalize header values before validate this - headers = _normalize_multi_header_values_with_param( + headers = _normalize_multi_params( app.current_event.resolved_headers_field, route.dependant.header_params, ) @@ -434,57 +435,80 @@ def _get_embed_body( return received_body, field_alias_omitted -def _normalize_multi_query_string_with_param( - query_string: dict[str, list[str]], +def _normalize_multi_params( + input_dict: MutableMapping[str, Any], params: Sequence[ModelField], -) -> dict[str, Any]: +) -> MutableMapping[str, Any]: """ - Extract and normalize resolved_query_string_parameters + Extract and normalize query string or header parameters with Pydantic model support. Parameters ---------- - query_string: dict - A dictionary containing the initial query string parameters. + input_dict: MutableMapping[str, Any] + A dictionary containing the initial query string or header parameters. params: Sequence[ModelField] A sequence of ModelField objects representing parameters. Returns ------- - A dictionary containing the processed multi_query_string_parameters. + MutableMapping[str, Any] + A dictionary containing the processed parameters with normalized values. """ - resolved_query_string: dict[str, Any] = query_string - for param in filter(is_scalar_field, params): - try: - # if the target parameter is a scalar, we keep the first value of the query string - # regardless if there are more in the payload - resolved_query_string[param.alias] = query_string[param.alias][0] - except KeyError: - pass - return resolved_query_string + for param in params: + if is_scalar_field(param): + _process_scalar_param(input_dict, param) + elif lenient_issubclass(param.field_info.annotation, BaseModel): + _process_model_param(input_dict, param) + return input_dict + + +def _process_scalar_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None: + """Process a scalar parameter by normalizing single-item lists.""" + try: + val = input_dict[param.alias] + if isinstance(val, list) and len(val) == 1: + input_dict[param.alias] = val[0] + except KeyError: + pass -def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]): - """ - Extract and normalize resolved_headers_field +def _process_model_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None: + """Process a Pydantic model parameter by extracting model fields.""" + model_class = cast(type[BaseModel], param.field_info.annotation) - Parameters - ---------- - headers: MutableMapping[str, Any] - A dictionary containing the initial header parameters. - params: Sequence[ModelField] - A sequence of ModelField objects representing parameters. + model_data = {} + for field_name, field_def in model_class.model_fields.items(): + field_alias = field_def.alias or field_name + value = _get_param_value(input_dict, field_alias, field_name, model_class) - Returns - ------- - A dictionary containing the processed headers. - """ - if headers: - for param in filter(is_scalar_field, params): - try: - if len(headers[param.alias]) == 1: - # if the target parameter is a scalar and the list contains only 1 element - # we keep the first value of the headers regardless if there are more in the payload - headers[param.alias] = headers[param.alias][0] - except KeyError: - pass - return headers + if value is not None: + model_data[field_alias] = _normalize_field_value(value, field_def) + + input_dict[param.alias] = model_data + + +def _get_param_value( + input_dict: MutableMapping[str, Any], + field_alias: str, + field_name: str, + model_class: type[BaseModel], +) -> Any: + """Get parameter value, checking both alias and field name if needed.""" + value = input_dict.get(field_alias) + if value is not None: + return value + + if model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name"): + value = input_dict.get(field_name) + + return value + + +def _normalize_field_value(value: Any, field_def: Any) -> Any: + """Normalize field value based on its type annotation.""" + if get_origin(field_def.annotation) is list: + return value + elif isinstance(value, list) and value: + return value[0] + else: + return value diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 98a8740a74f..310cab68e66 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -9,16 +9,13 @@ create_body_model, evaluate_forwardref, is_scalar_field, - is_scalar_sequence_field, ) from aws_lambda_powertools.event_handler.openapi.params import ( Body, Dependant, Form, - Header, Param, ParamTypes, - Query, _File, analyze_param, create_response_field, @@ -275,7 +272,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: return False elif is_scalar_field(field=param_field): return False - elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field): + elif isinstance(param_field.field_info, Param): return False else: if not isinstance(param_field.field_info, Body): diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 8fc8d0becfa..4c5f28e5d4b 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -4,7 +4,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Literal -from pydantic import BaseConfig +from pydantic import BaseConfig, BaseModel, create_model from pydantic.fields import FieldInfo from typing_extensions import Annotated, get_args, get_origin @@ -17,6 +17,7 @@ copy_field_info, field_annotation_is_scalar, get_annotation_from_field_info, + lenient_issubclass, ) if TYPE_CHECKING: @@ -1094,6 +1095,42 @@ def create_response_field( return ModelField(**kwargs) # type: ignore[arg-type] +def _apply_header_underscore_conversion( + field_info: FieldInfo, + type_annotation: Any, + param_name: str, +) -> tuple[FieldInfo, Any]: + """ + Apply underscore-to-dash conversion for Header parameters. + + For BaseModel: Creates new model with underscore-to-dash alias generator. + Note: If the BaseModel already has an alias generator, it will be replaced + with dash-case conversion since HTTP headers should use dash-case. + For all Header fields: Sets the parameter alias if convert_underscores is True + """ + if not isinstance(field_info, Header) or not field_info.convert_underscores: + return field_info, type_annotation + + # Always set the parameter alias for Header fields (if not already set) + if not field_info.alias: + field_info.alias = param_name.replace("_", "-") + + # Handle BaseModel case - create new model with dash-case alias generator + if lenient_issubclass(type_annotation, BaseModel): + # For HTTP headers, we should use dash-case regardless of existing alias generator + # This ensures consistent header naming conventions + header_aliased_model = create_model( + f"{type_annotation.__name__}WithHeaderAliases", + __base__=type_annotation, + __config__={"alias_generator": lambda name: name.replace("_", "-")}, + ) + + type_annotation = header_aliased_model + field_info.annotation = type_annotation + + return field_info, type_annotation + + def _create_model_field( field_info: FieldInfo | None, type_annotation: Any, @@ -1112,21 +1149,17 @@ def _create_model_field( elif isinstance(field_info, Param) and getattr(field_info, "in_", None) is None: field_info.in_ = ParamTypes.query + # Apply header underscore conversion + field_info, type_annotation = _apply_header_underscore_conversion(field_info, type_annotation, param_name) + # If the field_info is a Param, we use the `in_` attribute to determine the type annotation use_annotation = get_annotation_from_field_info(type_annotation, field_info, param_name) - # If the field doesn't have a defined alias, we use the param name - if not field_info.alias and getattr(field_info, "convert_underscores", None): - alias = param_name.replace("_", "-") - else: - alias = field_info.alias or param_name - field_info.alias = alias - return create_response_field( name=param_name, type_=use_annotation, default=field_info.default, - alias=alias, + alias=field_info.alias, required=field_info.default in (Required, Undefined), field_info=field_info, ) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_params.py b/tests/functional/event_handler/_pydantic/test_openapi_params.py index 19b5287d66a..f797cd541a5 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_params.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_params.py @@ -25,6 +25,145 @@ JSON_CONTENT_TYPE = "application/json" +def test_openapi_pydantic_query_params(): + """Test that Pydantic models in Query parameters are expanded into individual fields in OpenAPI schema""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + limit: int = Field(default=10, ge=1, le=100, description="Number of items to return") + offset: int = Field(default=0, ge=0, description="Number of items to skip") + search: Optional[str] = Field(default=None, description="Search term") + + @app.get("/search") + def search_handler(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + + # Check that the path exists + assert "/search" in schema.paths + path = schema.paths["/search"] + assert path.get is not None + + # Check that parameters are expanded + get_operation = path.get + assert get_operation.parameters is not None + assert len(get_operation.parameters) == 3 + + # Check individual parameters + param_names = [param.name for param in get_operation.parameters] + assert "limit" in param_names + assert "offset" in param_names + assert "search" in param_names + + # Check parameter details + for param in get_operation.parameters: + assert param.in_ == ParameterInType.query + if param.name == "limit": + assert param.required is False # Has default value + assert param.description == "Number of items to return" + assert param.schema_.type == "integer" + elif param.name == "offset": + assert param.required is False # Has default value + assert param.description == "Number of items to skip" + assert param.schema_.type == "integer" + elif param.name == "search": + assert param.required is False # Optional field + assert param.description == "Search term" + assert param.schema_.type == "string" + + +def test_openapi_pydantic_header_params(): + """Test that Pydantic models in Header parameters are expanded into individual fields in OpenAPI schema""" + app = APIGatewayRestResolver() + + class HeaderParams(BaseModel): + authorization: str = Field(description="Authorization token") + user_agent: str = Field(default="PowerTools/1.0", description="User agent") + language: Optional[str] = Field(default=None, alias="accept-language", description="Language preference") + + @app.get("/protected") + def protected_handler(headers: Annotated[HeaderParams, Header()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + + # Check that the path exists + assert "/protected" in schema.paths + path = schema.paths["/protected"] + assert path.get is not None + + # Check that parameters are expanded + get_operation = path.get + assert get_operation.parameters is not None + assert len(get_operation.parameters) == 3 + + # Check individual parameters + param_names = [param.name for param in get_operation.parameters] + assert "authorization" in param_names + assert "user-agent" in param_names # headers are always spinal-case + assert "accept-language" in param_names # Should use alias + + # Check parameter details + for param in get_operation.parameters: + assert param.in_ == ParameterInType.header + if param.name == "authorization": + assert param.required is True # No default value + assert param.description == "Authorization token" + assert param.schema_.type == "string" + elif param.name == "user_agent": + assert param.required is False # Has default value + assert param.description == "User agent" + assert param.schema_.type == "string" + elif param.name == "accept-language": + assert param.required is False # Optional field + assert param.description == "Language preference" + assert param.schema_.type == "string" + + +def test_openapi_pydantic_mixed_params(): + """Test that mixed Pydantic models (Query + Header) work together""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + q: str = Field(description="Search query") + limit: int = Field(default=10, description="Number of results") + + class HeaderParams(BaseModel): + authorization: str = Field(description="Bearer token") + + @app.get("/mixed") + def mixed_handler(query: Annotated[QueryParams, Query()], headers: Annotated[HeaderParams, Header()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + + # Check that the path exists + assert "/mixed" in schema.paths + path = schema.paths["/mixed"] + assert path.get is not None + + # Check that all parameters are expanded + get_operation = path.get + assert get_operation.parameters is not None + assert len(get_operation.parameters) == 3 # 2 query + 1 header + + # Check parameter types + query_params = [p for p in get_operation.parameters if p.in_ == ParameterInType.query] + header_params = [p for p in get_operation.parameters if p.in_ == ParameterInType.header] + + assert len(query_params) == 2 + assert len(header_params) == 1 + + # Check specific parameters + query_names = [p.name for p in query_params] + assert "q" in query_names + assert "limit" in query_names + + header_names = [p.name for p in header_params] + assert "authorization" in header_names + + def test_openapi_no_params(): app = APIGatewayRestResolver() @@ -776,3 +915,132 @@ def form_edge_cases( assert "required_field" in component_schema.required assert "optional_field" not in component_schema.required # Optional assert "field_with_default" not in component_schema.required # Has default + + +def test_openapi_pydantic_query_with_constraints(): + """Test that Pydantic field constraints are preserved in OpenAPI schema""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + limit: int = Field(ge=1, le=100, description="Number of items") + name: str = Field(min_length=1, max_length=50, description="Name filter") + + @app.get("/items") + def get_items(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/items"] + get_operation = path.get + + # Find the limit parameter + limit_param = next(p for p in get_operation.parameters if p.name == "limit") + assert limit_param.schema_.type == "integer" + assert limit_param.description == "Number of items" + + # Find the name parameter + name_param = next(p for p in get_operation.parameters if p.name == "name") + assert name_param.schema_.type == "string" + assert name_param.description == "Name filter" + + +def test_openapi_pydantic_header_with_alias(): + """Test that Pydantic field aliases work correctly in Header parameters""" + app = APIGatewayRestResolver() + + class HeaderParams(BaseModel): + content_type: str = Field(alias="content-type", description="Content type") + user_agent: str = Field(alias="user-agent", description="User agent") + + @app.get("/test") + def test_handler(headers: Annotated[HeaderParams, Header()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/test"] + get_operation = path.get + + # Check that aliases are used as parameter names + param_names = [param.name for param in get_operation.parameters] + assert "content-type" in param_names + assert "user-agent" in param_names + assert "content_type" not in param_names # Original field name should not be used + assert "user_agent" not in param_names + + +def test_openapi_pydantic_required_vs_optional(): + """Test that required vs optional fields are correctly identified""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + required_field: str = Field(description="Required field") + optional_with_default: str = Field(default="default", description="Optional with default") + optional_nullable: Optional[str] = Field(default=None, description="Optional nullable") + + @app.get("/test") + def test_handler(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/test"] + get_operation = path.get + + for param in get_operation.parameters: + if param.name == "required_field": + assert param.required is True + elif param.name == "optional_with_default": + assert param.required is False + elif param.name == "optional_nullable": + assert param.required is False + + +def test_openapi_pydantic_backward_compatibility(): + """Test that existing Body parameter behavior is unchanged""" + app = APIGatewayRestResolver() + + class BodyModel(BaseModel): + name: str = Field(description="Name") + age: int = Field(description="Age") + + @app.post("/users") + def create_user(user: BodyModel): # No annotation - should work as Body + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/users"] + post_operation = path.post + + # Should have no parameters (body is handled separately) + assert post_operation.parameters is None or len(post_operation.parameters) == 0 + + # Should have request body + assert post_operation.requestBody is not None + assert "application/json" in post_operation.requestBody.content + + +def test_openapi_pydantic_complex_types(): + """Test that complex types are handled correctly""" + app = APIGatewayRestResolver() + + class QueryParams(BaseModel): + string_field: str = Field(description="String field") + int_field: int = Field(description="Integer field") + float_field: float = Field(description="Float field") + bool_field: bool = Field(description="Boolean field") + + @app.get("/complex") + def complex_handler(params: Annotated[QueryParams, Query()]): + return {"message": "success"} + + schema = app.get_openapi_schema() + path = schema.paths["/complex"] + get_operation = path.get + + type_mapping = {} + for param in get_operation.parameters: + type_mapping[param.name] = param.schema_.type + + assert type_mapping["string_field"] == "string" + assert type_mapping["int_field"] == "integer" + assert type_mapping["float_field"] == "number" + assert type_mapping["bool_field"] == "boolean" diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 1fd919b7b71..acadae3101a 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from enum import Enum from pathlib import PurePath -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Annotated from aws_lambda_powertools.event_handler import ( @@ -47,6 +47,407 @@ def handler(user_id: int): assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) +def test_validate_pydantic_query_params(gw_event): + """Test that Pydantic models in Query parameters are validated correctly""" + + app = APIGatewayRestResolver(enable_validation=True) + del gw_event["multiValueQueryStringParameters"] + + class QueryParams(BaseModel): + limit: int = Field(default=10, ge=1, le=100, description="Number of items") + search: Optional[str] = Field(default=None, description="Search term") + + @app.get("/search") + def search_handler(params: Annotated[QueryParams, Query()]): + return { + "limit": params.limit, + "search": params.search, + } + + # Test valid request + gw_event["path"] = "/search" + gw_event["queryStringParameters"] = {"limit": "25", "search": "python"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["limit"] == 25 + assert body["search"] == "python" + + # Test with default values + gw_event["queryStringParameters"] = {} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["limit"] == 10 # Default value + assert body["search"] is None # Default value + + # Test validation error (limit too high) + gw_event["queryStringParameters"] = {"limit": "150"} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("limit" in str(error) for error in body["detail"]) + + +def test_validate_multi_value_query_params(gw_event): + """Test that multi-value query parameters are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/users") + def users_handler(ids: Annotated[List[int], Query()]): + return {"ids": ids} + + # Test valid request + gw_event["path"] = "/users" + gw_event["multiValueQueryStringParameters"] = {"ids": ["1", "2", "3"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["ids"] == [1, 2, 3] + + # Test with invalid value + gw_event["multiValueQueryStringParameters"] = {"ids": ["1", "abc", "3"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("ids" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_multi_value_query_params(gw_event): + """Test that Pydantic models in Multi-Value Query parameters are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + class QueryParams(BaseModel): + ids: List[int] = Field(..., description="List of user IDs") + + @app.get("/users") + def users_handler(params: Annotated[QueryParams, Query()]): + return {"ids": params.ids} + + # Test valid request + gw_event["path"] = "/users" + gw_event["multiValueQueryStringParameters"] = {"ids": ["1", "2", "3"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["ids"] == [1, 2, 3] + + # Test with invalid value + gw_event["multiValueQueryStringParameters"] = {"ids": ["1", "abc", "3"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("ids" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_query_params_detailed_errors(gw_event): + """Test that Pydantic validation errors include detailed field-level information""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + class QueryParams(BaseModel): + full_name: str = Field(..., min_length=5, description="Full name with minimum 5 characters") + age: int = Field(..., ge=18, le=100, description="Age between 18 and 100") + + @app.get("/query-model") + def query_model(params: Annotated[QueryParams, Query()]): + return {"full_name": params.full_name, "age": params.age} + + # Test validation error with detailed field information + gw_event["path"] = "/query-model" + gw_event["queryStringParameters"] = {"full_name": "Jo", "age": "15"} # Both invalid + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + + # Check that we get detailed field-level errors + errors = body["detail"] + + # Should have errors for both fields + full_name_error = next((e for e in errors if "full_name" in e["loc"]), None) + age_error = next((e for e in errors if "age" in e["loc"]), None) + + assert full_name_error is not None, "Should have error for full_name field" + assert age_error is not None, "Should have error for age field" + + # Check error details for full_name + assert full_name_error["loc"] == ["query", "params", "full_name"] + assert full_name_error["type"] == "string_too_short" + + # Check error details for age + assert age_error["loc"] == ["query", "params", "age"] + assert age_error["type"] == "greater_than_equal" + + +def test_validate_pydantic_header_params(gw_event): + """Test that Pydantic models in Header parameters are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + + class HeaderParams(BaseModel): + authorization: str = Field(description="Authorization token") + user_agent: str = Field(default="PowerTools/1.0", description="User agent") + + @app.get("/protected") + def protected_handler(my_headers: Annotated[HeaderParams, Header()]): + return { + "authorization": my_headers.authorization, + "user_agent": my_headers.user_agent, + } + + # Test valid request + gw_event["path"] = "/protected" + gw_event["headers"] = {"authorization": "Bearer token123", "user-agent": "TestClient/1.0"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["authorization"] == "Bearer token123" + assert body["user_agent"] == "TestClient/1.0" + + # Test with default value + gw_event["headers"] = {"authorization": "Bearer token123"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["authorization"] == "Bearer token123" + assert body["user_agent"] == "PowerTools/1.0" # Default value + + # Test missing required header + gw_event["headers"] = {} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("my-headers" in str(error) for error in body["detail"]) + + +def test_validate_multi_value_header_params(gw_event): + """Test that multi-value headers are validated correctly without Pydantic""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + + @app.get("/multi-value-headers") + def multi_value_handler(my_headers: Annotated[List[str], Header()]): + return {"items": my_headers} + + # Test valid request + gw_event["path"] = "/multi-value-headers" + gw_event["multiValueHeaders"] = {"my-headers": ["item1", "item2"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["items"] == ["item1", "item2"] + + # Test invalid request + gw_event["multiValueHeaders"] = {"items": "invalid"} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("my-headers" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_multi_value_header_params(gw_event): + """Test that multi-value headers are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + + class MultiValueHeaderParams(BaseModel): + list_items: List[str] = Field(description="List of items") + + @app.get("/multi-value-headers") + def multi_value_handler(my_headers: Annotated[MultiValueHeaderParams, Header()]): + return {"items": my_headers.list_items} + + # Test valid request + gw_event["path"] = "/multi-value-headers" + gw_event["multiValueHeaders"] = {"list-items": ["item1", "item2"]} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["items"] == ["item1", "item2"] + + # Test invalid request + gw_event["multiValueHeaders"] = {"items": "invalid"} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("my-headers" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_header_snake_case_to_kebab_case_schema(gw_event): + """Test that snake_case header fields are converted to kebab-case in OpenAPI schema and validation""" + app = APIGatewayRestResolver(enable_validation=True) + app.enable_swagger() + + class HeaderParams(BaseModel): + correlation_id: str = Field(description="Correlation ID header") + user_agent: str = Field(default="PowerTools/1.0", description="User agent header") + accept: str = Field(default="application/json") # omit description to test optional description + + @app.get("/kebab-headers") + def kebab_handler(my_headers: Annotated[HeaderParams, Header()]): + return { + "correlation_id": my_headers.correlation_id, + "user_agent": my_headers.user_agent, + } + + # Test that OpenAPI schema uses kebab-case for headers + openapi_schema = app.get_openapi_schema() + operation = openapi_schema.paths["/kebab-headers"].get + parameters = operation.parameters + + # Find the correlation_id parameter + correlation_param = next((p for p in parameters if p.name == "correlation-id"), None) + assert correlation_param is not None, "Should have correlation-id parameter in kebab-case" + + # Find the user_agent parameter + user_agent_param = next((p for p in parameters if p.name == "user-agent"), None) + assert user_agent_param is not None, "Should have user-agent parameter in kebab-case" + + # Test validation with kebab-case headers + gw_event["path"] = "/kebab-headers" + gw_event["multiValueHeaders"] = {"correlation-id": "test-123", "user-agent": "TestClient/1.0"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["correlation_id"] == "test-123" + assert body["user_agent"] == "TestClient/1.0" + + +def test_validate_pydantic_mixed_params(gw_event): + """Test that mixed Pydantic models (Query + Header) are validated correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + class QueryParams(BaseModel): + q: str = Field(description="Search query") + limit: int = Field(default=10, description="Number of results") + + class HeaderParams(BaseModel): + authorization: str = Field(description="Bearer token") + + @app.get("/mixed") + def mixed_handler(query: Annotated[QueryParams, Query()], headers: Annotated[HeaderParams, Header()]): + return { + "query": {"q": query.q, "limit": query.limit}, + "headers": {"authorization": headers.authorization}, + } + + # Test valid request + gw_event["path"] = "/mixed" + gw_event["queryStringParameters"] = {"q": "python", "limit": "25"} + gw_event["headers"] = {"authorization": "Bearer token123"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["query"]["q"] == "python" + assert body["query"]["limit"] == 25 + assert body["headers"]["authorization"] == "Bearer token123" + + # Test missing required query parameter + gw_event["queryStringParameters"] = {"limit": "25"} # Missing 'q' + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("q" in str(error) for error in body["detail"]) + + # Test missing required header + gw_event["queryStringParameters"] = {"q": "python"} + gw_event["headers"] = {} # Missing 'authorization' + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("headers" in str(error) for error in body["detail"]) + + +def test_validate_pydantic_with_alias(gw_event): + """Test that Pydantic models with field aliases work correctly""" + app = APIGatewayRestResolver(enable_validation=True) + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + class HeaderParams(BaseModel): + accept_language: str = Field(alias="accept-language", description="Language preference") + + @app.get("/alias") + def alias_handler(headers: Annotated[HeaderParams, Header()]): + return {"accept_language": headers.accept_language} + + # Test with alias in request + gw_event["path"] = "/alias" + gw_event["headers"] = {"accept-language": "en-US"} + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["accept_language"] == "en-US" + + # Test missing aliased field + gw_event["headers"] = {} + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + assert any("headers" in str(error) for error in body["detail"]) + + def test_validate_scalars_with_default(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -724,8 +1125,8 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -790,8 +1191,8 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -853,8 +1254,8 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -918,8 +1319,8 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -982,8 +1383,8 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -1046,8 +1447,8 @@ def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, He @app.get("/users") def handler3( - header2: Annotated[List[str], Header(name="Header2")], - header1: Annotated[str, Header(name="Header1")], + header2: Annotated[List[str], Header(alias="Header2")], + header1: Annotated[str, Header(alias="Header1")], ): print(header2) @@ -1983,3 +2384,141 @@ def get_user(user_id: int) -> UserModel: assert response_body["name"] == "User123" assert response_body["age"] == 143 assert response_body["email"] == "user123@example.com" + + +def test_validate_pydantic_query_params_with_config_dict_and_validators(gw_event): + """Test that Pydantic models with ConfigDict, aliases, and validators work correctly""" + from typing import Any + + from pydantic import AfterValidator, Base64UrlStr, ConfigDict, StringConstraints, alias_generators + + del gw_event["multiValueHeaders"] + del gw_event["multiValueQueryStringParameters"] + + app = APIGatewayRestResolver(enable_validation=True) + app.enable_swagger(path="/swagger") + + def _validate_powertools(value: str) -> str: + if not value.startswith("Powertools"): + raise ValueError("Full name must start with 'Powertools'") + return value + + class QuerySimple(BaseModel): + full_name: Annotated[str, StringConstraints(min_length=5), AfterValidator(_validate_powertools)] + next_token: Base64UrlStr + search_id: str + + @app.get("/query-model-simple") + def query_model(params: Annotated[QuerySimple, Query()]) -> Dict[str, Any]: + return { + "fullName": params.full_name, + "nextToken": params.next_token, + "searchId": params.search_id, + } + + class QueryAdvanced(BaseModel): + full_name: Annotated[str, StringConstraints(min_length=5)] + next_token: str + search_id: Annotated[str, Field(alias="id")] # Using str instead of UUID4 for simpler testing + + model_config = ConfigDict( + alias_generator=alias_generators.to_camel, + validate_by_alias=True, + validate_by_name=True, + serialize_by_alias=True, + ) + + @app.get("/query-model-advanced") + def query_model_advanced(params: Annotated[QueryAdvanced, Query()]) -> Dict[str, Any]: + return params.model_dump() + + # Test QuerySimple with validators + gw_event["path"] = "/query-model-simple" + gw_event["queryStringParameters"] = { + "full_name": "Powertools Lambda", + "next_token": "dGVzdA==", # base64url encoded "test" + "search_id": "search-123", + } + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["fullName"] == "Powertools Lambda" + assert body["nextToken"] == "test" + assert body["searchId"] == "search-123" + + # Test QuerySimple validation error (name doesn't start with "Powertools") + gw_event["queryStringParameters"] = { + "full_name": "Lambda Powertools", + "next_token": "dGVzdA==", + "search_id": "search-123", + } + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + errors = body["detail"] + + # Should have validation error for full_name with proper location + full_name_error = next((e for e in errors if "full_name" in e["loc"]), None) + + assert full_name_error is not None, "Should have error for full_name field" + + # Check error details for full_name + assert full_name_error["loc"] == ["query", "params", "full_name"] + assert full_name_error["type"] == "value_error" + + # Test QueryAdvanced with ConfigDict and alias_generator + gw_event["path"] = "/query-model-advanced" + gw_event["queryStringParameters"] = { + "fullName": "Advanced Test", # camelCase from alias_generator + "nextToken": "dGVzdA==", # camelCase from alias_generator + "id": "search-456", # explicit alias + } + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + # Should return with camelCase keys due to serialize_by_alias=True + assert body["fullName"] == "Advanced Test" + assert body["nextToken"] == "dGVzdA==" + assert body["id"] == "search-456" + + # Test QueryAdvanced with snake_case field names due to validate_by_name=True + gw_event["queryStringParameters"] = { + "full_name": "Snake Case Test", # snake_case field name + "next_token": "dGVzdA==", # snake_case field name + "search_id": "search-789", # snake_case field name + } + + gw_event["path"] = "/query-model-advanced" + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + body = json.loads(result["body"]) + assert body["fullName"] == "Snake Case Test" + assert body["nextToken"] == "dGVzdA==" + assert body["id"] == "search-789" + + # Test QueryAdvanced validation error (full_name too short) + gw_event["queryStringParameters"] = { + "fullName": "Bad", # Too short (min_length=5) + "nextToken": "dGVzdA==", + "id": "search-456", + } + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + body = json.loads(result["body"]) + assert "detail" in body + errors = body["detail"] + + # Should have validation error for full_name with proper location + full_name_error = next((e for e in errors if "full_name" in e["loc"] or "fullName" in e["loc"]), None) + assert full_name_error is not None + assert full_name_error["type"] == "string_too_short"