diff --git a/beanie/odm/utils/encoder.py b/beanie/odm/utils/encoder.py index d8ab3fd53..f53cbe2be 100644 --- a/beanie/odm/utils/encoder.py +++ b/beanie/odm/utils/encoder.py @@ -159,7 +159,11 @@ def _iter_model_items( for key, value in obj.__iter__(): field_info = get_model_field(key) if field_info is not None: - key = field_info.alias or key + key = ( + getattr(field_info, "serialization_alias", None) + or field_info.alias + or key + ) if key not in exclude and (value is not None or keep_nulls): yield key, value diff --git a/beanie/odm/utils/init.py b/beanie/odm/utils/init.py index 3ba8b4561..c5aafde41 100644 --- a/beanie/odm/utils/init.py +++ b/beanie/odm/utils/init.py @@ -403,7 +403,7 @@ def init_document_fields(self, cls) -> None: if cls._link_fields is None: cls._link_fields = {} for k, v in get_model_fields(cls).items(): - path = v.alias or k + path = getattr(v, "serialization_alias", None) or v.alias or k setattr(cls, k, ExpressionField(path)) link_info = self.detect_link(v, k) @@ -516,7 +516,9 @@ async def init_indexes(self, cls, allow_index_dropping: bool = False): IndexModel( [ ( - fvalue.alias or k, + getattr(fvalue, "serialization_alias", None) + or fvalue.alias + or k, indexed_attrs[0], ) ], @@ -639,7 +641,7 @@ def init_view_fields(self, cls) -> None: if cls._link_fields is None: cls._link_fields = {} for k, v in get_model_fields(cls).items(): - path = v.alias or k + path = getattr(v, "serialization_alias", None) or v.alias or k setattr(cls, k, ExpressionField(path)) link_info = self.detect_link(v, k) depth_level = cls.get_settings().max_nesting_depths_per_field.get( diff --git a/beanie/odm/utils/projection.py b/beanie/odm/utils/projection.py index 3be3cc31d..c9e9830f8 100644 --- a/beanie/odm/utils/projection.py +++ b/beanie/odm/utils/projection.py @@ -32,5 +32,7 @@ def get_projection( document_projection: Dict[str, int] = {} for name, field in get_model_fields(model).items(): - document_projection[field.alias or name] = 1 + document_projection[ + getattr(field, "serialization_alias", None) or field.alias or name + ] = 1 return document_projection diff --git a/beanie/odm/utils/pydantic.py b/beanie/odm/utils/pydantic.py index 820b70bd2..1cb19facd 100644 --- a/beanie/odm/utils/pydantic.py +++ b/beanie/odm/utils/pydantic.py @@ -37,6 +37,10 @@ def get_model_fields(model): def parse_model(model_type: Type[BaseModel], data: Any): if IS_PYDANTIC_V2: + for k, field in get_model_fields(model_type).items(): + if field.alias and field.alias != field.serialization_alias: + data[k] = data[field.serialization_alias] + del data[field.serialization_alias] return model_type.model_validate(data) else: return model_type.parse_obj(data) diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index c31c29cb7..3d0245a56 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -29,6 +29,7 @@ DocumentTestModelWithIndexFlagsAliases, DocumentTestModelWithLink, DocumentTestModelWithModelConfigExtraAllow, + DocumentTestModelWithSerializationAlias, DocumentTestModelWithSimpleIndex, DocumentTestModelWithSoftDelete, DocumentToBeLinked, @@ -115,6 +116,7 @@ DocumentWithExtras, DocumentWithPydanticConfig, DocumentTestModel, + DocumentTestModelWithSerializationAlias, DocumentTestModelWithSoftDelete, DocumentTestModelWithLink, DocumentTestModelWithCustomCollectionName, diff --git a/tests/odm/models.py b/tests/odm/models.py index be9731254..ed64c4340 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -174,6 +174,11 @@ class Settings: use_state_management = True +class DocumentTestModelWithSerializationAlias(Document): + test_int: int = Field(serialization_alias="test_int_serialize") + test_str: str = Field(serialization_alias="test_str_serialize") + + class DocumentTestModelWithLink(Document): test_link: Link[DocumentTestModel] diff --git a/tests/odm/test_beanie_serialization.py b/tests/odm/test_beanie_serialization.py new file mode 100644 index 000000000..474946cb0 --- /dev/null +++ b/tests/odm/test_beanie_serialization.py @@ -0,0 +1,27 @@ +import pytest + +from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 +from tests.odm.models import DocumentTestModelWithSerializationAlias + + +def data_maker(): + return DocumentTestModelWithSerializationAlias(test_int=1, test_str="test") + + +@pytest.mark.skipif( + not IS_PYDANTIC_V2, + reason="model serialization_alias is not supported in pydantic V1", +) +async def test_serialization_types_preserved_after_insertion(): + result = await DocumentTestModelWithSerializationAlias.insert_one( + data_maker() + ) + document = await DocumentTestModelWithSerializationAlias.get(result.id) + assert document is not None + assert document.test_int is not None + assert document.test_str is not None + dumped = document.model_dump(by_alias=True) + assert "test_int_serialize" in dumped + assert "test_str_serialize" in dumped + assert "test_int" not in dumped + assert "test_str" not in dumped