diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index d3340f34e4b..af5c2d5bc87 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -80,9 +80,19 @@ def type_(self) -> Any: return self.field_info.annotation def __post_init__(self) -> None: - self._type_adapter: TypeAdapter[Any] = TypeAdapter( - Annotated[self.field_info.annotation, self.field_info], - ) + # If the field_info.annotation is already an Annotated type with discriminator metadata, + # use it directly instead of wrapping it again + annotation = self.field_info.annotation + if ( + get_origin(annotation) is Annotated + and hasattr(self.field_info, "discriminator") + and self.field_info.discriminator is not None + ): + self._type_adapter: TypeAdapter[Any] = TypeAdapter(annotation) + else: + self._type_adapter: TypeAdapter[Any] = TypeAdapter( + Annotated[annotation, self.field_info], + ) def get_default(self) -> Any: if self.field_info.is_required(): @@ -176,7 +186,13 @@ def model_rebuild(model: type[BaseModel]) -> None: def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: - return type(field_info).from_annotation(annotation) + # Create a shallow copy of the field_info to preserve its type and all attributes + import copy + + new_field = copy.copy(field_info) + # Update only the annotation to the new one + new_field.annotation = annotation + return new_field def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]: diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 8fc8d0becfa..0d19928020c 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1046,17 +1046,47 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup type_annotation = annotated_args[0] powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] - if len(powertools_annotations) > 1: + # Special case: handle Field(discriminator) + Body() combination + # This happens when using Annotated[Union[A, B], Field(discriminator='...')] with Body() + has_discriminator_with_body = False + powertools_annotation: FieldInfo | None = None + + if len(powertools_annotations) == 2: + field_obj = None + body_obj = None + for ann in powertools_annotations: + if isinstance(ann, Body): + body_obj = ann + elif isinstance(ann, FieldInfo) and hasattr(ann, "discriminator") and ann.discriminator is not None: + field_obj = ann + + if field_obj and body_obj: + # Use Body as the primary annotation + powertools_annotation = body_obj + # Preserve the full annotation including discriminator for proper validation + # This ensures the discriminator is available when creating the TypeAdapter + type_annotation = annotation + has_discriminator_with_body = True + else: + raise AssertionError("Only one FieldInfo can be used per parameter") + elif len(powertools_annotations) > 1: raise AssertionError("Only one FieldInfo can be used per parameter") - - powertools_annotation = next(iter(powertools_annotations), None) + else: + powertools_annotation = next(iter(powertools_annotations), None) if isinstance(powertools_annotation, FieldInfo): - # Copy `field_info` because we mutate `field_info.default` later - field_info = copy_field_info( - field_info=powertools_annotation, - annotation=annotation, - ) + if has_discriminator_with_body: + # For discriminator + Body case, create a new Body instance directly + # This avoids issues with copy_field_info trying to process the Field + field_info = Body() + field_info.annotation = type_annotation + else: + # Copy `field_info` because we mutate `field_info.default` later + # Use the possibly modified type_annotation for copy_field_info + field_info = copy_field_info( + field_info=powertools_annotation, + annotation=type_annotation, + ) if field_info.default not in [Undefined, Required]: raise AssertionError("FieldInfo needs to have a default value of Undefined or Required") @@ -1067,6 +1097,12 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup else: field_info.default = Required + # Preserve the full annotated type if it contains discriminator metadata + # This is crucial for tagged unions to work properly + if hasattr(powertools_annotation, "discriminator") and powertools_annotation.discriminator is not None: + # Store the full annotated type for discriminated unions + type_annotation = annotation + return field_info, type_annotation diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 1fd919b7b71..dc55cd55772 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from enum import Enum from pathlib import PurePath -from typing import List, Optional, Tuple +from typing import List, Literal, Optional, Tuple import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Annotated from aws_lambda_powertools.event_handler import ( @@ -1983,3 +1983,82 @@ def get_user(user_id: int) -> UserModel: assert response_body["name"] == "User123" assert response_body["age"] == 143 assert response_body["email"] == "user123@example.com" + + +def test_field_discriminator_validation(gw_event): + """Test that Pydantic Field discriminator works with event_handler validation""" + app = APIGatewayRestResolver(enable_validation=True) + + class FooAction(BaseModel): + action: Literal["foo"] + foo_data: str + + class BarAction(BaseModel): + action: Literal["bar"] + bar_data: int + + Action = Annotated[FooAction | BarAction, Field(discriminator="action")] + + @app.post("/actions") + def create_action(action: Annotated[Action, Body()]): + return {"received_action": action.action, "data": action.model_dump()} + + gw_event["path"] = "/actions" + gw_event["httpMethod"] = "POST" + gw_event["headers"]["content-type"] = "application/json" + gw_event["body"] = '{"action": "foo", "foo_data": "test"}' + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + response_body = json.loads(result["body"]) + assert response_body["received_action"] == "foo" + assert response_body["data"]["action"] == "foo" + assert response_body["data"]["foo_data"] == "test" + + gw_event["body"] = '{"action": "bar", "bar_data": 123}' + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + response_body = json.loads(result["body"]) + assert response_body["received_action"] == "bar" + assert response_body["data"]["action"] == "bar" + assert response_body["data"]["bar_data"] == 123 + + gw_event["body"] = '{"action": "invalid", "some_data": "test"}' + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + +def test_field_other_features_still_work(gw_event): + """Test that other Field features still work after discriminator fix""" + app = APIGatewayRestResolver(enable_validation=True) + + class UserInput(BaseModel): + name: Annotated[str, Field(min_length=2, max_length=50, description="User name")] + age: Annotated[int, Field(ge=18, le=120, description="User age")] + email: Annotated[str, Field(pattern=r".+@.+\..+", description="User email")] + + @app.post("/users") + def create_user(user: UserInput): + return {"created": user.model_dump()} + + gw_event["path"] = "/users" + gw_event["httpMethod"] = "POST" + gw_event["headers"]["content-type"] = "application/json" + gw_event["body"] = '{"name": "John", "age": 25, "email": "john@example.com"}' + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + response_body = json.loads(result["body"]) + assert response_body["created"]["name"] == "John" + assert response_body["created"]["age"] == 25 + assert response_body["created"]["email"] == "john@example.com" + + gw_event["body"] = '{"name": "John", "age": 16, "email": "john@example.com"}' + + result = app(gw_event, {}) + assert result["statusCode"] == 422