Skip to content

Commit 9847bbe

Browse files
authored
Explicit ID for Enums (#15)
1 parent d68303f commit 9847bbe

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

jamcodec/types.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,11 @@ def encode(self, value: Union[str, dict]) -> JamBytes:
371371

372372
for idx, (variant_name, variant_obj) in enumerate(self.variants.items()):
373373

374-
if enum_key == variant_name:
374+
if type(variant_obj) is dict:
375+
idx = variant_obj.get('id', idx)
376+
variant_obj = variant_obj.get('type', variant_obj)
375377

378+
if enum_key == variant_name:
376379
data = JamBytes(bytearray([idx]))
377380

378381
if variant_obj is not None:
@@ -387,9 +390,19 @@ def decode(self, data: JamBytes) -> tuple:
387390

388391
index = int.from_bytes(data.get_next_bytes(1), byteorder='little')
389392

390-
try:
391-
enum_key, enum_variant = list(self.variants.items())[index]
392-
except IndexError:
393+
enum_key = None
394+
enum_variant = None
395+
396+
for idx, (variant_key, variant_obj) in enumerate(self.variants.items()):
397+
if type(variant_obj) is dict:
398+
idx = variant_obj.get('id', idx)
399+
variant_obj = variant_obj.get('type', variant_obj)
400+
if index == idx:
401+
enum_key = variant_key
402+
enum_variant = variant_obj
403+
break
404+
405+
if enum_key is None:
393406
raise ScaleDecodeException(f"Index '{index}' not present in Enum type mapping")
394407

395408
if enum_variant is None:
@@ -434,6 +447,9 @@ def deserialize(self, value: Union[str, dict, tuple]) -> tuple:
434447

435448
for idx, (variant_name, variant_obj) in enumerate(self.variants.items()):
436449

450+
if type(variant_obj) is dict:
451+
variant_obj = variant_obj.get('type', variant_obj)
452+
437453
if enum_key == variant_name:
438454

439455
if variant_obj is not None:

test/test_enum.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import unittest
1919

20-
from jamcodec.types import Enum, Bool, U32
20+
from jamcodec.types import Enum, Bool, U32, String
2121

2222

2323
class TestEnum(unittest.TestCase):
@@ -60,6 +60,14 @@ def test_enum_deserialize(self):
6060
scale_obj.deserialize('None')
6161
self.assertEqual(('None', None), scale_obj.value_object)
6262

63+
def test_enum_explicit_id(self):
64+
scale_obj = Enum(Bool=Bool(), Number=U32, Error={'id': 255, 'type': String}).new()
65+
jam_bytes = scale_obj.encode({'Error': 'test'})
66+
self.assertEqual('0xff0474657374', jam_bytes.to_hex())
67+
68+
value = scale_obj.decode(jam_bytes)
69+
self.assertEqual({'Error': 'test'}, value)
70+
6371

6472
if __name__ == '__main__':
6573
unittest.main()

0 commit comments

Comments
 (0)