Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6ca987
feat(event-handler): add support for Pydantic Field discriminator in …
dap0am Aug 21, 2025
63a0225
style(tests): remove inline comments to match project test style
dap0am Aug 21, 2025
d35753c
Merge branch 'develop' into develop
leandrodamascena Aug 21, 2025
0023b3a
style: run make format to fix CI formatting issues
dap0am Aug 21, 2025
62bb9c9
Merge branch 'develop' into develop
leandrodamascena Aug 29, 2025
ead3ee8
fix(event-handler): preserve FieldInfo subclass types in copy_field_info
dap0am Sep 3, 2025
0c3aad6
Merge branch 'develop' into develop
dreamorosi Sep 4, 2025
e2c8b49
refactor(event-handler): reduce cognitive complexity and address Sona…
dap0am Sep 4, 2025
d61e531
Merge branch 'develop' into develop
leandrodamascena Sep 4, 2025
95a5eba
style: fix formatting to pass CI format check
dap0am Sep 5, 2025
913c580
Merge branch 'develop' into develop
leandrodamascena Sep 8, 2025
d839919
fix: resolve mypy type error in _create_field_info function
dap0am Sep 8, 2025
3a53a67
Merge branch 'develop' into develop
dreamorosi Sep 8, 2025
5762c94
fix: use Union syntax for Python 3.9 compatibility
dap0am Sep 8, 2025
07ade23
Merge branch 'develop' into develop
leandrodamascena Sep 8, 2025
89fb5f8
Merge branch 'develop' into develop
leandrodamascena Sep 9, 2025
793a097
feat(event-handler): add documentation and example for Field discrimi…
dap0am Sep 9, 2025
9eac24a
Merge branch 'develop' into develop
leandrodamascena Sep 9, 2025
4400181
style: run make format to fix CI formatting issues
dap0am Sep 9, 2025
5cc265e
Merge branch 'develop' into develop
leandrodamascena Sep 10, 2025
587d2fa
small changes
leandrodamascena Sep 11, 2025
96ab0a6
Merge branch 'develop' into develop
leandrodamascena Sep 11, 2025
851992f
small changes
leandrodamascena Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions aws_lambda_powertools/event_handler/openapi/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,19 @@ def type_(self) -> Any:
return self.field_info.annotation

def __post_init__(self) -> None:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info],
)
# If the field_info.annotation is already an Annotated type with discriminator metadata,
# use it directly instead of wrapping it again
annotation = self.field_info.annotation
if (
get_origin(annotation) is Annotated
and hasattr(self.field_info, "discriminator")
and self.field_info.discriminator is not None
):
self._type_adapter: TypeAdapter[Any] = TypeAdapter(annotation)
else:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[annotation, self.field_info],
)

def get_default(self) -> Any:
if self.field_info.is_required():
Expand Down
52 changes: 44 additions & 8 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,17 +1046,47 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
type_annotation = annotated_args[0]
powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)]

if len(powertools_annotations) > 1:
# Special case: handle Field(discriminator) + Body() combination
# This happens when using Annotated[Union[A, B], Field(discriminator='...')] with Body()
has_discriminator_with_body = False
powertools_annotation: FieldInfo | None = None

if len(powertools_annotations) == 2:
field_obj = None
body_obj = None
for ann in powertools_annotations:
if isinstance(ann, Body):
body_obj = ann
elif isinstance(ann, FieldInfo) and hasattr(ann, "discriminator") and ann.discriminator is not None:
field_obj = ann

if field_obj and body_obj:
# Use Body as the primary annotation
powertools_annotation = body_obj
# Preserve the full annotation including discriminator for proper validation
# This ensures the discriminator is available when creating the TypeAdapter
type_annotation = annotation
has_discriminator_with_body = True
else:
raise AssertionError("Only one FieldInfo can be used per parameter")
elif len(powertools_annotations) > 1:
raise AssertionError("Only one FieldInfo can be used per parameter")

