From 9af0b7c97ccf9b27a6fb802445d5647d4d24d82d Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Aug 2025 15:00:07 -0700 Subject: [PATCH 1/6] Add Enum support to streamable --- chia/_tests/core/util/test_streamable.py | 70 ++++++++++++++++++++++++ chia/util/streamable.py | 45 +++++++++++++-- 2 files changed, 111 insertions(+), 4 deletions(-) diff --git a/chia/_tests/core/util/test_streamable.py b/chia/_tests/core/util/test_streamable.py index b75aaddfe867..2c722dd440c0 100644 --- a/chia/_tests/core/util/test_streamable.py +++ b/chia/_tests/core/util/test_streamable.py @@ -3,6 +3,7 @@ import io import re from dataclasses import dataclass, field, fields +from enum import Enum from typing import Any, Callable, ClassVar, Optional, get_type_hints import pytest @@ -27,6 +28,7 @@ function_to_parse_one_item, function_to_stream_one_item, is_type_Dict, + is_type_Enum, is_type_List, is_type_SpecificOptional, is_type_Tuple, @@ -39,6 +41,7 @@ parse_uint32, recurse_jsonify, streamable, + streamable_enum, streamable_from_dict, write_uint32, ) @@ -376,6 +379,25 @@ def test_basic_optional() -> None: assert not is_type_SpecificOptional(list[int]) +class BasicEnum(Enum): + A = 1 + B = 2 + + +def test_basic_enum() -> None: + assert is_type_Enum(BasicEnum) + assert not is_type_Enum(list[int]) + + +def test_enum_needs_proxy() -> None: + with pytest.raises(UnsupportedType): + + @streamable + @dataclass(frozen=True) + class EnumStreamable(Streamable): + enum: BasicEnum + + @streamable @dataclass(frozen=True) class PostInitTestClassBasic(Streamable): @@ -423,6 +445,25 @@ class PostInitTestClassDict(Streamable): b: dict[bytes32, dict[uint8, str]] +@streamable_enum(uint32) +class IntEnum(Enum): + A = 1 + B = 2 + + +@streamable_enum(str) +class StrEnum(Enum): + A = "foo" + B = "bar" + + +@streamable +@dataclass(frozen=True) +class PostInitTestClassEnum(Streamable): + a: IntEnum + b: StrEnum + + @pytest.mark.parametrize( "test_class, args", [ @@ -433,6 +474,7 @@ class PostInitTestClassDict(Streamable): (PostInitTestClassTuple, ((1, "test"), ((200, "test_2"), b"\xba" * 32))), (PostInitTestClassDict, ({1: "bar"}, {bytes32.zeros: {1: "bar"}})), (PostInitTestClassOptional, (12, None, 13, None)), + (PostInitTestClassEnum, (IntEnum.A, StrEnum.B)), ], ) def test_post_init_valid(test_class: type[Any], args: tuple[Any, ...]) -> None: @@ -453,6 +495,8 @@ def validate_item_type(type_in: type[Any], item: object) -> bool: return validate_item_type(key_type, next(iter(item.keys()))) and validate_item_type( value_type, next(iter(item.values())) ) + if is_type_Enum(type_in): + return validate_item_type(type_in._streamable_proxy, item) return isinstance(item, type_in) test_object = test_class(*args) @@ -497,6 +541,8 @@ class TestClass(Streamable): f: Optional[uint32] g: tuple[uint32, str, bytes] h: dict[uint32, str] + i: IntEnum + j: StrEnum # we want to test invalid here, hence the ignore. a = TestClass( @@ -508,6 +554,8 @@ class TestClass(Streamable): None, (uint32(383), "hello", b"goodbye"), {uint32(1): "foo"}, + IntEnum.A, + StrEnum.B, ) b: bytes = bytes(a) @@ -623,6 +671,17 @@ class TestClassUint(Streamable): TestClassUint.from_bytes(b"\x00\x00") +def test_ambiguous_deserialization_int_enum() -> None: + @streamable + @dataclass(frozen=True) + class TestClassIntEnum(Streamable): + a: IntEnum + + # Does not have the required uint size + with pytest.raises(ValueError): + TestClassIntEnum.from_bytes(b"\x00\x00") + + def test_ambiguous_deserialization_list() -> None: @streamable @dataclass(frozen=True) @@ -656,6 +715,17 @@ class TestClassStr(Streamable): TestClassStr.from_bytes(bytes([0, 0, 100, 24, 52])) +def test_ambiguous_deserialization_str_enum() -> None: + @streamable + @dataclass(frozen=True) + class TestClassStr(Streamable): + a: StrEnum + + # Does not have the required str size + with pytest.raises(AssertionError): + TestClassStr.from_bytes(bytes([0, 0, 100, 24, 52])) + + def test_ambiguous_deserialization_bytes() -> None: @streamable @dataclass(frozen=True) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index 7d04210a01c1..1c13b158fe10 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -8,7 +8,7 @@ import pprint import traceback from collections.abc import Collection -from enum import Enum +from enum import Enum, EnumType from typing import TYPE_CHECKING, Any, BinaryIO, Callable, ClassVar, Optional, TypeVar, Union, get_type_hints from chia_rs.sized_bytes import bytes32 @@ -130,6 +130,10 @@ def is_type_Dict(f_type: object) -> bool: return get_origin(f_type) is dict or f_type is dict +def is_type_Enum(f_type: object) -> bool: + return type(f_type) is EnumType + + def convert_optional(convert_func: ConvertFunctionType, item: Any) -> Any: if item is None: return None @@ -156,6 +160,10 @@ def convert_dict( return {key_converter(key): value_converter(value) for key, value in mapping.items()} +def convert_enum(convert_func: ConvertFunctionType, enum: Enum) -> Any: + return convert_func(enum.value) + + def convert_hex_string(item: str) -> bytes: if not isinstance(item, str): raise InvalidTypeError(str, type(item)) @@ -228,6 +236,11 @@ def function_to_convert_one_item( key_converter = function_to_convert_one_item(inner_types[0], json_parser) value_converter = function_to_convert_one_item(inner_types[1], json_parser) return lambda mapping: convert_dict(key_converter, value_converter, mapping) # type: ignore[arg-type] + elif is_type_Enum(f_type): + if not hasattr(f_type, "_streamable_proxy"): + raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.") + convert_func = function_to_convert_one_item(f_type._streamable_proxy, json_parser) + return lambda enum: convert_enum(convert_func, enum) # type: ignore[arg-type] elif hasattr(f_type, "from_json_dict"): if json_parser is None: json_parser = f_type.from_json_dict @@ -277,6 +290,11 @@ def function_to_post_init_process_one_item(f_type: type[object]) -> ConvertFunct key_converter = function_to_post_init_process_one_item(inner_types[0]) value_converter = function_to_post_init_process_one_item(inner_types[1]) return lambda mapping: convert_dict(key_converter, value_converter, mapping) # type: ignore[arg-type] + if is_type_Enum(f_type): + if not hasattr(f_type, "_streamable_proxy"): + raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.") + process_inner_func = function_to_post_init_process_one_item(f_type._streamable_proxy) + return lambda item: convert_enum(f_type._streamable_proxy, item) # type: ignore[arg-type] return lambda item: post_init_process_item(f_type, item) @@ -307,11 +325,10 @@ def recurse_jsonify( val, None, **next_recursion_env ) return new_dict - + elif isinstance(d, Enum): + return next_recursion_step(d.value, None, **next_recursion_env) elif issubclass(type(d), bytes): return f"0x{bytes(d).hex()}" - elif isinstance(d, Enum): - return d.name elif isinstance(d, bool): return d elif isinstance(d, int): @@ -439,6 +456,10 @@ def function_to_parse_one_item(f_type: type[Any]) -> ParseFunctionType: key_parse_inner_type_f = function_to_parse_one_item(inner_types[0]) value_parse_inner_type_f = function_to_parse_one_item(inner_types[1]) return lambda f: parse_dict(f, key_parse_inner_type_f, value_parse_inner_type_f) + if is_type_Enum(f_type): + if not hasattr(f_type, "_streamable_proxy"): + raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.") + return function_to_parse_one_item(f_type._streamable_proxy) if f_type is str: return parse_str raise UnsupportedType(f"Type {f_type} does not have parse") @@ -529,6 +550,10 @@ def function_to_stream_one_item(f_type: type[Any]) -> StreamFunctionType: key_stream_inner_type_func = function_to_stream_one_item(inner_types[0]) value_stream_inner_type_func = function_to_stream_one_item(inner_types[1]) return lambda item, f: stream_dict(key_stream_inner_type_func, value_stream_inner_type_func, item, f) + elif is_type_Enum(f_type): + if not hasattr(f_type, "_streamable_proxy"): + raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.") + return function_to_stream_one_item(f_type._streamable_proxy) elif f_type is str: return stream_str elif f_type is bool: @@ -700,3 +725,15 @@ class UInt32Range(Streamable): class UInt64Range(Streamable): start: uint64 = uint64(0) stop: uint64 = uint64.MAXIMUM + + +_T_Enum = TypeVar("_T_Enum", bound=EnumType) + + +def streamable_enum(proxy: type[Any]) -> Callable[[_T_Enum], _T_Enum]: + def streamable_enum_wrapper(cls: _T_Enum) -> _T_Enum: + setattr(cls, "_streamable_proxy", proxy) + setattr(cls, "_ignore_", ["_streamable_proxy"]) + return cls + + return streamable_enum_wrapper From 4b55b7f2d46019d1a77fcdbb7f5a757d53742f00 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Aug 2025 15:12:29 -0700 Subject: [PATCH 2/6] make sure to convert to enum on parse --- chia/util/streamable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index 1c13b158fe10..f3ec5d58bbd9 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -459,7 +459,7 @@ def function_to_parse_one_item(f_type: type[Any]) -> ParseFunctionType: if is_type_Enum(f_type): if not hasattr(f_type, "_streamable_proxy"): raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.") - return function_to_parse_one_item(f_type._streamable_proxy) + return lambda f: f_type(function_to_parse_one_item(f_type._streamable_proxy)(f)) if f_type is str: return parse_str raise UnsupportedType(f"Type {f_type} does not have parse") From f2cb6b138b7d5df3302fe02315b20737f7c15483 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Aug 2025 15:15:28 -0700 Subject: [PATCH 3/6] use EnumMeta for python < 3.11 --- chia/util/streamable.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index f3ec5d58bbd9..242b08c41231 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -8,7 +8,7 @@ import pprint import traceback from collections.abc import Collection -from enum import Enum, EnumType +from enum import Enum, EnumMeta from typing import TYPE_CHECKING, Any, BinaryIO, Callable, ClassVar, Optional, TypeVar, Union, get_type_hints from chia_rs.sized_bytes import bytes32 @@ -131,7 +131,7 @@ def is_type_Dict(f_type: object) -> bool: def is_type_Enum(f_type: object) -> bool: - return type(f_type) is EnumType + return type(f_type) is EnumMeta def convert_optional(convert_func: ConvertFunctionType, item: Any) -> Any: @@ -727,7 +727,7 @@ class UInt64Range(Streamable): stop: uint64 = uint64.MAXIMUM -_T_Enum = TypeVar("_T_Enum", bound=EnumType) +_T_Enum = TypeVar("_T_Enum", bound=EnumMeta) def streamable_enum(proxy: type[Any]) -> Callable[[_T_Enum], _T_Enum]: From f3e85301c24833a1542625496f16b1cb70cea620 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Aug 2025 15:30:28 -0700 Subject: [PATCH 4/6] stop converting --- chia/_tests/core/util/test_streamable.py | 2 +- chia/util/streamable.py | 19 ++++--------------- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/chia/_tests/core/util/test_streamable.py b/chia/_tests/core/util/test_streamable.py index 2c722dd440c0..a3dcc90c3378 100644 --- a/chia/_tests/core/util/test_streamable.py +++ b/chia/_tests/core/util/test_streamable.py @@ -496,7 +496,7 @@ def validate_item_type(type_in: type[Any], item: object) -> bool: value_type, next(iter(item.values())) ) if is_type_Enum(type_in): - return validate_item_type(type_in._streamable_proxy, item) + return validate_item_type(type_in._streamable_proxy, type_in._streamable_proxy(item.value)) # type: ignore[attr-defined] return isinstance(item, type_in) test_object = test_class(*args) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index 242b08c41231..8263eb8b44ef 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -160,10 +160,6 @@ def convert_dict( return {key_converter(key): value_converter(value) for key, value in mapping.items()} -def convert_enum(convert_func: ConvertFunctionType, enum: Enum) -> Any: - return convert_func(enum.value) - - def convert_hex_string(item: str) -> bytes: if not isinstance(item, str): raise InvalidTypeError(str, type(item)) @@ -236,11 +232,6 @@ def function_to_convert_one_item( key_converter = function_to_convert_one_item(inner_types[0], json_parser) value_converter = function_to_convert_one_item(inner_types[1], json_parser) return lambda mapping: convert_dict(key_converter, value_converter, mapping) # type: ignore[arg-type] - elif is_type_Enum(f_type): - if not hasattr(f_type, "_streamable_proxy"): - raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.") - convert_func = function_to_convert_one_item(f_type._streamable_proxy, json_parser) - return lambda enum: convert_enum(convert_func, enum) # type: ignore[arg-type] elif hasattr(f_type, "from_json_dict"): if json_parser is None: json_parser = f_type.from_json_dict @@ -290,11 +281,6 @@ def function_to_post_init_process_one_item(f_type: type[object]) -> ConvertFunct key_converter = function_to_post_init_process_one_item(inner_types[0]) value_converter = function_to_post_init_process_one_item(inner_types[1]) return lambda mapping: convert_dict(key_converter, value_converter, mapping) # type: ignore[arg-type] - if is_type_Enum(f_type): - if not hasattr(f_type, "_streamable_proxy"): - raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.") - process_inner_func = function_to_post_init_process_one_item(f_type._streamable_proxy) - return lambda item: convert_enum(f_type._streamable_proxy, item) # type: ignore[arg-type] return lambda item: post_init_process_item(f_type, item) @@ -553,7 +539,10 @@ def function_to_stream_one_item(f_type: type[Any]) -> StreamFunctionType: elif is_type_Enum(f_type): if not hasattr(f_type, "_streamable_proxy"): raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.") - return function_to_stream_one_item(f_type._streamable_proxy) + return lambda item, f: function_to_stream_one_item(f_type._streamable_proxy)( + f_type._streamable_proxy(item.value), # type: ignore[attr-defined] + f, + ) elif f_type is str: return stream_str elif f_type is bool: From 4623d0d792ded7b8206ebba1651f1db74785d389 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 26 Aug 2025 10:16:12 -0700 Subject: [PATCH 5/6] Comments by @altendky --- chia/_tests/core/util/test_streamable.py | 34 ++++++++++++------------ chia/util/streamable.py | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/chia/_tests/core/util/test_streamable.py b/chia/_tests/core/util/test_streamable.py index a3dcc90c3378..4617c3a42d86 100644 --- a/chia/_tests/core/util/test_streamable.py +++ b/chia/_tests/core/util/test_streamable.py @@ -446,13 +446,13 @@ class PostInitTestClassDict(Streamable): @streamable_enum(uint32) -class IntEnum(Enum): +class IntegerEnum(Enum): A = 1 B = 2 @streamable_enum(str) -class StrEnum(Enum): +class StringEnum(Enum): A = "foo" B = "bar" @@ -460,8 +460,8 @@ class StrEnum(Enum): @streamable @dataclass(frozen=True) class PostInitTestClassEnum(Streamable): - a: IntEnum - b: StrEnum + a: IntegerEnum + b: StringEnum @pytest.mark.parametrize( @@ -474,7 +474,7 @@ class PostInitTestClassEnum(Streamable): (PostInitTestClassTuple, ((1, "test"), ((200, "test_2"), b"\xba" * 32))), (PostInitTestClassDict, ({1: "bar"}, {bytes32.zeros: {1: "bar"}})), (PostInitTestClassOptional, (12, None, 13, None)), - (PostInitTestClassEnum, (IntEnum.A, StrEnum.B)), + (PostInitTestClassEnum, (IntegerEnum.A, StringEnum.B)), ], ) def test_post_init_valid(test_class: type[Any], args: tuple[Any, ...]) -> None: @@ -541,8 +541,8 @@ class TestClass(Streamable): f: Optional[uint32] g: tuple[uint32, str, bytes] h: dict[uint32, str] - i: IntEnum - j: StrEnum + i: IntegerEnum + j: StringEnum # we want to test invalid here, hence the ignore. a = TestClass( @@ -554,8 +554,8 @@ class TestClass(Streamable): None, (uint32(383), "hello", b"goodbye"), {uint32(1): "foo"}, - IntEnum.A, - StrEnum.B, + IntegerEnum.A, + StringEnum.B, ) b: bytes = bytes(a) @@ -667,19 +667,19 @@ class TestClassUint(Streamable): a: uint32 # Does not have the required uint size - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=re.escape("uint32.from_bytes() requires 4 bytes but got: 2")): TestClassUint.from_bytes(b"\x00\x00") def test_ambiguous_deserialization_int_enum() -> None: @streamable @dataclass(frozen=True) - class TestClassIntEnum(Streamable): - a: IntEnum + class TestClassIntegerEnum(Streamable): + a: IntegerEnum - # Does not have the required uint size - with pytest.raises(ValueError): - TestClassIntEnum.from_bytes(b"\x00\x00") + # passed bytes are incorrect size for serialization proxy + with pytest.raises(ValueError, match=re.escape("uint32.from_bytes() requires 4 bytes but got: 2")): + TestClassIntegerEnum.from_bytes(b"\x00\x00") def test_ambiguous_deserialization_list() -> None: @@ -719,9 +719,9 @@ def test_ambiguous_deserialization_str_enum() -> None: @streamable @dataclass(frozen=True) class TestClassStr(Streamable): - a: StrEnum + a: StringEnum - # Does not have the required str size + # passed bytes are incorrect size for serialization proxy with pytest.raises(AssertionError): TestClassStr.from_bytes(bytes([0, 0, 100, 24, 52])) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index 8263eb8b44ef..ccca8ecb5d48 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -719,7 +719,7 @@ class UInt64Range(Streamable): _T_Enum = TypeVar("_T_Enum", bound=EnumMeta) -def streamable_enum(proxy: type[Any]) -> Callable[[_T_Enum], _T_Enum]: +def streamable_enum(proxy: type[object]) -> Callable[[_T_Enum], _T_Enum]: def streamable_enum_wrapper(cls: _T_Enum) -> _T_Enum: setattr(cls, "_streamable_proxy", proxy) setattr(cls, "_ignore_", ["_streamable_proxy"]) From 0f1d690246d7b6c6a64b0c69aa7968300360c8a8 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 26 Aug 2025 10:37:20 -0700 Subject: [PATCH 6/6] Add a test for invalid enum value --- chia/_tests/core/util/test_streamable.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/chia/_tests/core/util/test_streamable.py b/chia/_tests/core/util/test_streamable.py index 4617c3a42d86..665491d9793a 100644 --- a/chia/_tests/core/util/test_streamable.py +++ b/chia/_tests/core/util/test_streamable.py @@ -726,6 +726,17 @@ class TestClassStr(Streamable): TestClassStr.from_bytes(bytes([0, 0, 100, 24, 52])) +def test_deserialization_to_invalid_enum() -> None: + @streamable + @dataclass(frozen=True) + class TestClassStr(Streamable): + a: StringEnum + + # encodes the string "baz" which is not a valid value for StringEnum + with pytest.raises(ValueError, match=re.escape("'baz' is not a valid StringEnum")): + TestClassStr.from_bytes(bytes([0, 0, 0, 3, 98, 97, 122])) + + def test_ambiguous_deserialization_bytes() -> None: @streamable @dataclass(frozen=True)