diff --git a/jamcodec/types.py b/jamcodec/types.py index a704418..160e209 100644 --- a/jamcodec/types.py +++ b/jamcodec/types.py @@ -371,8 +371,11 @@ def encode(self, value: Union[str, dict]) -> JamBytes: for idx, (variant_name, variant_obj) in enumerate(self.variants.items()): - if enum_key == variant_name: + if type(variant_obj) is dict: + idx = variant_obj.get('id', idx) + variant_obj = variant_obj.get('type', variant_obj) + if enum_key == variant_name: data = JamBytes(bytearray([idx])) if variant_obj is not None: @@ -387,9 +390,19 @@ def decode(self, data: JamBytes) -> tuple: index = int.from_bytes(data.get_next_bytes(1), byteorder='little') - try: - enum_key, enum_variant = list(self.variants.items())[index] - except IndexError: + enum_key = None + enum_variant = None + + for idx, (variant_key, variant_obj) in enumerate(self.variants.items()): + if type(variant_obj) is dict: + idx = variant_obj.get('id', idx) + variant_obj = variant_obj.get('type', variant_obj) + if index == idx: + enum_key = variant_key + enum_variant = variant_obj + break + + if enum_key is None: raise ScaleDecodeException(f"Index '{index}' not present in Enum type mapping") if enum_variant is None: @@ -434,6 +447,9 @@ def deserialize(self, value: Union[str, dict, tuple]) -> tuple: for idx, (variant_name, variant_obj) in enumerate(self.variants.items()): + if type(variant_obj) is dict: + variant_obj = variant_obj.get('type', variant_obj) + if enum_key == variant_name: if variant_obj is not None: diff --git a/test/test_enum.py b/test/test_enum.py index 358aca7..82b5b12 100644 --- a/test/test_enum.py +++ b/test/test_enum.py @@ -17,7 +17,7 @@ import unittest -from jamcodec.types import Enum, Bool, U32 +from jamcodec.types import Enum, Bool, U32, String class TestEnum(unittest.TestCase): @@ -60,6 +60,14 @@ def test_enum_deserialize(self): scale_obj.deserialize('None') self.assertEqual(('None', None), scale_obj.value_object) + def test_enum_explicit_id(self): + scale_obj = Enum(Bool=Bool(), Number=U32, Error={'id': 255, 'type': String}).new() + jam_bytes = scale_obj.encode({'Error': 'test'}) + self.assertEqual('0xff0474657374', jam_bytes.to_hex()) + + value = scale_obj.decode(jam_bytes) + self.assertEqual({'Error': 'test'}, value) + if __name__ == '__main__': unittest.main()