Skip to content

Commit d3b46be

Browse files
committed
fix typing
1 parent 2d523bd commit d3b46be

File tree

2 files changed

+61
-34
lines changed

2 files changed

+61
-34
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -812,10 +812,8 @@ def _openapi_operation_parameters(
812812
"""
813813
Returns the OpenAPI operation parameters.
814814
"""
815-
from aws_lambda_powertools.event_handler.openapi.compat import (
816-
get_schema_from_model_field,
817-
)
818-
from aws_lambda_powertools.event_handler.openapi.params import Form, Header, Param, Query
815+
from aws_lambda_powertools.event_handler.openapi.compat import get_schema_from_model_field
816+
from aws_lambda_powertools.event_handler.openapi.params import Param
819817

820818
parameters = []
821819
parameter: dict[str, Any] = {}
@@ -831,11 +829,13 @@ def _openapi_operation_parameters(
831829

832830
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass
833831

834-
if isinstance(field_info, (Query, Header, Form)) and lenient_issubclass(field_info.annotation, BaseModel):
832+
if lenient_issubclass(field_info.annotation, BaseModel):
835833
# Expand Pydantic model into individual parameters
836-
model_class = field_info.annotation
834+
model_class = cast(type[BaseModel], field_info.annotation)
837835

838836
for field_name, field_def in model_class.model_fields.items():
837+
if not field_def.annotation:
838+
continue
839839
# Create individual parameter for each model field
840840
param_name = field_def.alias or field_name
841841

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import logging
66
from copy import deepcopy
7-
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, get_origin
7+
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast, get_origin
88
from urllib.parse import parse_qs
99

1010
from pydantic import BaseModel
@@ -456,32 +456,59 @@ def _normalize_multi_params(
456456
"""
457457
for param in params:
458458
if is_scalar_field(param):
459-
try:
460-
val = input_dict[param.alias]
461-
if isinstance(val, list) and len(val) == 1:
462-
input_dict[param.alias] = val[0]
463-
elif isinstance(val, list):
464-
pass # leave as list for multi-value
465-
# If it's a string, leave as is
466-
except KeyError:
467-
pass
459+
_process_scalar_param(input_dict, param)
468460
elif lenient_issubclass(param.field_info.annotation, BaseModel):
469-
model_class = param.field_info.annotation
470-
model_data = {}
471-
472-
for field_name, field_def in model_class.model_fields.items():
473-
field_alias = field_def.alias or field_name
474-
value = input_dict.get(field_alias)
475-
if value is None and (
476-
model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name")
477-
):
478-
value = input_dict.get(field_name)
479-
if value is not None:
480-
if get_origin(field_def.annotation) is list:
481-
model_data[field_alias] = value
482-
elif isinstance(value, list):
483-
model_data[field_alias] = value[0]
484-
else:
485-
model_data[field_alias] = value
486-
input_dict[param.alias] = model_data
461+
_process_model_param(input_dict, param)
487462
return input_dict
463+
464+
465+
def _process_scalar_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None:
466+
"""Process a scalar parameter by normalizing single-item lists."""
467+
try:
468+
val = input_dict[param.alias]
469+
if isinstance(val, list) and len(val) == 1:
470+
input_dict[param.alias] = val[0]
471+
except KeyError:
472+
pass
473+
474+
475+
def _process_model_param(input_dict: MutableMapping[str, Any], param: ModelField) -> None:
476+
"""Process a Pydantic model parameter by extracting model fields."""
477+
model_class = cast(type[BaseModel], param.field_info.annotation)
478+
479+
model_data = {}
480+
for field_name, field_def in model_class.model_fields.items():
481+
field_alias = field_def.alias or field_name
482+
value = _get_param_value(input_dict, field_alias, field_name, model_class)
483+
484+
if value is not None:
485+
model_data[field_alias] = _normalize_field_value(value, field_def)
486+
487+
input_dict[param.alias] = model_data
488+
489+
490+
def _get_param_value(
491+
input_dict: MutableMapping[str, Any],
492+
field_alias: str,
493+
field_name: str,
494+
model_class: type[BaseModel],
495+
) -> Any:
496+
"""Get parameter value, checking both alias and field name if needed."""
497+
value = input_dict.get(field_alias)
498+
if value is not None:
499+
return value
500+
501+
if model_class.model_config.get("validate_by_name") or model_class.model_config.get("populate_by_name"):
502+
value = input_dict.get(field_name)
503+
504+
return value
505+
506+
507+
def _normalize_field_value(value: Any, field_def: Any) -> Any:
508+
"""Normalize field value based on its type annotation."""
509+
if get_origin(field_def.annotation) is list:
510+
return value
511+
elif isinstance(value, list) and value:
512+
return value[0]
513+
else:
514+
return value

0 commit comments

Comments
 (0)