Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions chia/_tests/core/util/test_streamable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import re
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Any, Callable, ClassVar, Optional, get_type_hints

import pytest
Expand All @@ -27,6 +28,7 @@
function_to_parse_one_item,
function_to_stream_one_item,
is_type_Dict,
is_type_Enum,
is_type_List,
is_type_SpecificOptional,
is_type_Tuple,
Expand All @@ -39,6 +41,7 @@
parse_uint32,
recurse_jsonify,
streamable,
streamable_enum,
streamable_from_dict,
write_uint32,
)
Expand Down Expand Up @@ -376,6 +379,25 @@ def test_basic_optional() -> None:
assert not is_type_SpecificOptional(list[int])


class BasicEnum(Enum):
A = 1
B = 2


def test_basic_enum() -> None:
assert is_type_Enum(BasicEnum)
assert not is_type_Enum(list[int])


def test_enum_needs_proxy() -> None:
with pytest.raises(UnsupportedType):

@streamable
@dataclass(frozen=True)
class EnumStreamable(Streamable):
enum: BasicEnum


@streamable
@dataclass(frozen=True)
class PostInitTestClassBasic(Streamable):
Expand Down Expand Up @@ -423,6 +445,25 @@ class PostInitTestClassDict(Streamable):
b: dict[bytes32, dict[uint8, str]]


@streamable_enum(uint32)
class IntEnum(Enum):
A = 1
B = 2


@streamable_enum(str)
class StrEnum(Enum):
A = "foo"
B = "bar"


@streamable
@dataclass(frozen=True)
class PostInitTestClassEnum(Streamable):
a: IntEnum
b: StrEnum


@pytest.mark.parametrize(
"test_class, args",
[
Expand All @@ -433,6 +474,7 @@ class PostInitTestClassDict(Streamable):
(PostInitTestClassTuple, ((1, "test"), ((200, "test_2"), b"\xba" * 32))),
(PostInitTestClassDict, ({1: "bar"}, {bytes32.zeros: {1: "bar"}})),
(PostInitTestClassOptional, (12, None, 13, None)),
(PostInitTestClassEnum, (IntEnum.A, StrEnum.B)),
],
)
def test_post_init_valid(test_class: type[Any], args: tuple[Any, ...]) -> None:
Expand All @@ -453,6 +495,8 @@ def validate_item_type(type_in: type[Any], item: object) -> bool:
return validate_item_type(key_type, next(iter(item.keys()))) and validate_item_type(
value_type, next(iter(item.values()))
)
if is_type_Enum(type_in):
return validate_item_type(type_in._streamable_proxy, type_in._streamable_proxy(item.value)) # type: ignore[attr-defined]
return isinstance(item, type_in)

test_object = test_class(*args)
Expand Down Expand Up @@ -497,6 +541,8 @@ class TestClass(Streamable):
f: Optional[uint32]
g: tuple[uint32, str, bytes]
h: dict[uint32, str]
i: IntEnum
j: StrEnum

# we want to test invalid here, hence the ignore.
a = TestClass(
Expand All @@ -508,6 +554,8 @@ class TestClass(Streamable):
None,
(uint32(383), "hello", b"goodbye"),
{uint32(1): "foo"},
IntEnum.A,
StrEnum.B,
)

b: bytes = bytes(a)
Expand Down Expand Up @@ -623,6 +671,17 @@ class TestClassUint(Streamable):
TestClassUint.from_bytes(b"\x00\x00")


def test_ambiguous_deserialization_int_enum() -> None:
@streamable
@dataclass(frozen=True)
class TestClassIntEnum(Streamable):
a: IntEnum

# Does not have the required uint size
with pytest.raises(ValueError):
TestClassIntEnum.from_bytes(b"\x00\x00")


def test_ambiguous_deserialization_list() -> None:
@streamable
@dataclass(frozen=True)
Expand Down Expand Up @@ -656,6 +715,17 @@ class TestClassStr(Streamable):
TestClassStr.from_bytes(bytes([0, 0, 100, 24, 52]))


