Skip to content

Commit d5922a3

Browse files
committed
EM-1209 Test enum on model serialisation
1 parent aa33793 commit d5922a3

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tests/base_model_store_tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
from dataclasses import dataclass, field
55
from datetime import date, datetime
6+
from enum import Enum
67
from typing import Any, Dict, Generator, List, Mapping, Optional, Type, TypedDict, Union, cast
78
from uuid import uuid4
89

@@ -102,6 +103,11 @@ class NestedItem:
102103
timestamp: datetime = field(default_factory=datetime.utcnow)
103104

104105

106+
class SomeEnum(str, Enum):
107+
FIELD_ONE = "one"
108+
FIELD_TWO = "two"
109+
110+
105111
@dataclass
106112
class MyDerivedModel(MyBaseModel):
107113
id: str
@@ -121,6 +127,7 @@ class MyDerivedModel(MyBaseModel):
121127
union_date: Union[datetime, None] = field(default_factory=datetime.utcnow)
122128
bytes_type: Union[bytes, None] = None
123129
bytearray_type: Union[bytearray, None] = None
130+
some_enum: SomeEnum = field(default_factory=lambda: SomeEnum.FIELD_ONE)
124131

125132
def get_key(self) -> _MyModelKey:
126133
return _MyModelKey(my_pk=f"AA#{self.id}", my_sk=self.sk_field or "#")
@@ -489,6 +496,9 @@ async def test_serialize_deserialize_model(store: MyModelStore):
489496
assert isinstance(serialized["nested_item"], dict)
490497
assert serialized["nested_item"]["event"] == "created"
491498

499+
assert isinstance(serialized["some_enum"], str)
500+
assert serialized["some_enum"] == "one"
501+
492502
assert model.none_string is None
493503
assert "none_thing" not in serialized
494504
assert model.none_list is None
@@ -505,6 +515,8 @@ async def test_serialize_deserialize_model(store: MyModelStore):
505515
assert deserialized.nested_item.event == "created"
506516
assert deserialized.last_modified == model.last_modified
507517
assert deserialized.today == model.today
518+
assert isinstance(deserialized.some_enum, SomeEnum)
519+
assert deserialized.some_enum == SomeEnum.FIELD_ONE
508520

509521

510522
async def test_transact_get_put_model(store: MyModelStore):

0 commit comments

Comments
 (0)