Skip to content

Commit b0d6aed

Browse files
Refactoring tests
1 parent 34072ed commit b0d6aed

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

aws_lambda_powertools/utilities/kafka/serialization/serialization.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import sys
34
from dataclasses import is_dataclass
4-
from typing import TYPE_CHECKING, Any
5+
from types import UnionType
6+
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, get_args, get_origin
57

68
from aws_lambda_powertools.utilities.kafka.serialization.custom_dict import CustomDictOutputSerializer
79
from aws_lambda_powertools.utilities.kafka.serialization.dataclass import DataclassOutputSerializer
@@ -17,10 +19,6 @@ def _get_output_serializer(output: type[T] | Callable | None = None) -> Any:
1719
Returns the appropriate serializer for the given output class.
1820
Uses lazy imports to avoid unnecessary dependencies.
1921
"""
20-
if output is None:
21-
# Return a pass-through serializer if no output class is specified
22-
return CustomDictOutputSerializer()
23-
2422
# Check if it's a dataclass
2523
if is_dataclass(output):
2624
return DataclassOutputSerializer()
@@ -40,6 +38,14 @@ def _is_pydantic_model(obj: Any) -> bool:
4038
has_model_fields = getattr(obj, "model_fields", None) is not None
4139
has_model_validate = callable(getattr(obj, "model_validate", None))
4240
return has_model_fields and has_model_validate
41+
42+
origin = get_origin(obj)
43+
if origin in (Union, Optional, Annotated) or (sys.version_info >= (3, 10) and origin in (Union, UnionType)):
44+
# Check if any element in the Union is a Pydantic model
45+
for arg in get_args(obj):
46+
if _is_pydantic_model(arg):
47+
return True
48+
4349
return False
4450

4551

tests/functional/kafka_consumer/_pydantic/test_kafka_consumer_with_pydantic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import base64
22
import json
3-
from typing import Literal, Union
3+
from typing import Annotated, Literal, Union
44

55
import pytest
66
from pydantic import BaseModel, Field
@@ -94,10 +94,9 @@ class UserValueModel2(BaseModel):
9494
name: Literal["Not using"]
9595
email: str
9696

97-
class Model(BaseModel):
98-
name: Union[UserValueModel, UserValueModel2] = Field(discriminator="name")
97+
UnionModel = Annotated[Union[UserValueModel, UserValueModel2], Field(discriminator="name")]
9998

100-
schema_config = SchemaConfig(value_schema_type="JSON", value_output_serializer=UserValueModel)
99+
schema_config = SchemaConfig(value_schema_type="JSON", value_output_serializer=UnionModel)
101100

102101
@kafka_consumer(schema_config=schema_config)
103102
def handler(event: ConsumerRecords, context):

0 commit comments

Comments
 (0)