Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6ca987
feat(event-handler): add support for Pydantic Field discriminator in …
dap0am Aug 21, 2025
63a0225
style(tests): remove inline comments to match project test style
dap0am Aug 21, 2025
d35753c
Merge branch 'develop' into develop
leandrodamascena Aug 21, 2025
0023b3a
style: run make format to fix CI formatting issues
dap0am Aug 21, 2025
62bb9c9
Merge branch 'develop' into develop
leandrodamascena Aug 29, 2025
ead3ee8
fix(event-handler): preserve FieldInfo subclass types in copy_field_info
dap0am Sep 3, 2025
0c3aad6
Merge branch 'develop' into develop
dreamorosi Sep 4, 2025
e2c8b49
refactor(event-handler): reduce cognitive complexity and address Sona…
dap0am Sep 4, 2025
d61e531
Merge branch 'develop' into develop
leandrodamascena Sep 4, 2025
95a5eba
style: fix formatting to pass CI format check
dap0am Sep 5, 2025
913c580
Merge branch 'develop' into develop
leandrodamascena Sep 8, 2025
d839919
fix: resolve mypy type error in _create_field_info function
dap0am Sep 8, 2025
3a53a67
Merge branch 'develop' into develop
dreamorosi Sep 8, 2025
5762c94
fix: use Union syntax for Python 3.9 compatibility
dap0am Sep 8, 2025
07ade23
Merge branch 'develop' into develop
leandrodamascena Sep 8, 2025
89fb5f8
Merge branch 'develop' into develop
leandrodamascena Sep 9, 2025
793a097
feat(event-handler): add documentation and example for Field discrimi…
dap0am Sep 9, 2025
9eac24a
Merge branch 'develop' into develop
leandrodamascena Sep 9, 2025
4400181
style: run make format to fix CI formatting issues
dap0am Sep 9, 2025
5cc265e
Merge branch 'develop' into develop
leandrodamascena Sep 10, 2025
587d2fa
small changes
leandrodamascena Sep 11, 2025
96ab0a6
Merge branch 'develop' into develop
leandrodamascena Sep 11, 2025
851992f
small changes
leandrodamascena Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions aws_lambda_powertools/event_handler/openapi/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

from collections import deque
from collections.abc import Mapping, Sequence

# MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different
# versions of a module, so we need to ignore errors here.
from copy import copy
from dataclasses import dataclass, is_dataclass
from typing import TYPE_CHECKING, Any, Deque, FrozenSet, List, Set, Tuple, Union

Expand Down Expand Up @@ -80,9 +78,19 @@ def type_(self) -> Any:
return self.field_info.annotation

def __post_init__(self) -> None:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info],
)
# If the field_info.annotation is already an Annotated type with discriminator metadata,
# use it directly instead of wrapping it again
annotation = self.field_info.annotation
if (
get_origin(annotation) is Annotated
and hasattr(self.field_info, "discriminator")
and self.field_info.discriminator is not None
):
self._type_adapter: TypeAdapter[Any] = TypeAdapter(annotation)
else:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[annotation, self.field_info],
)

def get_default(self) -> Any:
if self.field_info.is_required():
Expand Down Expand Up @@ -176,7 +184,11 @@ def model_rebuild(model: type[BaseModel]) -> None:


def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return type(field_info).from_annotation(annotation)
# Create a shallow copy of the field_info to preserve its type and all attributes
new_field = copy(field_info)
# Update only the annotation to the new one
new_field.annotation = annotation
return new_field


