From f6ca987403a11f5a4c32b19ba3a41cff85b21d45 Mon Sep 17 00:00:00 2001 From: dap0am Date: Thu, 21 Aug 2025 09:27:07 +0100 Subject: [PATCH 1/4] feat(event-handler): add support for Pydantic Field discriminator in validation (#5953) Enable use of Field(discriminator='...') with tagged unions in event handler validation. This allows developers to use Pydantic's native discriminator syntax instead of requiring Powertools-specific Param annotations. - Handle Field(discriminator) + Body() combination in get_field_info_annotated_type - Preserve discriminator metadata when creating TypeAdapter in ModelField - Add comprehensive tests for discriminator validation and Field features --- .../event_handler/openapi/compat.py | 17 +++- .../event_handler/openapi/params.py | 52 ++++++++-- .../test_openapi_validation_middleware.py | 96 ++++++++++++++++++- 3 files changed, 152 insertions(+), 13 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index d3340f34e4b..d9c975e3396 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -80,9 +80,20 @@ def type_(self) -> Any: return self.field_info.annotation def __post_init__(self) -> None: - self._type_adapter: TypeAdapter[Any] = TypeAdapter( - Annotated[self.field_info.annotation, self.field_info], - ) + + # If the field_info.annotation is already an Annotated type with discriminator metadata, + # use it directly instead of wrapping it again + annotation = self.field_info.annotation + if ( + get_origin(annotation) is Annotated + and hasattr(self.field_info, "discriminator") + and self.field_info.discriminator is not None + ): + self._type_adapter: TypeAdapter[Any] = TypeAdapter(annotation) + else: + self._type_adapter: TypeAdapter[Any] = TypeAdapter( + Annotated[annotation, self.field_info], + ) def get_default(self) -> Any: if self.field_info.is_required(): diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 8fc8d0becfa..3743ac0eff7 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1046,17 +1046,47 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup type_annotation = annotated_args[0] powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] - if len(powertools_annotations) > 1: + # Special case: handle Field(discriminator) + Body() combination + # This happens when using Annotated[Union[A, B], Field(discriminator='...')] with Body() + has_discriminator_with_body = False + powertools_annotation: FieldInfo | None = None + + if len(powertools_annotations) == 2: + field_obj = None + body_obj = None + for ann in powertools_annotations: + if isinstance(ann, Body): + body_obj = ann + elif isinstance(ann, FieldInfo) and hasattr(ann, "discriminator") and ann.discriminator is not None: + field_obj = ann + + if field_obj and body_obj: + # Use Body as the primary annotation + powertools_annotation = body_obj + # Preserve the full annotation including discriminator for proper validation + # This ensures the discriminator is available when creating the TypeAdapter + type_annotation = annotation + has_discriminator_with_body = True + else: + raise AssertionError("Only one FieldInfo can be used per parameter") + elif len(powertools_annotations) > 1: raise AssertionError("Only one FieldInfo can be used per parameter") - - powertools_annotation = next(iter(powertools_annotations), None) + else: + powertools_annotation = next(iter(powertools_annotations), None) if isinstance(powertools_annotation, FieldInfo): - # Copy `field_info` because we mutate `field_info.default` later - field_info = copy_field_info( - field_info=powertools_annotation, - annotation=annotation, - ) + if has_discriminator_with_body: + # For discriminator + Body case, create a new Body instance directly + # This avoids issues with copy_field_info trying to process the Field + field_info = Body() + field_info.annotation = type_annotation + else: + # Copy `field_info` because we mutate `field_info.default` later + # Use the possibly modified type_annotation for copy_field_info + field_info = copy_field_info( + field_info=powertools_annotation, + annotation=type_annotation, + ) if field_info.default not in [Undefined, Required]: raise AssertionError("FieldInfo needs to have a default value of Undefined or Required") @@ -1067,6 +1097,12 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup else: field_info.default = Required + # Preserve the full annotated type if it contains discriminator metadata + # This is crucial for tagged unions to work properly + if hasattr(powertools_annotation, "discriminator") and powertools_annotation.discriminator is not None: + # Store the full annotated type for discriminated unions + type_annotation = annotation + return field_info, type_annotation diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 1fd919b7b71..9c07a7313ad 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from enum import Enum from pathlib import PurePath -from typing import List, Optional, Tuple +from typing import List, Literal, Optional, Tuple import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Annotated from aws_lambda_powertools.event_handler import ( @@ -1983,3 +1983,95 @@ def get_user(user_id: int) -> UserModel: assert response_body["name"] == "User123" assert response_body["age"] == 143 assert response_body["email"] == "user123@example.com" + + +def test_field_discriminator_validation(gw_event): + """Test that Pydantic Field discriminator works with event_handler validation""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class FooAction(BaseModel): + action: Literal["foo"] + foo_data: str + + class BarAction(BaseModel): + action: Literal["bar"] + bar_data: int + + # This should work with Field discriminator (issue #5953) + Action = Annotated[FooAction | BarAction, Field(discriminator="action")] + + @app.post("/actions") + def create_action(action: Annotated[Action, Body()]): + return {"received_action": action.action, "data": action.model_dump()} + + # WHEN sending a valid foo action + gw_event["path"] = "/actions" + gw_event["httpMethod"] = "POST" + gw_event["headers"]["content-type"] = "application/json" + gw_event["body"] = '{"action": "foo", "foo_data": "test"}' + + # THEN the handler should be invoked and return 200 + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + response_body = json.loads(result["body"]) + assert response_body["received_action"] == "foo" + assert response_body["data"]["action"] == "foo" + assert response_body["data"]["foo_data"] == "test" + + # WHEN sending a valid bar action + gw_event["body"] = '{"action": "bar", "bar_data": 123}' + + # THEN the handler should be invoked and return 200 + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + response_body = json.loads(result["body"]) + assert response_body["received_action"] == "bar" + assert response_body["data"]["action"] == "bar" + assert response_body["data"]["bar_data"] == 123 + + # WHEN sending an invalid discriminator + gw_event["body"] = '{"action": "invalid", "some_data": "test"}' + + # THEN the handler should return 422 (validation error) + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + +def test_field_other_features_still_work(gw_event): + """Test that other Field features still work after discriminator fix""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class UserInput(BaseModel): + name: Annotated[str, Field(min_length=2, max_length=50, description="User name")] + age: Annotated[int, Field(ge=18, le=120, description="User age")] + email: Annotated[str, Field(pattern=r".+@.+\..+", description="User email")] + + @app.post("/users") + def create_user(user: UserInput): + return {"created": user.model_dump()} + + # WHEN sending valid data + gw_event["path"] = "/users" + gw_event["httpMethod"] = "POST" + gw_event["headers"]["content-type"] = "application/json" + gw_event["body"] = '{"name": "John", "age": 25, "email": "john@example.com"}' + + # THEN the handler should return 200 + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + response_body = json.loads(result["body"]) + assert response_body["created"]["name"] == "John" + assert response_body["created"]["age"] == 25 + assert response_body["created"]["email"] == "john@example.com" + + # WHEN sending data with validation error (age too low) + gw_event["body"] = '{"name": "John", "age": 16, "email": "john@example.com"}' + + # THEN the handler should return 422 (validation error) + result = app(gw_event, {}) + assert result["statusCode"] == 422 From 63a0225f147c2abd9ad2ded9d0e2fbbedf32cea9 Mon Sep 17 00:00:00 2001 From: dap0am Date: Thu, 21 Aug 2025 09:43:04 +0100 Subject: [PATCH 2/4] style(tests): remove inline comments to match project test style --- .../_pydantic/test_openapi_validation_middleware.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 9c07a7313ad..dc55cd55772 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1987,7 +1987,6 @@ def get_user(user_id: int) -> UserModel: def test_field_discriminator_validation(gw_event): """Test that Pydantic Field discriminator works with event_handler validation""" - # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) class FooAction(BaseModel): @@ -1998,20 +1997,17 @@ class BarAction(BaseModel): action: Literal["bar"] bar_data: int - # This should work with Field discriminator (issue #5953) Action = Annotated[FooAction | BarAction, Field(discriminator="action")] @app.post("/actions") def create_action(action: Annotated[Action, Body()]): return {"received_action": action.action, "data": action.model_dump()} - # WHEN sending a valid foo action gw_event["path"] = "/actions" gw_event["httpMethod"] = "POST" gw_event["headers"]["content-type"] = "application/json" gw_event["body"] = '{"action": "foo", "foo_data": "test"}' - # THEN the handler should be invoked and return 200 result = app(gw_event, {}) assert result["statusCode"] == 200 @@ -2020,10 +2016,8 @@ def create_action(action: Annotated[Action, Body()]): assert response_body["data"]["action"] == "foo" assert response_body["data"]["foo_data"] == "test" - # WHEN sending a valid bar action gw_event["body"] = '{"action": "bar", "bar_data": 123}' - # THEN the handler should be invoked and return 200 result = app(gw_event, {}) assert result["statusCode"] == 200 @@ -2032,17 +2026,14 @@ def create_action(action: Annotated[Action, Body()]): assert response_body["data"]["action"] == "bar" assert response_body["data"]["bar_data"] == 123 - # WHEN sending an invalid discriminator gw_event["body"] = '{"action": "invalid", "some_data": "test"}' - # THEN the handler should return 422 (validation error) result = app(gw_event, {}) assert result["statusCode"] == 422 def test_field_other_features_still_work(gw_event): """Test that other Field features still work after discriminator fix""" - # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) class UserInput(BaseModel): @@ -2054,13 +2045,11 @@ class UserInput(BaseModel): def create_user(user: UserInput): return {"created": user.model_dump()} - # WHEN sending valid data gw_event["path"] = "/users" gw_event["httpMethod"] = "POST" gw_event["headers"]["content-type"] = "application/json" gw_event["body"] = '{"name": "John", "age": 25, "email": "john@example.com"}' - # THEN the handler should return 200 result = app(gw_event, {}) assert result["statusCode"] == 200 @@ -2069,9 +2058,7 @@ def create_user(user: UserInput): assert response_body["created"]["age"] == 25 assert response_body["created"]["email"] == "john@example.com" - # WHEN sending data with validation error (age too low) gw_event["body"] = '{"name": "John", "age": 16, "email": "john@example.com"}' - # THEN the handler should return 422 (validation error) result = app(gw_event, {}) assert result["statusCode"] == 422 From 0023b3ac15a63538abb310373865d264aba52308 Mon Sep 17 00:00:00 2001 From: dap0am Date: Thu, 21 Aug 2025 10:39:39 +0100 Subject: [PATCH 3/4] style: run make format to fix CI formatting issues --- aws_lambda_powertools/event_handler/openapi/compat.py | 1 - aws_lambda_powertools/event_handler/openapi/params.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index d9c975e3396..74945748921 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -80,7 +80,6 @@ def type_(self) -> Any: return self.field_info.annotation def __post_init__(self) -> None: - # If the field_info.annotation is already an Annotated type with discriminator metadata, # use it directly instead of wrapping it again annotation = self.field_info.annotation diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 3743ac0eff7..0d19928020c 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1050,7 +1050,7 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup # This happens when using Annotated[Union[A, B], Field(discriminator='...')] with Body() has_discriminator_with_body = False powertools_annotation: FieldInfo | None = None - + if len(powertools_annotations) == 2: field_obj = None body_obj = None From ead3ee8563476190eddb066cbb1e8c34e6d2fc6b Mon Sep 17 00:00:00 2001 From: dap0am Date: Wed, 3 Sep 2025 13:30:55 +0100 Subject: [PATCH 4/4] fix(event-handler): preserve FieldInfo subclass types in copy_field_info Fix regression where copy_field_info was losing custom FieldInfo subclass types (Body, Query, etc.) by using shallow copy instead of from_annotation. This resolves the failing test_validate_embed_body_param while maintaining the discriminator functionality. --- aws_lambda_powertools/event_handler/openapi/compat.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 74945748921..af5c2d5bc87 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -186,7 +186,13 @@ def model_rebuild(model: type[BaseModel]) -> None: def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: - return type(field_info).from_annotation(annotation) + # Create a shallow copy of the field_info to preserve its type and all attributes + import copy + + new_field = copy.copy(field_info) + # Update only the annotation to the new one + new_field.annotation = annotation + return new_field def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]: