Skip to content

Commit 2f2450a

Browse files
authored
feat(python): Ensure that buffer produced by CBufferView.unpack_bits() has a boolean type (#457)
This is small change to ensure that `np.array(some_buffer.unpack_bits())` "just works" without nanoarrow having to know about numpy dtypes. Basically we just need to ensure that we can create/export a buffer with a `"?"` format string. ```python import nanoarrow as na import numpy as np bool_array = na.Array([True, True, True, False, False, True], na.bool_()) np.array(bool_array.buffer(1).unpack_bits(0, len(bool_array))) #> array([ True, True, True, False, False, True]) ```
1 parent f47e830 commit 2f2450a

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

python/src/nanoarrow/_lib.pyx

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ cdef c_arrow_type_from_format(format):
335335
return item_size, NANOARROW_TYPE_DOUBLE
336336

337337
# Check for signed integers
338-
if format in ("b", "?", "h", "i", "l", "q", "n"):
338+
if format in ("b", "h", "i", "l", "q", "n"):
339339
if item_size == 1:
340340
return item_size, NANOARROW_TYPE_INT8
341341
elif item_size == 2:
@@ -346,7 +346,7 @@ cdef c_arrow_type_from_format(format):
346346
return item_size, NANOARROW_TYPE_INT64
347347

348348
# Check for unsinged integers
349-
if format in ("B", "H", "I", "L", "Q", "N"):
349+
if format in ("B", "?", "H", "I", "L", "Q", "N"):
350350
if item_size == 1:
351351
return item_size, NANOARROW_TYPE_UINT8
352352
elif item_size == 2:
@@ -1988,7 +1988,7 @@ cdef class CBufferView:
19881988
if length is None:
19891989
length = self.n_elements
19901990

1991-
out = CBufferBuilder().set_data_type(NANOARROW_TYPE_UINT8)
1991+
out = CBufferBuilder().set_format("?")
19921992
out.reserve_bytes(length)
19931993
self.unpack_bits_into(out, offset, length)
19941994
out.advance(length)
@@ -2108,6 +2108,8 @@ cdef class CBuffer:
21082108
self._device
21092109
)
21102110

2111+
snprintf(self._view._format, sizeof(self._view._format), "%s", self._format)
2112+
21112113
@staticmethod
21122114
def empty():
21132115
cdef CBuffer out = CBuffer()
@@ -2272,6 +2274,13 @@ cdef class CBufferBuilder:
22722274
self._buffer._set_data_type(type_id, element_size_bits)
22732275
return self
22742276

2277+
def set_format(self, str format):
2278+
"""Set the Python buffer format used to interpret elements in
2279+
:meth:`write_elements`.
2280+
"""
2281+
self._buffer._set_format(format)
2282+
return self
2283+
22752284
@property
22762285
def format(self):
22772286
"""The ``struct`` format code of the underlying buffer"""

python/tests/test_c_buffer_view.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_buffer_view_bool_unpack():
7272
unpacked_all = view.unpack_bits()
7373
assert len(unpacked_all) == view.n_elements
7474
assert unpacked_all.data_type == "uint8"
75+
assert unpacked_all.format == "?"
7576
assert list(unpacked_all) == [1, 0, 0, 1, 0, 0, 0, 0]
7677

7778
unpacked_some = view.unpack_bits(1, 4)

0 commit comments

Comments
 (0)