def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
Expand Down
102 changes: 84 additions & 18 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,35 +1037,101 @@ def get_field_info_response_type(annotation, value) -> tuple[FieldInfo | None, A
return get_field_info_and_type_annotation(inner_type, value, False, True)


def _has_discriminator(field_info: FieldInfo) -> bool:
"""Check if a FieldInfo has a discriminator."""
return hasattr(field_info, "discriminator") and field_info.discriminator is not None


def _handle_discriminator_with_param(
annotations: list[FieldInfo],
annotation: Any,
) -> tuple[FieldInfo | None, Any, bool]:
"""
Handle the special case of Field(discriminator) + Body() combination.

Returns:
tuple of (powertools_annotation, type_annotation, has_discriminator_with_body)
"""
field_obj = None
body_obj = None

for ann in annotations:
if isinstance(ann, Body):
body_obj = ann
elif _has_discriminator(ann):
field_obj = ann

if field_obj and body_obj:
# Use Body as the primary annotation, preserve full annotation for validation
return body_obj, annotation, True

raise AssertionError("Only one FieldInfo can be used per parameter")


def _create_field_info(
powertools_annotation: FieldInfo,
type_annotation: Any,
has_discriminator_with_body: bool,
) -> FieldInfo:
"""Create or copy FieldInfo based on the annotation type."""
field_info: FieldInfo
if has_discriminator_with_body:
# For discriminator + Body case, create a new Body instance directly
field_info = Body()
field_info.annotation = type_annotation
else:
# Copy field_info because we mutate field_info.default later
field_info = copy_field_info(
field_info=powertools_annotation,
annotation=type_annotation,
)
return field_info


def _set_field_default(field_info: FieldInfo, value: Any, is_path_param: bool) -> None:
"""Set the default value for a field."""
if field_info.default not in [Undefined, Required]:
raise AssertionError("FieldInfo needs to have a default value of Undefined or Required")

if value is not inspect.Signature.empty:
if is_path_param:
raise AssertionError("Cannot use a FieldInfo as a path parameter and pass a value")
field_info.default = value
else:
field_info.default = Required


def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tuple[FieldInfo | None, Any]:
"""
Get the FieldInfo and type annotation from an Annotated type.
"""
field_info: FieldInfo | None = None
annotated_args = get_args(annotation)
type_annotation = annotated_args[0]
powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)]

if len(powertools_annotations) > 1:
raise AssertionError("Only one FieldInfo can be used per parameter")

powertools_annotation = next(iter(powertools_annotations), None)
# Determine which annotation to use
powertools_annotation: FieldInfo | None = None
has_discriminator_with_param = False

if isinstance(powertools_annotation, FieldInfo):
# Copy `field_info` because we mutate `field_info.default` later
field_info = copy_field_info(
field_info=powertools_annotation,
annotation=annotation,
if len(powertools_annotations) == 2:
powertools_annotation, type_annotation, has_discriminator_with_param = _handle_discriminator_with_param(
powertools_annotations,
annotation,
)
if field_info.default not in [Undefined, Required]:
raise AssertionError("FieldInfo needs to have a default value of Undefined or Required")
elif len(powertools_annotations) > 1:
raise AssertionError("Only one FieldInfo can be used per parameter")
else:
powertools_annotation = next(iter(powertools_annotations), None)

if value is not inspect.Signature.empty:
if is_path_param:
raise AssertionError("Cannot use a FieldInfo as a path parameter and pass a value")
field_info.default = value
else:
field_info.default = Required
# Process the annotation if it exists
field_info: FieldInfo | None = None
if isinstance(powertools_annotation, FieldInfo): # pragma: no cover
field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_param)
_set_field_default(field_info, value, is_path_param)

# Preserve full annotated type for discriminated unions
if _has_discriminator(powertools_annotation): # pragma: no cover
type_annotation = annotation # pragma: no cover

return field_info, type_annotation

Expand Down
11 changes: 11 additions & 0 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,17 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou
--8<-- "examples/event_handler_rest/src/validating_payload_subset_output.json"
```

##### Discriminated unions

You can use Pydantic's `Field(discriminator="...")` with union types to create discriminated unions (also known as tagged unions). This allows the Event Handler to automatically determine which model to use based on a discriminator field in the request body.

```python hl_lines="3 4 8 31 36" title="discriminated_unions.py"
--8<-- "examples/event_handler_rest/src/discriminated_unions.py"
```

1. `Field(discriminator="action")` tells Pydantic to use the `action` field to determine which model to instantiate
2. `Body()` annotation tells the Event Handler to parse the request body using the discriminated union

#### Validating responses

You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`.
Expand Down
47 changes: 47 additions & 0 deletions examples/event_handler_rest/src/discriminated_unions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Literal, Union

from pydantic import BaseModel, Field
from typing_extensions import Annotated

from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.params import Body
from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.utilities.typing import LambdaContext

tracer = Tracer()
logger = Logger()
app = APIGatewayRestResolver(enable_validation=True)


class FooAction(BaseModel):
"""Action type for foo operations."""

action: Literal["foo"] = "foo"
foo_data: str


class BarAction(BaseModel):
"""Action type for bar operations."""

action: Literal["bar"] = "bar"
bar_data: int


ActionType = Annotated[Union[FooAction, BarAction], Field(discriminator="action")] # (1)!