powertools_annotation = next(iter(powertools_annotations), None)
else:
powertools_annotation = next(iter(powertools_annotations), None)

if isinstance(powertools_annotation, FieldInfo):
# Copy `field_info` because we mutate `field_info.default` later
field_info = copy_field_info(
field_info=powertools_annotation,
annotation=annotation,
)
if has_discriminator_with_body:
# For discriminator + Body case, create a new Body instance directly
# This avoids issues with copy_field_info trying to process the Field
field_info = Body()
field_info.annotation = type_annotation
else:
# Copy `field_info` because we mutate `field_info.default` later
# Use the possibly modified type_annotation for copy_field_info
field_info = copy_field_info(
field_info=powertools_annotation,
annotation=type_annotation,
)
if field_info.default not in [Undefined, Required]:
raise AssertionError("FieldInfo needs to have a default value of Undefined or Required")

Expand All @@ -1067,6 +1097,12 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
else:
field_info.default = Required

# Preserve the full annotated type if it contains discriminator metadata
# This is crucial for tagged unions to work properly
if hasattr(powertools_annotation, "discriminator") and powertools_annotation.discriminator is not None:
# Store the full annotated type for discriminated unions
type_annotation = annotation

return field_info, type_annotation


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import PurePath
from typing import List, Optional, Tuple
from typing import List, Literal, Optional, Tuple

import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from aws_lambda_powertools.event_handler import (
Expand Down Expand Up @@ -1983,3 +1983,82 @@ def get_user(user_id: int) -> UserModel:
assert response_body["name"] == "User123"
assert response_body["age"] == 143
assert response_body["email"] == "[email protected]"


def test_field_discriminator_validation(gw_event):
"""Test that Pydantic Field discriminator works with event_handler validation"""
app = APIGatewayRestResolver(enable_validation=True)

class FooAction(BaseModel):
action: Literal["foo"]
foo_data: str

class BarAction(BaseModel):
action: Literal["bar"]
bar_data: int

Action = Annotated[FooAction | BarAction, Field(discriminator="action")]

@app.post("/actions")
def create_action(action: Annotated[Action, Body()]):
return {"received_action": action.action, "data": action.model_dump()}

gw_event["path"] = "/actions"
gw_event["httpMethod"] = "POST"
gw_event["headers"]["content-type"] = "application/json"
gw_event["body"] = '{"action": "foo", "foo_data": "test"}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["received_action"] == "foo"
assert response_body["data"]["action"] == "foo"
assert response_body["data"]["foo_data"] == "test"

gw_event["body"] = '{"action": "bar", "bar_data": 123}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["received_action"] == "bar"
assert response_body["data"]["action"] == "bar"
assert response_body["data"]["bar_data"] == 123

gw_event["body"] = '{"action": "invalid", "some_data": "test"}'

result = app(gw_event, {})
assert result["statusCode"] == 422


def test_field_other_features_still_work(gw_event):
"""Test that other Field features still work after discriminator fix"""
app = APIGatewayRestResolver(enable_validation=True)

class UserInput(BaseModel):
name: Annotated[str, Field(min_length=2, max_length=50, description="User name")]
age: Annotated[int, Field(ge=18, le=120, description="User age")]
email: Annotated[str, Field(pattern=r".+@.+\..+", description="User email")]

@app.post("/users")
def create_user(user: UserInput):
return {"created": user.model_dump()}

gw_event["path"] = "/users"
gw_event["httpMethod"] = "POST"
gw_event["headers"]["content-type"] = "application/json"
gw_event["body"] = '{"name": "John", "age": 25, "email": "[email protected]"}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["created"]["name"] == "John"
assert response_body["created"]["age"] == 25
assert response_body["created"]["email"] == "[email protected]"

gw_event["body"] = '{"name": "John", "age": 16, "email": "[email protected]"}'

result = app(gw_event, {})
assert result["statusCode"] == 422
Loading