Skip to content

Commit d123368

Browse files
committed
Refactor BitReader to accomodate unpack fn
1 parent 615dd80 commit d123368

File tree

1 file changed

+43
-6
lines changed

1 file changed

+43
-6
lines changed

cpp/src/arrow/util/bit_stream_utils_internal.h

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,34 @@ inline uint64_t ReadLittleEndianWord(const uint8_t* buffer, int bytes_remaining)
123123
/// bytes in one read (e.g. encoded int).
124124
class BitReader {
125125
public:
126+
template <typename T, typename = void>
127+
struct UnpackFnDetect;
128+
129+
template <typename T>
130+
struct UnpackFnDetect<T, std::enable_if_t<(sizeof(T) >= sizeof(int))>> {
131+
using type = internal::UnpackFn<std::make_unsigned_t<T>>;
132+
};
133+
134+
template <typename T>
135+
struct UnpackFnDetect<T, std::enable_if_t<(sizeof(T) < sizeof(int))>> {
136+
using type = internal::UnpackFn<uint32_t>;
137+
};
138+
139+
/// The type for a function that can extract bit-packed integers.
140+
template <typename T>
141+
using UnpackFn = typename UnpackFnDetect<T>::type;
142+
143+
/// Get the unack function most appropriated for this type and bit width.
144+
template <typename T>
145+
static UnpackFn<T> get_unpack_fn(int num_bits) {
146+
// This is intimately linked to the GetBatch implementation
147+
if constexpr (sizeof(T) >= sizeof(int)) {
148+
return internal::get_unpack_fn<std::make_unsigned_t<T>>(num_bits);
149+
} else {
150+
return internal::get_unpack_fn<uint32_t>(num_bits);
151+
}
152+
}
153+
126154
BitReader() noexcept = default;
127155

128156
/// 'buffer' is the buffer to read from. The buffer's length is 'buffer_len'.
@@ -148,6 +176,11 @@ class BitReader {
148176
template <typename T>
149177
int GetBatch(int num_bits, T* v, int batch_size);
150178

179+
/// Get a number of values from the buffer. Return the number of values actually read.
180+
/// @param unpack Function pointer to the unpack function for the correct bit width.
181+
template <typename T>
182+
int GetBatch(int num_bits, T* v, int batch_size, UnpackFn<T> unpack);
183+
151184
/// Reads a 'num_bytes'-sized value from the buffer and stores it in 'v'. T
152185
/// needs to be a little-endian native type and big enough to store
153186
/// 'num_bytes'. The value is assumed to be byte-aligned so the stream will
@@ -297,7 +330,7 @@ inline bool BitReader::GetValue(int num_bits, T* v) {
297330
}
298331

299332
template <typename T>
300-
inline int BitReader::GetBatch(int num_bits, T* v, int batch_size) {
333+
int BitReader::GetBatch(int num_bits, T* v, int batch_size, UnpackFn<T> unpack) {
301334
ARROW_DCHECK(buffer_ != NULL);
302335
ARROW_DCHECK_LE(num_bits, static_cast<int>(sizeof(T) * 8)) << "num_bits: " << num_bits;
303336

@@ -325,9 +358,9 @@ inline int BitReader::GetBatch(int num_bits, T* v, int batch_size) {
325358

326359
// unpack for uint16_t not as fast as unpack for uint32_t + memcpy.
327360
if constexpr (sizeof(T) >= sizeof(32)) {
328-
int num_unpacked = internal::unpack(buffer + byte_offset,
329-
reinterpret_cast<std::make_unsigned_t<T>*>(v + i),
330-
batch_size - i, num_bits);
361+
int num_unpacked =
362+
unpack(buffer + byte_offset, reinterpret_cast<std::make_unsigned_t<T>*>(v + i),
363+
batch_size - i);
331364
i += num_unpacked;
332365
byte_offset += num_unpacked * num_bits / 8;
333366
} else {
@@ -337,8 +370,7 @@ inline int BitReader::GetBatch(int num_bits, T* v, int batch_size) {
337370
uint32_t unpack_buffer[buffer_size];
338371
while (i < batch_size) {
339372
int unpack_size = std::min(buffer_size, batch_size - i);
340-
int num_unpacked =
341-
internal::unpack(buffer + byte_offset, unpack_buffer, unpack_size, num_bits);
373+
int num_unpacked = unpack(buffer + byte_offset, unpack_buffer, unpack_size);
342374
if (num_unpacked == 0) {
343375
break;
344376
}
@@ -372,6 +404,11 @@ inline int BitReader::GetBatch(int num_bits, T* v, int batch_size) {
372404
return batch_size;
373405
}
374406

407+
template <typename T>
408+
inline int BitReader::GetBatch(int num_bits, T* v, int batch_size) {
409+
return GetBatch(num_bits, v, batch_size, get_unpack_fn<T>(num_bits));
410+
}
411+
375412
template <typename T>
376413
inline bool BitReader::GetAligned(int num_bytes, T* v) {
377414
if (ARROW_PREDICT_FALSE(num_bytes > static_cast<int>(sizeof(T)))) {

0 commit comments

Comments
 (0)