Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6ca987
feat(event-handler): add support for Pydantic Field discriminator in …
dap0am Aug 21, 2025
63a0225
style(tests): remove inline comments to match project test style
dap0am Aug 21, 2025
d35753c
Merge branch 'develop' into develop
leandrodamascena Aug 21, 2025
0023b3a
style: run make format to fix CI formatting issues
dap0am Aug 21, 2025
62bb9c9
Merge branch 'develop' into develop
leandrodamascena Aug 29, 2025
ead3ee8
fix(event-handler): preserve FieldInfo subclass types in copy_field_info
dap0am Sep 3, 2025
0c3aad6
Merge branch 'develop' into develop
dreamorosi Sep 4, 2025
e2c8b49
refactor(event-handler): reduce cognitive complexity and address Sona…
dap0am Sep 4, 2025
d61e531
Merge branch 'develop' into develop
leandrodamascena Sep 4, 2025
95a5eba
style: fix formatting to pass CI format check
dap0am Sep 5, 2025
913c580
Merge branch 'develop' into develop
leandrodamascena Sep 8, 2025
d839919
fix: resolve mypy type error in _create_field_info function
dap0am Sep 8, 2025
3a53a67
Merge branch 'develop' into develop
dreamorosi Sep 8, 2025
5762c94
fix: use Union syntax for Python 3.9 compatibility
dap0am Sep 8, 2025
07ade23
Merge branch 'develop' into develop
leandrodamascena Sep 8, 2025
89fb5f8
Merge branch 'develop' into develop
leandrodamascena Sep 9, 2025
793a097
feat(event-handler): add documentation and example for Field discrimi…
dap0am Sep 9, 2025
9eac24a
Merge branch 'develop' into develop
leandrodamascena Sep 9, 2025
4400181
style: run make format to fix CI formatting issues
dap0am Sep 9, 2025
5cc265e
Merge branch 'develop' into develop
leandrodamascena Sep 10, 2025
587d2fa
small changes
leandrodamascena Sep 11, 2025
96ab0a6
Merge branch 'develop' into develop
leandrodamascena Sep 11, 2025
851992f
small changes
leandrodamascena Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions aws_lambda_powertools/event_handler/openapi/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,19 @@ def type_(self) -> Any:
return self.field_info.annotation

def __post_init__(self) -> None:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info],
)
# If the field_info.annotation is already an Annotated type with discriminator metadata,
# use it directly instead of wrapping it again
annotation = self.field_info.annotation
if (
get_origin(annotation) is Annotated
and hasattr(self.field_info, "discriminator")
and self.field_info.discriminator is not None
):
self._type_adapter: TypeAdapter[Any] = TypeAdapter(annotation)
else:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[annotation, self.field_info],
)

def get_default(self) -> Any:
if self.field_info.is_required():
Expand Down Expand Up @@ -176,7 +186,13 @@ def model_rebuild(model: type[BaseModel]) -> None:


def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return type(field_info).from_annotation(annotation)
# Create a shallow copy of the field_info to preserve its type and all attributes
from copy import copy

new_field = copy(field_info)
# Update only the annotation to the new one
new_field.annotation = annotation
return new_field


def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
Expand Down
100 changes: 83 additions & 17 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,35 +1037,101 @@ def get_field_info_response_type(annotation, value) -> tuple[FieldInfo | None, A
return get_field_info_and_type_annotation(inner_type, value, False, True)


def _has_discriminator(field_info: FieldInfo) -> bool:
"""Check if a FieldInfo has a discriminator."""
return hasattr(field_info, "discriminator") and field_info.discriminator is not None


def _handle_discriminator_with_body(
annotations: list[FieldInfo],
annotation: Any,
) -> tuple[FieldInfo | None, Any, bool]:
"""
Handle the special case of Field(discriminator) + Body() combination.

Returns:
tuple of (powertools_annotation, type_annotation, has_discriminator_with_body)
"""
field_obj = None
body_obj = None

for ann in annotations:
if isinstance(ann, Body):
body_obj = ann
elif _has_discriminator(ann):
field_obj = ann

if field_obj and body_obj:
# Use Body as the primary annotation, preserve full annotation for validation
return body_obj, annotation, True

raise AssertionError("Only one FieldInfo can be used per parameter")


def _create_field_info(
powertools_annotation: FieldInfo,
type_annotation: Any,
has_discriminator_with_body: bool,
) -> FieldInfo:
"""Create or copy FieldInfo based on the annotation type."""
field_info: FieldInfo
if has_discriminator_with_body:
# For discriminator + Body case, create a new Body instance directly
field_info = Body()
field_info.annotation = type_annotation
else:
# Copy field_info because we mutate field_info.default later
field_info = copy_field_info(
field_info=powertools_annotation,
annotation=type_annotation,
)
return field_info


def _set_field_default(field_info: FieldInfo, value: Any, is_path_param: bool) -> None:
"""Set the default value for a field."""
if field_info.default not in [Undefined, Required]:
raise AssertionError("FieldInfo needs to have a default value of Undefined or Required")

if value is not inspect.Signature.empty:
if is_path_param:
raise AssertionError("Cannot use a FieldInfo as a path parameter and pass a value")
field_info.default = value
else:
field_info.default = Required


def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tuple[FieldInfo | None, Any]:
"""
Get the FieldInfo and type annotation from an Annotated type.
"""
field_info: FieldInfo | None = None
annotated_args = get_args(annotation)
type_annotation = annotated_args[0]
powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)]

