diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 7328b3f..55ffc2f 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + python-version: ['3.9', '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v2 diff --git a/jamcodec/types.py b/jamcodec/types.py index 8ba065f..2908529 100644 --- a/jamcodec/types.py +++ b/jamcodec/types.py @@ -19,6 +19,7 @@ import math import struct import typing +from collections.abc import Mapping from typing import Union, Optional from jamcodec.base import JamCodecType, JamBytes, JamCodecPrimitive, JamCodecTypeDef @@ -949,9 +950,9 @@ def __init__(self, key_def: JamCodecTypeDef, value_def: JamCodecTypeDef): self.key_def = key_def self.value_def = value_def - def encode(self, value: Union[dict, list]) -> JamBytes: + def encode(self, value: Union[typing.Mapping, list]) -> JamBytes: - if type(value) is dict: + if isinstance(value, Mapping): value = value.items() # Encode length of Vec @@ -981,8 +982,8 @@ def decode(self, data: JamBytes) -> list: def serialize(self, value: list) -> typing.List[typing.Tuple]: return [(k.value_serialized, v.value_serialized) for k, v in value] - def deserialize(self, value: Union[dict, list]) -> list: - if type(value) is dict: + def deserialize(self, value: Union[typing.Mapping, list]) -> list: + if isinstance(value, Mapping): value = value.items() result = [] diff --git a/test/test_map.py b/test/test_map.py index ef24f1c..eedb4fd 100644 --- a/test/test_map.py +++ b/test/test_map.py @@ -16,6 +16,7 @@ # import unittest +from collections.abc import Mapping from jamcodec.base import JamBytes from jamcodec.types import U32, Map, H256, Bytes @@ -48,5 +49,39 @@ def test_map_decode(self): self.assertEqual({2: b'test'}, obj.to_serializable_obj()) + def test_mapping_type(self): + + class StorageMap(Mapping): + def __init__(self, initial_data: dict = None): + if initial_data is None: + initial_data = {} + self.cache = initial_data + + def __getitem__(self, *args): + + if len(args) == 1: + args = args[0] + + return self.cache.get(args) + + def __iter__(self): + return iter(self.cache) + + def __len__(self): + return len(self.cache) + + obj = Map(U32, Bytes).new() + + data = obj.encode(StorageMap({2: b'test'})) + self.assertEqual(JamBytes('0x01020000000474657374'), data) + + obj = Map(U32, Bytes).new() + + obj.deserialize(StorageMap({2: b'test'})) + + self.assertEqual(obj.value_object[0][0], 2) + + + if __name__ == '__main__': unittest.main()