Skip to content

Commit fb05d17

Browse files
authored
[CHIA-3703] Add Enum support to streamable (#19983)
This PR adds support for `Enum`s to the streamable framework. In order to support an `Enum` in your streamable class you must decorate it with the `streamable_enum` decorator and specify a proxy streamable object to use for serializing the enum's values. Once you've done this, the object can be used in a streamble object.
2 parents 2708a7e + 0f1d690 commit fb05d17

File tree

2 files changed

+112
-5
lines changed

2 files changed

+112
-5
lines changed

chia/_tests/core/util/test_streamable.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import io
44
import re
55
from dataclasses import dataclass, field, fields
6+
from enum import Enum
67
from typing import Any, Callable, ClassVar, Optional, get_type_hints
78

89
import pytest
@@ -27,6 +28,7 @@
2728
function_to_parse_one_item,
2829
function_to_stream_one_item,
2930
is_type_Dict,
31+
is_type_Enum,
3032
is_type_List,
3133
is_type_SpecificOptional,
3234
is_type_Tuple,
@@ -39,6 +41,7 @@
3941
parse_uint32,
4042
recurse_jsonify,
4143
streamable,
44+
streamable_enum,
4245
streamable_from_dict,
4346
write_uint32,
4447
)
@@ -376,6 +379,25 @@ def test_basic_optional() -> None:
376379
assert not is_type_SpecificOptional(list[int])
377380

378381

382+
class BasicEnum(Enum):
383+
A = 1
384+
B = 2
385+
386+
387+
def test_basic_enum() -> None:
388+
assert is_type_Enum(BasicEnum)
389+
assert not is_type_Enum(list[int])
390+
391+
392+
def test_enum_needs_proxy() -> None:
393+
with pytest.raises(UnsupportedType):
394+
395+
@streamable
396+
@dataclass(frozen=True)
397+
class EnumStreamable(Streamable):
398+
enum: BasicEnum
399+
400+
379401
@streamable
380402
@dataclass(frozen=True)
381403
class PostInitTestClassBasic(Streamable):
@@ -423,6 +445,25 @@ class PostInitTestClassDict(Streamable):
423445
b: dict[bytes32, dict[uint8, str]]
424446

425447

448+
@streamable_enum(uint32)
449+
class IntegerEnum(Enum):
450+
A = 1
451+
B = 2
452+
453+
454+
@streamable_enum(str)
455+
class StringEnum(Enum):
456+
A = "foo"
457+
B = "bar"
458+
459+
460+
@streamable
461+
@dataclass(frozen=True)
462+
class PostInitTestClassEnum(Streamable):
463+
a: IntegerEnum
464+
b: StringEnum
465+
466+
426467
@pytest.mark.parametrize(
427468
"test_class, args",
428469
[
@@ -433,6 +474,7 @@ class PostInitTestClassDict(Streamable):
433474
(PostInitTestClassTuple, ((1, "test"), ((200, "test_2"), b"\xba" * 32))),
434475
(PostInitTestClassDict, ({1: "bar"}, {bytes32.zeros: {1: "bar"}})),
435476
(PostInitTestClassOptional, (12, None, 13, None)),
477+
(PostInitTestClassEnum, (IntegerEnum.A, StringEnum.B)),
436478
],
437479
)
438480
def test_post_init_valid(test_class: type[Any], args: tuple[Any, ...]) -> None:
@@ -453,6 +495,8 @@ def validate_item_type(type_in: type[Any], item: object) -> bool:
453495
return validate_item_type(key_type, next(iter(item.keys()))) and validate_item_type(
454496
value_type, next(iter(item.values()))
455497
)
498+
if is_type_Enum(type_in):
499+
return validate_item_type(type_in._streamable_proxy, type_in._streamable_proxy(item.value)) # type: ignore[attr-defined]
456500
return isinstance(item, type_in)
457501

458502
test_object = test_class(*args)
@@ -497,6 +541,8 @@ class TestClass(Streamable):
497541
f: Optional[uint32]
498542
g: tuple[uint32, str, bytes]
499543
h: dict[uint32, str]
544+
i: IntegerEnum
545+
j: StringEnum
500546

501547
# we want to test invalid here, hence the ignore.
502548
a = TestClass(
@@ -508,6 +554,8 @@ class TestClass(Streamable):
508554
None,
509555
(uint32(383), "hello", b"goodbye"),
510556
{uint32(1): "foo"},
557+
IntegerEnum.A,
558+
StringEnum.B,
511559
)
512560

513561
b: bytes = bytes(a)
@@ -619,10 +667,21 @@ class TestClassUint(Streamable):
619667
a: uint32
620668

621669
# Does not have the required uint size
622-
with pytest.raises(ValueError):
670+
with pytest.raises(ValueError, match=re.escape("uint32.from_bytes() requires 4 bytes but got: 2")):
623671
TestClassUint.from_bytes(b"\x00\x00")
624672

625673

