1
1
from __future__ import annotations
2
2
3
+ import sys
3
4
from dataclasses import is_dataclass
4
- from typing import TYPE_CHECKING , Any
5
+ from types import UnionType
6
+ from typing import TYPE_CHECKING , Annotated , Any , Optional , Union , get_args , get_origin
5
7
6
8
from aws_lambda_powertools .utilities .kafka .serialization .custom_dict import CustomDictOutputSerializer
7
9
from aws_lambda_powertools .utilities .kafka .serialization .dataclass import DataclassOutputSerializer
@@ -17,10 +19,6 @@ def _get_output_serializer(output: type[T] | Callable | None = None) -> Any:
17
19
Returns the appropriate serializer for the given output class.
18
20
Uses lazy imports to avoid unnecessary dependencies.
19
21
"""
20
- if output is None :
21
- # Return a pass-through serializer if no output class is specified
22
- return CustomDictOutputSerializer ()
23
-
24
22
# Check if it's a dataclass
25
23
if is_dataclass (output ):
26
24
return DataclassOutputSerializer ()
@@ -40,6 +38,14 @@ def _is_pydantic_model(obj: Any) -> bool:
40
38
has_model_fields = getattr (obj , "model_fields" , None ) is not None
41
39
has_model_validate = callable (getattr (obj , "model_validate" , None ))
42
40
return has_model_fields and has_model_validate
41
+
42
+ origin = get_origin (obj )
43
+ if origin in (Union , Optional , Annotated ) or (sys .version_info >= (3 , 10 ) and origin in (Union , UnionType )):
44
+ # Check if any element in the Union is a Pydantic model
45
+ for arg in get_args (obj ):
46
+ if _is_pydantic_model (arg ):
47
+ return True
48
+
43
49
return False
44
50
45
51
0 commit comments