diff --git a/pyroaring/abstract_bitmap.pxi b/pyroaring/abstract_bitmap.pxi index ab30abc..786be9a 100644 --- a/pyroaring/abstract_bitmap.pxi +++ b/pyroaring/abstract_bitmap.pxi @@ -12,18 +12,32 @@ try: except NameError: # python 3 pass -cdef croaring.roaring_bitmap_t *deserialize_ptr(char *buff): +cdef croaring.roaring_bitmap_t *deserialize_ptr(bytes buff): cdef croaring.roaring_bitmap_t *ptr - ptr = croaring.roaring_bitmap_portable_deserialize(buff) + cdef const char *reason_failure = NULL + buff_size = len(buff) + ptr = croaring.roaring_bitmap_portable_deserialize_safe(buff, buff_size) + if ptr == NULL: + raise ValueError("Could not deserialize bitmap") + # Validate the bitmap + if not croaring.roaring_bitmap_internal_validate(ptr, &reason_failure): + # If validation fails, free the bitmap and raise an exception + croaring.roaring_bitmap_free(ptr) + raise ValueError(f"Invalid bitmap after deserialization: {reason_failure.decode('utf-8')}") return ptr cdef croaring.roaring64_bitmap_t *deserialize64_ptr(bytes buff): cdef croaring.roaring64_bitmap_t *ptr + cdef const char *reason_failure = NULL buff_size = len(buff) - bm_size = croaring.roaring64_bitmap_portable_deserialize_size(buff, buff_size) - if bm_size == 0: - raise ValueError("Invalid bitmap serialization") - ptr = croaring.roaring64_bitmap_portable_deserialize_safe(buff, bm_size) + ptr = croaring.roaring64_bitmap_portable_deserialize_safe(buff, buff_size) + if ptr == NULL: + raise ValueError("Could not deserialize bitmap") + # Validate the bitmap + if not croaring.roaring64_bitmap_internal_validate(ptr, &reason_failure): + # If validation fails, free the bitmap and raise an exception + croaring.roaring64_bitmap_free(ptr) + raise ValueError(f"Invalid bitmap after deserialization: {reason_failure.decode('utf-8')}") return ptr def _string_rep(bm): @@ -744,7 +758,7 @@ cdef class AbstractBitMap: @classmethod - def deserialize(cls, char *buff): + def deserialize(cls, bytes buff): """ Generate a bitmap from the given serialization. See AbstractBitMap.serialize for the reverse operation. diff --git a/pyroaring/croaring.pxd b/pyroaring/croaring.pxd index ed5a33b..9b94de5 100644 --- a/pyroaring/croaring.pxd +++ b/pyroaring/croaring.pxd @@ -100,6 +100,8 @@ cdef extern from "roaring.h": size_t roaring_bitmap_portable_size_in_bytes(const roaring_bitmap_t *ra) size_t roaring_bitmap_portable_serialize(const roaring_bitmap_t *ra, char *buf) roaring_bitmap_t *roaring_bitmap_portable_deserialize(const char *buf) + roaring_bitmap_t *roaring_bitmap_portable_deserialize_safe(const char *buf, size_t maxbytes) + bool roaring_bitmap_internal_validate(const roaring_bitmap_t *r, const char **reason) roaring_uint32_iterator_t *roaring_iterator_create(const roaring_bitmap_t *ra) bool roaring_uint32_iterator_advance(roaring_uint32_iterator_t *it) uint32_t roaring_uint32_iterator_read(roaring_uint32_iterator_t *it, uint32_t* buf, uint32_t count) @@ -163,6 +165,7 @@ cdef extern from "roaring.h": size_t roaring64_bitmap_portable_serialize(const roaring64_bitmap_t *r, char *buf) size_t roaring64_bitmap_portable_deserialize_size(const char *buf, size_t maxbytes) roaring64_bitmap_t *roaring64_bitmap_portable_deserialize_safe(const char *buf, size_t maxbytes) + bool roaring64_bitmap_internal_validate(const roaring64_bitmap_t *r, const char **reason) roaring64_iterator_t *roaring64_iterator_create(const roaring64_bitmap_t *r) void roaring64_iterator_free(roaring64_iterator_t *it) bool roaring64_iterator_has_value(const roaring64_iterator_t *it) diff --git a/setup.py b/setup.py index 9d1ae83..68e663b 100755 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ version=VERSION, description='Library for handling efficiently sorted integer sets.', long_description=long_description, - setup_requires=['cython'], + setup_requires=['cython>=3.0.2'], url='https://github.com/Ezibenroc/PyRoaringBitMap', author='Tom Cornebize', author_email='tom.cornebize@gmail.com', diff --git a/test.py b/test.py index 0374831..1d297da 100755 --- a/test.py +++ b/test.py @@ -12,6 +12,7 @@ import operator import unittest import functools +import base64 from typing import TYPE_CHECKING from collections.abc import Set, Callable, Iterable, Iterator @@ -886,6 +887,27 @@ def test_pickle_protocol( assert old_bm == new_bm self.assert_is_not(old_bm, new_bm) + @given(bitmap_cls) + def test_impossible_deserialization( + self, + cls: type[EitherBitMap], + ) -> None: + wrong_input = base64.b64decode('aGVsbG8gd29ybGQ=') + with pytest.raises(ValueError, match='Could not deserialize bitmap'): + bitmap = cls.deserialize(wrong_input) + + @given(bitmap_cls) + def test_invalid_deserialization( + self, + cls: type[EitherBitMap], + ) -> None: + wrong_input = base64.b64decode('aGVsbG8gd29ybGQ=') + bm = cls(list(range(0, 1000000, 3))) + bitmap_bytes = bm.serialize() + bitmap_bytes = bitmap_bytes[:42] + wrong_input + bitmap_bytes[42:] + with pytest.raises(ValueError, match='Invalid bitmap after deserialization'): + bitmap = cls.deserialize(bitmap_bytes) + class TestStatistics(Util): diff --git a/tox.ini b/tox.ini index 58c58cf..f4def40 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,5 @@ [tox] envlist = - cython_pre3 cython3 test_wheel @@ -10,19 +9,6 @@ setenv = PYTHONFAULTHANDLER=1 -[testenv:cython_pre3] -deps = - hypothesis - pytest - cython<3.0.0 -passenv = - HYPOTHESIS_PROFILE - ROARING_BITSIZE -commands = - py.test -v test.py test_state_machine.py - python cydoctest.py - - [testenv:cython3] deps = hypothesis