Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions aws_lambda_powertools/event_handler/openapi/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,19 @@ def type_(self) -> Any:
return self.field_info.annotation

def __post_init__(self) -> None:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info],
)
# If the field_info.annotation is already an Annotated type with discriminator metadata,
# use it directly instead of wrapping it again
annotation = self.field_info.annotation
if (
get_origin(annotation) is Annotated
and hasattr(self.field_info, "discriminator")
and self.field_info.discriminator is not None
):
self._type_adapter: TypeAdapter[Any] = TypeAdapter(annotation)
else:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[annotation, self.field_info],
)

def get_default(self) -> Any:
if self.field_info.is_required():
Expand Down
52 changes: 44 additions & 8 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,17 +1046,47 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
type_annotation = annotated_args[0]
powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)]

if len(powertools_annotations) > 1:
# Special case: handle Field(discriminator) + Body() combination
# This happens when using Annotated[Union[A, B], Field(discriminator='...')] with Body()
has_discriminator_with_body = False
powertools_annotation: FieldInfo | None = None

if len(powertools_annotations) == 2:
field_obj = None
body_obj = None
for ann in powertools_annotations:
if isinstance(ann, Body):
body_obj = ann
elif isinstance(ann, FieldInfo) and hasattr(ann, "discriminator") and ann.discriminator is not None:
field_obj = ann

if field_obj and body_obj:
# Use Body as the primary annotation
powertools_annotation = body_obj
# Preserve the full annotation including discriminator for proper validation
# This ensures the discriminator is available when creating the TypeAdapter
type_annotation = annotation
has_discriminator_with_body = True
else:
raise AssertionError("Only one FieldInfo can be used per parameter")
elif len(powertools_annotations) > 1:
raise AssertionError("Only one FieldInfo can be used per parameter")

powertools_annotation = next(iter(powertools_annotations), None)
else:
powertools_annotation = next(iter(powertools_annotations), None)

if isinstance(powertools_annotation, FieldInfo):
# Copy `field_info` because we mutate `field_info.default` later
field_info = copy_field_info(
field_info=powertools_annotation,
annotation=annotation,
)
if has_discriminator_with_body:
# For discriminator + Body case, create a new Body instance directly
# This avoids issues with copy_field_info trying to process the Field
field_info = Body()
field_info.annotation = type_annotation
else:
# Copy `field_info` because we mutate `field_info.default` later
# Use the possibly modified type_annotation for copy_field_info
field_info = copy_field_info(
field_info=powertools_annotation,
annotation=type_annotation,
)
if field_info.default not in [Undefined, Required]:
raise AssertionError("FieldInfo needs to have a default value of Undefined or Required")

Expand All @@ -1067,6 +1097,12 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
else:
field_info.default = Required

# Preserve the full annotated type if it contains discriminator metadata
# This is crucial for tagged unions to work properly
if hasattr(powertools_annotation, "discriminator") and powertools_annotation.discriminator is not None:
# Store the full annotated type for discriminated unions
type_annotation = annotation

return field_info, type_annotation


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import PurePath
from typing import List, Optional, Tuple
from typing import List, Literal, Optional, Tuple

import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from aws_lambda_powertools.event_handler import (
Expand Down Expand Up @@ -1983,3 +1983,82 @@ def get_user(user_id: int) -> UserModel:
assert response_body["name"] == "User123"
assert response_body["age"] == 143
assert response_body["email"] == "[email protected]"


def test_field_discriminator_validation(gw_event):
"""Test that Pydantic Field discriminator works with event_handler validation"""
app = APIGatewayRestResolver(enable_validation=True)

class FooAction(BaseModel):
action: Literal["foo"]
foo_data: str

class BarAction(BaseModel):
action: Literal["bar"]
bar_data: int

Action = Annotated[FooAction | BarAction, Field(discriminator="action")]

@app.post("/actions")
def create_action(action: Annotated[Action, Body()]):
return {"received_action": action.action, "data": action.model_dump()}

gw_event["path"] = "/actions"
gw_event["httpMethod"] = "POST"
gw_event["headers"]["content-type"] = "application/json"
gw_event["body"] = '{"action": "foo", "foo_data": "test"}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["received_action"] == "foo"
assert response_body["data"]["action"] == "foo"
assert response_body["data"]["foo_data"] == "test"

gw_event["body"] = '{"action": "bar", "bar_data": 123}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["received_action"] == "bar"
assert response_body["data"]["action"] == "bar"
assert response_body["data"]["bar_data"] == 123

gw_event["body"] = '{"action": "invalid", "some_data": "test"}'

result = app(gw_event, {})
assert result["statusCode"] == 422


def test_field_other_features_still_work(gw_event):
"""Test that other Field features still work after discriminator fix"""
app = APIGatewayRestResolver(enable_validation=True)

class UserInput(BaseModel):
name: Annotated[str, Field(min_length=2, max_length=50, description="User name")]
age: Annotated[int, Field(ge=18, le=120, description="User age")]
email: Annotated[str, Field(pattern=r".+@.+\..+", description="User email")]

@app.post("/users")
def create_user(user: UserInput):
return {"created": user.model_dump()}

gw_event["path"] = "/users"
gw_event["httpMethod"] = "POST"
gw_event["headers"]["content-type"] = "application/json"
gw_event["body"] = '{"name": "John", "age": 25, "email": "[email protected]"}'

result = app(gw_event, {})
assert result["statusCode"] == 200

response_body = json.loads(result["body"])
assert response_body["created"]["name"] == "John"
assert response_body["created"]["age"] == 25
assert response_body["created"]["email"] == "[email protected]"

gw_event["body"] = '{"name": "John", "age": 16, "email": "[email protected]"}'

result = app(gw_event, {})
assert result["statusCode"] == 422
Loading