diff --git a/src/meta_memcache/serializer.py b/src/meta_memcache/serializer.py index 816ad7c..2a7f37d 100644 --- a/src/meta_memcache/serializer.py +++ b/src/meta_memcache/serializer.py @@ -72,7 +72,6 @@ class ZstdSerializer(BaseSerializer): BINARY = 16 ZSTD_COMPRESSED = 32 - ZSTD_MAGIC = b"(\xb5/\xfd" DEFAULT_PICKLE_PROTOCOL = 5 DEFAULT_COMPRESSION_LEVEL = 9 DEFAULT_COMPRESSION_THRESHOLD = 128 @@ -141,7 +140,7 @@ def __init__( else: self._default_zstd_compressor = None - self._zstd_decompressors[0] = zstd.ZstdDecompressor() + self._zstd_decompressors[0] = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS) def _build_dict(self, dictionary: bytes) -> Tuple[int, zstd.ZstdCompressionDict]: zstd_dict = zstd.ZstdCompressionDict(dictionary) @@ -151,7 +150,7 @@ def _build_dict(self, dictionary: bytes) -> Tuple[int, zstd.ZstdCompressionDict] def _add_dict_decompressor( self, dict_id: int, zstd_dict: zstd.ZstdCompressionDict ) -> zstd.ZstdDecompressor: - self._zstd_decompressors[dict_id] = zstd.ZstdDecompressor(dict_data=zstd_dict) + self._zstd_decompressors[dict_id] = zstd.ZstdDecompressor(dict_data=zstd_dict, format=zstd.FORMAT_ZSTD1_MAGICLESS) return self._zstd_decompressors[dict_id] def _add_dict_compressor( @@ -174,11 +173,10 @@ def _compress(self, key: Key, data: bytes) -> Tuple[bytes, int]: return zlib.compress(data), self.ZLIB_COMPRESSED def _decompress(self, data: bytes) -> bytes: - data = self.ZSTD_MAGIC + data - dict_id = zstd.get_frame_parameters(data).dict_id - if decompressor := self._zstd_decompressors.get(dict_id): + params = zstd.get_frame_parameters(data, format=zstd.FORMAT_ZSTD1_MAGICLESS) + if decompressor := self._zstd_decompressors.get(params.dict_id): return decompressor.decompress(data) - raise ValueError(f"Unknown dictionary id: {dict_id}") + raise ValueError(f"Unknown dictionary id: {params.dict_id}") def _should_compress(self, key: Key, data: bytes) -> bool: data_len = len(data) diff --git a/tests/serializer_test.py b/tests/serializer_test.py index e9957cd..f31ac6d 100644 --- a/tests/serializer_test.py +++ b/tests/serializer_test.py @@ -34,7 +34,7 @@ def get_compression_dict(compressor: zstd.ZstdCompressor) -> int: def get_data_compression_dict(data: bytes) -> int: - return zstd.get_frame_parameters(ZstdSerializer.ZSTD_MAGIC + data).dict_id + return zstd.get_frame_parameters(data, format=zstd.FORMAT_ZSTD1_MAGICLESS).dict_id def test_zstd_serializer_initialization(