if len(powertools_annotations) > 1:
raise AssertionError("Only one FieldInfo can be used per parameter")
# Determine which annotation to use
powertools_annotation: FieldInfo | None = None
has_discriminator_with_body = False

powertools_annotation = next(iter(powertools_annotations), None)
if len(powertools_annotations) == 2:
powertools_annotation, type_annotation, has_discriminator_with_body = _handle_discriminator_with_body(
powertools_annotations,
annotation,
)
elif len(powertools_annotations) > 1:
raise AssertionError("Only one FieldInfo can be used per parameter")
else:
powertools_annotation = next(iter(powertools_annotations), None)

# Process the annotation if it exists
field_info: FieldInfo | None = None
if isinstance(powertools_annotation, FieldInfo):
# Copy `field_info` because we mutate `field_info.default` later
field_info = copy_field_info(
field_info=powertools_annotation,
annotation=annotation,
)
if field_info.default not in [Undefined, Required]:
raise AssertionError("FieldInfo needs to have a default value of Undefined or Required")
field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_body)
_set_field_default(field_info, value, is_path_param)

if value is not inspect.Signature.empty:
if is_path_param:
raise AssertionError("Cannot use a FieldInfo as a path parameter and pass a value")
field_info.default = value
else:
field_info.default = Required
# Preserve full annotated type for discriminated unions
if _has_discriminator(powertools_annotation):
type_annotation = annotation

return field_info, type_annotation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import PurePath
from typing import List, Optional, Tuple
from typing import List, Literal, Optional, Tuple, Union

import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from aws_lambda_powertools.event_handler import (
Expand Down Expand Up @@ -1983,3 +1983,82 @@ def get_user(user_id: int) -> UserModel:
assert response_body["name"] == "User123"
assert response_body["age"] == 143
assert response_body["email"] == "[email protected]"


def test_field_discriminator_validation(gw_event):
"""Test that Pydantic Field discriminator works with event_handler validation"""
app = APIGatewayRestResolver(enable_validation=True)

class FooAction(BaseModel):
action: Literal["foo"]
foo_data: str

class BarAction(BaseModel):
action: Literal["bar"]
bar_data: int

action_type = Annotated[Union[FooAction, BarAction], Field(discriminator="action")]

@app.post("/actions")
def create_action(action: Annotated[action_type, Body()]):
return {"received_action": action.action, "data": action.model_dump()}

gw_event["path"] = "/actions"
gw_event["httpMethod"] = "POST"
gw_event["headers"]["content-type"] = "application/json"
gw_event["body"] = '{"action": "foo", "foo_data": "test"}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["received_action"] == "foo"
assert response_body["data"]["action"] == "foo"
assert response_body["data"]["foo_data"] == "test"

gw_event["body"] = '{"action": "bar", "bar_data": 123}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["received_action"] == "bar"
assert response_body["data"]["action"] == "bar"
assert response_body["data"]["bar_data"] == 123

gw_event["body"] = '{"action": "invalid", "some_data": "test"}'

result = app(gw_event, {})
assert result["statusCode"] == 422


def test_field_other_features_still_work(gw_event):
"""Test that other Field features still work after discriminator fix"""
app = APIGatewayRestResolver(enable_validation=True)

class UserInput(BaseModel):
name: Annotated[str, Field(min_length=2, max_length=50, description="User name")]
age: Annotated[int, Field(ge=18, le=120, description="User age")]
email: Annotated[str, Field(pattern=r".+@.+\..+", description="User email")]

@app.post("/users")
def create_user(user: UserInput):
return {"created": user.model_dump()}

gw_event["path"] = "/users"
gw_event["httpMethod"] = "POST"
gw_event["headers"]["content-type"] = "application/json"
gw_event["body"] = '{"name": "John", "age": 25, "email": "[email protected]"}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["created"]["name"] == "John"
assert response_body["created"]["age"] == 25
assert response_body["created"]["email"] == "[email protected]"

gw_event["body"] = '{"name": "John", "age": 16, "email": "[email protected]"}'

result = app(gw_event, {})
assert result["statusCode"] == 422