def test_ambiguous_deserialization_str_enum() -> None:
@streamable
@dataclass(frozen=True)
class TestClassStr(Streamable):
a: StrEnum

# Does not have the required str size
with pytest.raises(AssertionError):
TestClassStr.from_bytes(bytes([0, 0, 100, 24, 52]))


def test_ambiguous_deserialization_bytes() -> None:
@streamable
@dataclass(frozen=True)
Expand Down
34 changes: 30 additions & 4 deletions chia/util/streamable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pprint
import traceback
from collections.abc import Collection
from enum import Enum
from enum import Enum, EnumMeta
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, ClassVar, Optional, TypeVar, Union, get_type_hints

from chia_rs.sized_bytes import bytes32
Expand Down Expand Up @@ -130,6 +130,10 @@ def is_type_Dict(f_type: object) -> bool:
return get_origin(f_type) is dict or f_type is dict


def is_type_Enum(f_type: object) -> bool:
return type(f_type) is EnumMeta


def convert_optional(convert_func: ConvertFunctionType, item: Any) -> Any:
if item is None:
return None
Expand Down Expand Up @@ -307,11 +311,10 @@ def recurse_jsonify(
val, None, **next_recursion_env
)
return new_dict

elif isinstance(d, Enum):
return next_recursion_step(d.value, None, **next_recursion_env)
elif issubclass(type(d), bytes):
return f"0x{bytes(d).hex()}"
elif isinstance(d, Enum):
return d.name
elif isinstance(d, bool):
return d
elif isinstance(d, int):
Expand Down Expand Up @@ -439,6 +442,10 @@ def function_to_parse_one_item(f_type: type[Any]) -> ParseFunctionType:
key_parse_inner_type_f = function_to_parse_one_item(inner_types[0])
value_parse_inner_type_f = function_to_parse_one_item(inner_types[1])
return lambda f: parse_dict(f, key_parse_inner_type_f, value_parse_inner_type_f)
if is_type_Enum(f_type):
if not hasattr(f_type, "_streamable_proxy"):
raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.")
return lambda f: f_type(function_to_parse_one_item(f_type._streamable_proxy)(f))
if f_type is str:
return parse_str
raise UnsupportedType(f"Type {f_type} does not have parse")
Expand Down Expand Up @@ -529,6 +536,13 @@ def function_to_stream_one_item(f_type: type[Any]) -> StreamFunctionType:
key_stream_inner_type_func = function_to_stream_one_item(inner_types[0])
value_stream_inner_type_func = function_to_stream_one_item(inner_types[1])
return lambda item, f: stream_dict(key_stream_inner_type_func, value_stream_inner_type_func, item, f)
elif is_type_Enum(f_type):
if not hasattr(f_type, "_streamable_proxy"):
raise UnsupportedType(f"Using Enum ({f_type}) in streamable requires a 'streamable_enum' wrapper.")
return lambda item, f: function_to_stream_one_item(f_type._streamable_proxy)(
f_type._streamable_proxy(item.value), # type: ignore[attr-defined]
f,
)
elif f_type is str:
return stream_str
elif f_type is bool:
Expand Down Expand Up @@ -700,3 +714,15 @@ class UInt32Range(Streamable):
class UInt64Range(Streamable):
start: uint64 = uint64(0)
stop: uint64 = uint64.MAXIMUM


_T_Enum = TypeVar("_T_Enum", bound=EnumMeta)


def streamable_enum(proxy: type[Any]) -> Callable[[_T_Enum], _T_Enum]:
def streamable_enum_wrapper(cls: _T_Enum) -> _T_Enum:
setattr(cls, "_streamable_proxy", proxy)
setattr(cls, "_ignore_", ["_streamable_proxy"])
return cls

return streamable_enum_wrapper
Loading