@app.post("/actions")
@tracer.capture_method
def handle_action(action: Annotated[ActionType, Body(description="Action to perform")]): # (2)!
"""Handle different action types using discriminated unions."""
if isinstance(action, FooAction):
return {"message": f"Handling foo action with data: {action.foo_data}"}
elif isinstance(action, BarAction):
return {"message": f"Handling bar action with data: {action.bar_data}"}


@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP)
@tracer.capture_lambda_handler
def lambda_handler(event: dict, context: LambdaContext) -> dict:
return app.resolve(event, context)
34 changes: 34 additions & 0 deletions tests/e2e/event_handler/handlers/data_validation_with_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

from typing import Annotated, Literal

from pydantic import BaseModel, Field

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.params import Body

app = APIGatewayRestResolver(enable_validation=True)
app.enable_swagger()


class FooAction(BaseModel):
action: Literal["foo"]
foo_data: str


class BarAction(BaseModel):
action: Literal["bar"]
bar_data: int


Action = Annotated[FooAction | BarAction, Field(discriminator="action")]


@app.post("/data_validation_with_fields")
def create_action(action: Annotated[Action, Body(discriminator="action")]):
return {"message": "Powertools e2e API"}


def lambda_handler(event, context):
print(event)
return app.resolve(event, context)
4 changes: 4 additions & 0 deletions tests/e2e/event_handler/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def create_resources(self):
functions["OpenapiHandler"],
functions["OpenapiHandlerWithPep563"],
functions["DataValidationAndMiddleware"],
functions["DataValidationWithFields"],
],
)
self._create_api_gateway_http(function=functions["ApiGatewayHttpHandler"])
Expand Down Expand Up @@ -105,6 +106,9 @@ def _create_api_gateway_rest(self, function: list[Function]):
openapi_schema = apigw.root.add_resource("data_validation_middleware")
openapi_schema.add_method("GET", apigwv1.LambdaIntegration(function[3], proxy=True))

openapi_schema = apigw.root.add_resource("data_validation_with_fields")
openapi_schema.add_method("POST", apigwv1.LambdaIntegration(function[4], proxy=True))

CfnOutput(self.stack, "APIGatewayRestUrl", value=apigw.url)

def _create_lambda_function_url(self, function: Function):
Expand Down
17 changes: 17 additions & 0 deletions tests/e2e/event_handler/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,20 @@ def test_get_openapi_validation_and_middleware(apigw_rest_endpoint):
)

assert response.status_code == 202


def test_openapi_with_fields_discriminator(apigw_rest_endpoint):
# GIVEN
url = f"{apigw_rest_endpoint}data_validation_with_fields"

# WHEN
response = data_fetcher.get_http_response(
Request(
method="POST",
url=url,
json={"action": "foo", "foo_data": "foo data working"},
),
)

assert "Powertools e2e API" in response.text
assert response.status_code == 200
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import PurePath
from typing import List, Optional, Tuple
from typing import List, Literal, Optional, Tuple, Union

import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from aws_lambda_powertools.event_handler import (
Expand Down Expand Up @@ -1983,3 +1983,50 @@ def get_user(user_id: int) -> UserModel:
assert response_body["name"] == "User123"
assert response_body["age"] == 143
assert response_body["email"] == "[email protected]"


def test_field_discriminator_validation(gw_event):
"""Test that Pydantic Field discriminator works with event_handler validation"""
app = APIGatewayRestResolver(enable_validation=True)

class FooAction(BaseModel):
action: Literal["foo"]
foo_data: str

class BarAction(BaseModel):
action: Literal["bar"]
bar_data: int

action_type = Annotated[Union[FooAction, BarAction], Field(discriminator="action")]

@app.post("/actions")
def create_action(action: Annotated[action_type, Body()]):
return {"received_action": action.action, "data": action.model_dump()}

gw_event["path"] = "/actions"
gw_event["httpMethod"] = "POST"
gw_event["headers"]["content-type"] = "application/json"
gw_event["body"] = '{"action": "foo", "foo_data": "test"}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["received_action"] == "foo"
assert response_body["data"]["action"] == "foo"
assert response_body["data"]["foo_data"] == "test"

gw_event["body"] = '{"action": "bar", "bar_data": 123}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["received_action"] == "bar"
assert response_body["data"]["action"] == "bar"
assert response_body["data"]["bar_data"] == 123

gw_event["body"] = '{"action": "invalid", "some_data": "test"}'

result = app(gw_event, {})
assert result["statusCode"] == 422