diff --git a/chia/_tests/core/util/test_streamable.py b/chia/_tests/core/util/test_streamable.py index b75aaddfe867..665491d9793a 100644 --- a/chia/_tests/core/util/test_streamable.py +++ b/chia/_tests/core/util/test_streamable.py @@ -3,6 +3,7 @@ import io import re from dataclasses import dataclass, field, fields +from enum import Enum from typing import Any, Callable, ClassVar, Optional, get_type_hints import pytest @@ -27,6 +28,7 @@ function_to_parse_one_item, function_to_stream_one_item, is_type_Dict, + is_type_Enum, is_type_List, is_type_SpecificOptional, is_type_Tuple, @@ -39,6 +41,7 @@ parse_uint32, recurse_jsonify, streamable, + streamable_enum, streamable_from_dict, write_uint32, ) @@ -376,6 +379,25 @@ def test_basic_optional() -> None: assert not is_type_SpecificOptional(list[int]) +class BasicEnum(Enum): + A = 1 + B = 2 + + +def test_basic_enum() -> None: + assert is_type_Enum(BasicEnum) + assert not is_type_Enum(list[int]) + + +def test_enum_needs_proxy() -> None: + with pytest.raises(UnsupportedType): + + @streamable + @dataclass(frozen=True) + class EnumStreamable(Streamable): + enum: BasicEnum + + @streamable @dataclass(frozen=True) class PostInitTestClassBasic(Streamable): @@ -423,6 +445,25 @@ class PostInitTestClassDict(Streamable): b: dict[bytes32, dict[uint8, str]] +@streamable_enum(uint32) +class IntegerEnum(Enum): + A = 1 + B = 2 + + +@streamable_enum(str) +class StringEnum(Enum): + A = "foo" + B = "bar" + + +@streamable +@dataclass(frozen=True) +class PostInitTestClassEnum(Streamable): + a: IntegerEnum + b: StringEnum + + @pytest.mark.parametrize( "test_class, args", [ @@ -433,6 +474,7 @@ class PostInitTestClassDict(Streamable): (PostInitTestClassTuple, ((1, "test"), ((200, "test_2"), b"\xba" * 32))), (PostInitTestClassDict, ({1: "bar"}, {bytes32.zeros: {1: "bar"}})), (PostInitTestClassOptional, (12, None, 13, None)), + (PostInitTestClassEnum, (IntegerEnum.A, StringEnum.B)), ], ) def test_post_init_valid(test_class: type[Any], args: tuple[Any, ...]) -> None: @@ -453,6 +495,8 @@ def validate_item_type(type_in: type[Any], item: object) -> bool: return validate_item_type(key_type, next(iter(item.keys()))) and validate_item_type( value_type, next(iter(item.values())) ) + if is_type_Enum(type_in): + return validate_item_type(type_in._streamable_proxy, type_in._streamable_proxy(item.value)) # type: ignore[attr-defined] return isinstance(item, type_in) test_object = test_class(*args) @@ -497,6 +541,8 @@ class TestClass(Streamable): f: Optional[uint32] g: tuple[uint32, str, bytes] h: dict[uint32, str] + i: IntegerEnum + j: StringEnum # we want to test invalid here, hence the ignore. a = TestClass( @@ -508,6 +554,8 @@ class TestClass(Streamable): None, (uint32(383), "hello", b"goodbye"), {uint32(1): "foo"}, + IntegerEnum.A, + StringEnum.B, ) b: bytes = bytes(a) @@ -619,10 +667,21 @@ class TestClassUint(Streamable): a: uint32 # Does not have the required uint size - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("uint32.from_bytes() requires 4 bytes but got: 2")): TestClassUint.from_bytes(b"\x00\x00") +def test_ambiguous_deserialization_int_enum() -> None: + @streamable + @dataclass(frozen=True) + class TestClassIntegerEnum(Streamable): + a: IntegerEnum + + # passed bytes are incorrect size for serialization proxy + with pytest.raises(ValueError, match=re.escape("uint32.from_bytes() requires 4 bytes but got: 2")): + TestClassIntegerEnum.from_bytes(b"\x00\x00") + + def test_ambiguous_deserialization_list() -> None: @streamable @dataclass(frozen=True) @@ -656,6 +715,28 @@ class TestClassStr(Streamable): TestClassStr.from_bytes(bytes([0, 0, 100, 24, 52])) +def test_ambiguous_deserialization_str_enum() -> None: + @streamable + @dataclass(frozen=True) + class TestClassStr(Streamable): + a: StringEnum + + # passed bytes are incorrect size for serialization proxy + with pytest.raises(AssertionError): + TestClassStr.from_bytes(bytes([0, 0, 100, 24, 52])) + + +def test_deserialization_to_invalid_enum() -> None: + @streamable + @dataclass(frozen=True) + class TestClassStr(Streamable): + a: StringEnum + + # encodes the string "baz" which is not a valid value for StringEnum + with pytest.raises(ValueError, match=re.escape("'baz' is not a valid StringEnum")): + TestClassStr.from_bytes(bytes([0, 0, 0, 3, 98, 97, 122])) + + def test_ambiguous_deserialization_bytes() -> None: @streamable @dataclass(frozen=True) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index 7d04210a01c1..ccca8ecb5d48 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -8,7 +8,7 @@ import pprint import traceback from collections.abc import Collection -from enum import Enum +from enum import Enum, EnumMeta from typing import TYPE_CHECKING, Any, BinaryIO, Callable, ClassVar, Optional, TypeVar, Union, get_type_hints from chia_rs.sized_bytes import bytes32 @@ -130,6 +130,10 @@ def is_type_Dict(f_type: object) -> bool: return get_origin(f_type) is dict or f_type is dict +def is_type_Enum(f_type: object) -> bool: + return type(f_type) is EnumMeta + + def convert_optional(convert_func: ConvertFunctionType, item: Any) -> Any: if item is None: return None @@ -307,11 +311,10 @@ def recurse_jsonify( val, None, **next_recursion_env ) return new_dict - + elif isinstance(d, Enum): + return next_recursion_step(d.value, None, **next_recursion_env) elif issubclass(type(d), bytes): return f"0x{bytes(d).hex()}" - elif isinstance(d, Enum): - return d.name elif isinstance(d, bool): return d elif isinstance(d, int): @@ -439,6 +442,10 @@ def function_to_parse_one_item(f_type: type[Any]) -> ParseFunctionType: key_parse_inner_type_f = function_to_parse_one_item(inner_types[0]) value_parse_inner_type_f = function_to_parse_one_item(inner_types[1]) return lambda f: parse_dict(f, key_parse_inner_type_f, value_parse_inner_type_f) + if is_type_Enum(f_type): + if not hasattr(f_type, "_streamable_proxy"): + raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.") + return lambda f: f_type(function_to_parse_one_item(f_type._streamable_proxy)(f)) if f_type is str: return parse_str raise UnsupportedType(f"Type {f_type} does not have parse") @@ -529,6 +536,13 @@ def function_to_stream_one_item(f_type: type[Any]) -> StreamFunctionType: key_stream_inner_type_func = function_to_stream_one_item(inner_types[0]) value_stream_inner_type_func = function_to_stream_one_item(inner_types[1]) return lambda item, f: stream_dict(key_stream_inner_type_func, value_stream_inner_type_func, item, f) + elif is_type_Enum(f_type): + if not hasattr(f_type, "_streamable_proxy"): + raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.") + return lambda item, f: function_to_stream_one_item(f_type._streamable_proxy)( + f_type._streamable_proxy(item.value), # type: ignore[attr-defined] + f, + ) elif f_type is str: return stream_str elif f_type is bool: @@ -700,3 +714,15 @@ class UInt32Range(Streamable): class UInt64Range(Streamable): start: uint64 = uint64(0) stop: uint64 = uint64.MAXIMUM + + +_T_Enum = TypeVar("_T_Enum", bound=EnumMeta) + + +def streamable_enum(proxy: type[object]) -> Callable[[_T_Enum], _T_Enum]: + def streamable_enum_wrapper(cls: _T_Enum) -> _T_Enum: + setattr(cls, "_streamable_proxy", proxy) + setattr(cls, "_ignore_", ["_streamable_proxy"]) + return cls + + return streamable_enum_wrapper