3
3
import io
4
4
import re
5
5
from dataclasses import dataclass , field , fields
6
+ from enum import Enum
6
7
from typing import Any , Callable , ClassVar , Optional , get_type_hints
7
8
8
9
import pytest
27
28
function_to_parse_one_item ,
28
29
function_to_stream_one_item ,
29
30
is_type_Dict ,
31
+ is_type_Enum ,
30
32
is_type_List ,
31
33
is_type_SpecificOptional ,
32
34
is_type_Tuple ,
39
41
parse_uint32 ,
40
42
recurse_jsonify ,
41
43
streamable ,
44
+ streamable_enum ,
42
45
streamable_from_dict ,
43
46
write_uint32 ,
44
47
)
@@ -376,6 +379,25 @@ def test_basic_optional() -> None:
376
379
assert not is_type_SpecificOptional (list [int ])
377
380
378
381
382
+ class BasicEnum (Enum ):
383
+ A = 1
384
+ B = 2
385
+
386
+
387
+ def test_basic_enum () -> None :
388
+ assert is_type_Enum (BasicEnum )
389
+ assert not is_type_Enum (list [int ])
390
+
391
+
392
+ def test_enum_needs_proxy () -> None :
393
+ with pytest .raises (UnsupportedType ):
394
+
395
+ @streamable
396
+ @dataclass (frozen = True )
397
+ class EnumStreamable (Streamable ):
398
+ enum : BasicEnum
399
+
400
+
379
401
@streamable
380
402
@dataclass (frozen = True )
381
403
class PostInitTestClassBasic (Streamable ):
@@ -423,6 +445,25 @@ class PostInitTestClassDict(Streamable):
423
445
b : dict [bytes32 , dict [uint8 , str ]]
424
446
425
447
448
+ @streamable_enum (uint32 )
449
+ class IntegerEnum (Enum ):
450
+ A = 1
451
+ B = 2
452
+
453
+
454
+ @streamable_enum (str )
455
+ class StringEnum (Enum ):
456
+ A = "foo"
457
+ B = "bar"
458
+
459
+
460
+ @streamable
461
+ @dataclass (frozen = True )
462
+ class PostInitTestClassEnum (Streamable ):
463
+ a : IntegerEnum
464
+ b : StringEnum
465
+
466
+
426
467
@pytest .mark .parametrize (
427
468
"test_class, args" ,
428
469
[
@@ -433,6 +474,7 @@ class PostInitTestClassDict(Streamable):
433
474
(PostInitTestClassTuple , ((1 , "test" ), ((200 , "test_2" ), b"\xba " * 32 ))),
434
475
(PostInitTestClassDict , ({1 : "bar" }, {bytes32 .zeros : {1 : "bar" }})),
435
476
(PostInitTestClassOptional , (12 , None , 13 , None )),
477
+ (PostInitTestClassEnum , (IntegerEnum .A , StringEnum .B )),
436
478
],
437
479
)
438
480
def test_post_init_valid (test_class : type [Any ], args : tuple [Any , ...]) -> None :
@@ -453,6 +495,8 @@ def validate_item_type(type_in: type[Any], item: object) -> bool:
453
495
return validate_item_type (key_type , next (iter (item .keys ()))) and validate_item_type (
454
496
value_type , next (iter (item .values ()))
455
497
)
498
+ if is_type_Enum (type_in ):
499
+ return validate_item_type (type_in ._streamable_proxy , type_in ._streamable_proxy (item .value )) # type: ignore[attr-defined]
456
500
return isinstance (item , type_in )
457
501
458
502
test_object = test_class (* args )
@@ -497,6 +541,8 @@ class TestClass(Streamable):
497
541
f : Optional [uint32 ]
498
542
g : tuple [uint32 , str , bytes ]
499
543
h : dict [uint32 , str ]
544
+ i : IntegerEnum
545
+ j : StringEnum
500
546
501
547
# we want to test invalid here, hence the ignore.
502
548
a = TestClass (
@@ -508,6 +554,8 @@ class TestClass(Streamable):
508
554
None ,
509
555
(uint32 (383 ), "hello" , b"goodbye" ),
510
556
{uint32 (1 ): "foo" },
557
+ IntegerEnum .A ,
558
+ StringEnum .B ,
511
559
)
512
560
513
561
b : bytes = bytes (a )
@@ -619,10 +667,21 @@ class TestClassUint(Streamable):
619
667
a : uint32
620
668
621
669
# Does not have the required uint size
622
- with pytest .raises (ValueError ):
670
+ with pytest .raises (ValueError , match = re . escape ( "uint32.from_bytes() requires 4 bytes but got: 2" ) ):
623
671
TestClassUint .from_bytes (b"\x00 \x00 " )
624
672
625
673
674
+ def test_ambiguous_deserialization_int_enum () -> None :
675
+ @streamable
676
+ @dataclass (frozen = True )
677
+ class TestClassIntegerEnum (Streamable ):
678
+ a : IntegerEnum
679
+
680
+ # passed bytes are incorrect size for serialization proxy
681
+ with pytest .raises (ValueError , match = re .escape ("uint32.from_bytes() requires 4 bytes but got: 2" )):
682
+ TestClassIntegerEnum .from_bytes (b"\x00 \x00 " )
683
+
684
+
626
685
def test_ambiguous_deserialization_list () -> None :
627
686
@streamable
628
687
@dataclass (frozen = True )
@@ -656,6 +715,28 @@ class TestClassStr(Streamable):
656
715
TestClassStr .from_bytes (bytes ([0 , 0 , 100 , 24 , 52 ]))
657
716
658
717
718
+ def test_ambiguous_deserialization_str_enum () -> None :
719
+ @streamable
720
+ @dataclass (frozen = True )
721
+ class TestClassStr (Streamable ):
722
+ a : StringEnum
723
+
724
+ # passed bytes are incorrect size for serialization proxy
725
+ with pytest .raises (AssertionError ):
726
+ TestClassStr .from_bytes (bytes ([0 , 0 , 100 , 24 , 52 ]))
727
+
728
+
729
+ def test_deserialization_to_invalid_enum () -> None :
730
+ @streamable
731
+ @dataclass (frozen = True )
732
+ class TestClassStr (Streamable ):
733
+ a : StringEnum
734
+
735
+ # encodes the string "baz" which is not a valid value for StringEnum
736
+ with pytest .raises (ValueError , match = re .escape ("'baz' is not a valid StringEnum" )):
737
+ TestClassStr .from_bytes (bytes ([0 , 0 , 0 , 3 , 98 , 97 , 122 ]))
738
+
739
+
659
740
def test_ambiguous_deserialization_bytes () -> None :
660
741
@streamable
661
742
@dataclass (frozen = True )
0 commit comments