Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 109 additions & 26 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
Server,
Tag,
)
from aws_lambda_powertools.event_handler.openapi.params import Dependant
from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param
from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import (
OAuth2Config,
)
Expand Down Expand Up @@ -812,46 +812,129 @@ def _openapi_operation_parameters(
"""
Returns the OpenAPI operation parameters.
"""
from aws_lambda_powertools.event_handler.openapi.compat import (
get_schema_from_model_field,
)
from aws_lambda_powertools.event_handler.openapi.params import Param

parameters = []
parameter: dict[str, Any] = {}
parameters: list[dict[str, Any]] = []

for param in all_route_params:
field_info = param.field_info
field_info = cast(Param, field_info)
field_info = cast(Param, param.field_info)
if not field_info.include_in_schema:
continue

param_schema = get_schema_from_model_field(
field=param,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
# Check if this is a Pydantic model that should be expanded
if Route._is_pydantic_model_param(field_info):
parameters.extend(Route._expand_pydantic_model_parameters(field_info))
else:
parameters.append(Route._create_regular_parameter(param, model_name_map, field_mapping))

parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": param_schema,
}
return parameters

if field_info.description:
parameter["description"] = field_info.description
@staticmethod
def _is_pydantic_model_param(field_info: ModelField | Param) -> bool:
"""Check if the field info represents a Pydantic model parameter."""
from pydantic import BaseModel

if field_info.openapi_examples:
parameter["examples"] = field_info.openapi_examples
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
from aws_lambda_powertools.event_handler.openapi.params import Param

if field_info.deprecated:
parameter["deprecated"] = field_info.deprecated
if not isinstance(field_info, Param):
return False
return lenient_issubclass(field_info.annotation, BaseModel)

@staticmethod
def _expand_pydantic_model_parameters(field_info: Param) -> list[dict[str, Any]]:
"""Expand a Pydantic model into individual OpenAPI parameters."""
from pydantic import BaseModel

parameters.append(parameter)
model_class = cast(type[BaseModel], field_info.annotation)
parameters: list[dict[str, Any]] = []

for field_name, field_def in model_class.model_fields.items():
if not field_def.annotation:
continue

param_name = field_def.alias or field_name
individual_param = Route._create_pydantic_field_parameter(
param_name=param_name,
field_def=field_def,
param_location=field_info.in_.value,
)
parameters.append(individual_param)

return parameters

@staticmethod
def _create_pydantic_field_parameter(
param_name: str,
field_def: Any,
param_location: str,
) -> dict[str, Any]:
"""Create an OpenAPI parameter from a Pydantic field definition."""
individual_param: dict[str, Any] = {
"name": param_name,
"in": param_location,
"required": field_def.is_required() if hasattr(field_def, "is_required") else field_def.default is ...,
"schema": Route._get_basic_type_schema(field_def.annotation or type(None)),
}

if field_def.description:
individual_param["description"] = field_def.description

return individual_param

@staticmethod
def _create_regular_parameter(
param: ModelField,
model_name_map: dict[TypeModelOrEnum, str],
field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
) -> dict[str, Any]:
"""Create an OpenAPI parameter from a regular ModelField."""
from aws_lambda_powertools.event_handler.openapi.compat import get_schema_from_model_field
from aws_lambda_powertools.event_handler.openapi.params import Param

field_info = cast(Param, param.field_info)
param_schema = get_schema_from_model_field(
field=param,
model_name_map=model_name_map,
field_mapping=field_mapping,
)

parameter: dict[str, Any] = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": param_schema,
}

# Add optional attributes if present
if field_info.description:
parameter["description"] = field_info.description
if field_info.openapi_examples:
parameter["examples"] = field_info.openapi_examples
if field_info.deprecated:
parameter["deprecated"] = field_info.deprecated

return parameter

@staticmethod
def _get_basic_type_schema(param_type: type) -> dict[str, str]:
"""
Get basic OpenAPI schema for simple types
"""
try:
# Check bool before int, since bool is a subclass of int in Python
if issubclass(param_type, bool):
return {"type": "boolean"}
elif issubclass(param_type, int):
return {"type": "integer"}
elif issubclass(param_type, float):
return {"type": "number"}
else:
return {"type": "string"}
except TypeError:
# param_type may not be a type (e.g., typing.Optional[int]), fallback to string
return {"type": "string"}