674+
def test_ambiguous_deserialization_int_enum() -> None:
675+
@streamable
676+
@dataclass(frozen=True)
677+
class TestClassIntegerEnum(Streamable):
678+
a: IntegerEnum
679+
680+
# passed bytes are incorrect size for serialization proxy
681+
with pytest.raises(ValueError, match=re.escape("uint32.from_bytes() requires 4 bytes but got: 2")):
682+
TestClassIntegerEnum.from_bytes(b"\x00\x00")
683+
684+
626685
def test_ambiguous_deserialization_list() -> None:
627686
@streamable
628687
@dataclass(frozen=True)
@@ -656,6 +715,28 @@ class TestClassStr(Streamable):
656715
TestClassStr.from_bytes(bytes([0, 0, 100, 24, 52]))
657716

658717

718+
def test_ambiguous_deserialization_str_enum() -> None:
719+
@streamable
720+
@dataclass(frozen=True)
721+
class TestClassStr(Streamable):
722+
a: StringEnum
723+
724+
# passed bytes are incorrect size for serialization proxy
725+
with pytest.raises(AssertionError):
726+
TestClassStr.from_bytes(bytes([0, 0, 100, 24, 52]))
727+
728+
729+
def test_deserialization_to_invalid_enum() -> None:
730+
@streamable
731+
@dataclass(frozen=True)
732+
class TestClassStr(Streamable):
733+
a: StringEnum
734+
735+
# encodes the string "baz" which is not a valid value for StringEnum
736+
with pytest.raises(ValueError, match=re.escape("'baz' is not a valid StringEnum")):
737+
TestClassStr.from_bytes(bytes([0, 0, 0, 3, 98, 97, 122]))
738+
739+
659740
def test_ambiguous_deserialization_bytes() -> None:
660741
@streamable
661742
@dataclass(frozen=True)

chia/util/streamable.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pprint
99
import traceback
1010
from collections.abc import Collection
11-
from enum import Enum
11+
from enum import Enum, EnumMeta
1212
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, ClassVar, Optional, TypeVar, Union, get_type_hints
1313

1414
from chia_rs.sized_bytes import bytes32
@@ -130,6 +130,10 @@ def is_type_Dict(f_type: object) -> bool:
130130
return get_origin(f_type) is dict or f_type is dict
131131

132132

133+
def is_type_Enum(f_type: object) -> bool:
134+
return type(f_type) is EnumMeta
135+
136+
133137
def convert_optional(convert_func: ConvertFunctionType, item: Any) -> Any:
134138
if item is None:
135139
return None
@@ -307,11 +311,10 @@ def recurse_jsonify(
307311
val, None, **next_recursion_env
308312
)
309313
return new_dict
310-
314+
elif isinstance(d, Enum):
315+
return next_recursion_step(d.value, None, **next_recursion_env)
311316
elif issubclass(type(d), bytes):
312317
return f"0x{bytes(d).hex()}"
313-
elif isinstance(d, Enum):
314-
return d.name
315318
elif isinstance(d, bool):
316319
return d
317320
elif isinstance(d, int):
@@ -439,6 +442,10 @@ def function_to_parse_one_item(f_type: type[Any]) -> ParseFunctionType:
439442
key_parse_inner_type_f = function_to_parse_one_item(inner_types[0])
440443
value_parse_inner_type_f = function_to_parse_one_item(inner_types[1])
441444
return lambda f: parse_dict(f, key_parse_inner_type_f, value_parse_inner_type_f)
445+
if is_type_Enum(f_type):
446+
if not hasattr(f_type, "_streamable_proxy"):
447+
raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.")
448+
return lambda f: f_type(function_to_parse_one_item(f_type._streamable_proxy)(f))
442449
if f_type is str:
443450
return parse_str
444451
raise UnsupportedType(f"Type {f_type} does not have parse")
@@ -529,6 +536,13 @@ def function_to_stream_one_item(f_type: type[Any]) -> StreamFunctionType:
529536
key_stream_inner_type_func = function_to_stream_one_item(inner_types[0])
530537
value_stream_inner_type_func = function_to_stream_one_item(inner_types[1])
531538
return lambda item, f: stream_dict(key_stream_inner_type_func, value_stream_inner_type_func, item, f)
539+
elif is_type_Enum(f_type):
540+
if not hasattr(f_type, "_streamable_proxy"):
541+
raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.")
542+
return lambda item, f: function_to_stream_one_item(f_type._streamable_proxy)(
543+
f_type._streamable_proxy(item.value), # type: ignore[attr-defined]
544+
f,
545+
)
532546
elif f_type is str:
533547
return stream_str
534548
elif f_type is bool:
@@ -700,3 +714,15 @@ class UInt32Range(Streamable):
700714
class UInt64Range(Streamable):
701715
start: uint64 = uint64(0)
702716
stop: uint64 = uint64.MAXIMUM
717+
718+
719+
_T_Enum = TypeVar("_T_Enum", bound=EnumMeta)
720+
721+
722+
def streamable_enum(proxy: type[object]) -> Callable[[_T_Enum], _T_Enum]:
723+
def streamable_enum_wrapper(cls: _T_Enum) -> _T_Enum:
724+
setattr(cls, "_streamable_proxy", proxy)
725+
setattr(cls, "_ignore_", ["_streamable_proxy"])
726+
return cls
727+
728+
return streamable_enum_wrapper

0 commit comments

Comments
 (0)