@staticmethod
def _openapi_operation_return(
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast, get_origin
from urllib.parse import parse_qs

from pydantic import BaseModel
Expand All @@ -15,6 +15,7 @@
_normalize_errors,
_regenerate_error_with_loc,
get_missing_field_error,
lenient_issubclass,
)
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
Expand Down Expand Up @@ -64,7 +65,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
)

# Normalize query values before validate this
query_string = _normalize_multi_query_string_with_param(
query_string = _normalize_multi_params(
app.current_event.resolved_query_string_parameters,
route.dependant.query_params,
)
Expand All @@ -76,7 +77,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
)

# Normalize header values before validate this
headers = _normalize_multi_header_values_with_param(
headers = _normalize_multi_params(
app.current_event.resolved_headers_field,
route.dependant.header_params,
)
Expand Down Expand Up @@ -434,57 +435,80 @@ def _get_embed_body(
return received_body, field_alias_omitted


def _normalize_multi_query_string_with_param(
query_string: dict[str, list[str]],
def _normalize_multi_params(
input_dict: MutableMapping[str, Any],
params: Sequence[ModelField],
) -> dict[str, Any]:
) -> MutableMapping[str, Any]:
"""
Extract and normalize resolved_query_string_parameters
Extract and normalize query string or header parameters with Pydantic model support.

Parameters
----------
query_string: dict
A dictionary containing the initial query string parameters.
input_dict: MutableMapping[str, Any]
A dictionary containing the initial query string or header parameters.
params: Sequence[ModelField]
A sequence of ModelField objects representing parameters.

Returns
-------
A dictionary containing the processed multi_query_string_parameters.
MutableMapping[str, Any]
A dictionary containing the processed parameters with normalized values.
"""
resolved_query_string: dict[str, Any] = query_string
for param in filter(is_scalar_field, params):
try:
# if the target parameter is a scalar, we keep the first value of the query string
# regardless if there are more in the payload
resolved_query_string[param.alias] = query_string[param.alias][0]
except KeyError:
pass
return resolved_query_string
for param in params:
if is_scalar_field(param):
_process_scalar_param(input_dict, param)
elif lenient_issubclass(param.field_info.annotation, BaseModel):
_process_model_param(input_dict, param)
return input_dict


def _process_scalar_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None:
"""Process a scalar parameter by normalizing single-item lists."""
try:
val = input_dict[param.alias]
if isinstance(val, list) and len(val) == 1:
input_dict[param.alias] = val[0]
except KeyError:
pass


def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]):
"""
Extract and normalize resolved_headers_field
def _process_model_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None:
"""Process a Pydantic model parameter by extracting model fields."""
model_class = cast(type[BaseModel], param.field_info.annotation)

Parameters
----------
headers: MutableMapping[str, Any]
A dictionary containing the initial header parameters.
params: Sequence[ModelField]
A sequence of ModelField objects representing parameters.
model_data = {}
for field_name, field_def in model_class.model_fields.items():
field_alias = field_def.alias or field_name
value = _get_param_value(input_dict, field_alias, field_name, model_class)

Returns
-------
A dictionary containing the processed headers.
"""
if headers:
for param in filter(is_scalar_field, params):
try:
if len(headers[param.alias]) == 1:
# if the target parameter is a scalar and the list contains only 1 element
# we keep the first value of the headers regardless if there are more in the payload
headers[param.alias] = headers[param.alias][0]
except KeyError:
pass
return headers
if value is not None:
model_data[field_alias] = _normalize_field_value(value, field_def)

input_dict[param.alias] = model_data


def _get_param_value(
input_dict: MutableMapping[str, Any],
field_alias: str,
field_name: str,
model_class: type[BaseModel],
) -> Any:
"""Get parameter value, checking both alias and field name if needed."""
value = input_dict.get(field_alias)
if value is not None:
return value

if model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name"):
value = input_dict.get(field_name)

return value


def _normalize_field_value(value: Any, field_def: Any) -> Any:
"""Normalize field value based on its type annotation."""
if get_origin(field_def.annotation) is list:
return value
elif isinstance(value, list) and value:
return value[0]
else:
return value
3 changes: 1 addition & 2 deletions aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
create_body_model,
evaluate_forwardref,
is_scalar_field,
is_scalar_sequence_field,
)
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
Expand Down Expand Up @@ -275,7 +274,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
return False
elif is_scalar_field(field=param_field):
return False
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
elif isinstance(param_field.field_info, (Query, Header)):
return False
else:
if not isinstance(param_field.field_info, Body):
Expand Down
Loading