@@ -123,6 +123,34 @@ inline uint64_t ReadLittleEndianWord(const uint8_t* buffer, int bytes_remaining)
123123// / bytes in one read (e.g. encoded int).
124124class BitReader {
125125 public:
126+ template <typename T, typename = void >
127+ struct UnpackFnDetect ;
128+
129+ template <typename T>
130+ struct UnpackFnDetect <T, std::enable_if_t <(sizeof (T) >= sizeof (int ))>> {
131+ using type = internal::UnpackFn<std::make_unsigned_t <T>>;
132+ };
133+
134+ template <typename T>
135+ struct UnpackFnDetect <T, std::enable_if_t <(sizeof (T) < sizeof (int ))>> {
136+ using type = internal::UnpackFn<uint32_t >;
137+ };
138+
139+ // / The type for a function that can extract bit-packed integers.
140+ template <typename T>
141+ using UnpackFn = typename UnpackFnDetect<T>::type;
142+
143+ // / Get the unack function most appropriated for this type and bit width.
144+ template <typename T>
145+ static UnpackFn<T> get_unpack_fn (int num_bits) {
146+ // This is intimately linked to the GetBatch implementation
147+ if constexpr (sizeof (T) >= sizeof (int )) {
148+ return internal::get_unpack_fn<std::make_unsigned_t <T>>(num_bits);
149+ } else {
150+ return internal::get_unpack_fn<uint32_t >(num_bits);
151+ }
152+ }
153+
126154 BitReader () noexcept = default;
127155
128156 // / 'buffer' is the buffer to read from. The buffer's length is 'buffer_len'.
@@ -148,6 +176,11 @@ class BitReader {
148176 template <typename T>
149177 int GetBatch (int num_bits, T* v, int batch_size);
150178
179+ // / Get a number of values from the buffer. Return the number of values actually read.
180+ // / @param unpack Function pointer to the unpack function for the correct bit width.
181+ template <typename T>
182+ int GetBatch (int num_bits, T* v, int batch_size, UnpackFn<T> unpack);
183+
151184 // / Reads a 'num_bytes'-sized value from the buffer and stores it in 'v'. T
152185 // / needs to be a little-endian native type and big enough to store
153186 // / 'num_bytes'. The value is assumed to be byte-aligned so the stream will
@@ -297,7 +330,7 @@ inline bool BitReader::GetValue(int num_bits, T* v) {
297330}
298331
299332template <typename T>
300- inline int BitReader::GetBatch (int num_bits, T* v, int batch_size) {
333+ int BitReader::GetBatch (int num_bits, T* v, int batch_size, UnpackFn<T> unpack ) {
301334 ARROW_DCHECK (buffer_ != NULL );
302335 ARROW_DCHECK_LE (num_bits, static_cast <int >(sizeof (T) * 8 )) << " num_bits: " << num_bits;
303336
@@ -325,9 +358,9 @@ inline int BitReader::GetBatch(int num_bits, T* v, int batch_size) {
325358
326359 // unpack for uint16_t not as fast as unpack for uint32_t + memcpy.
327360 if constexpr (sizeof (T) >= sizeof (32 )) {
328- int num_unpacked = internal::unpack (buffer + byte_offset,
329- reinterpret_cast <std::make_unsigned_t <T>*>(v + i),
330- batch_size - i, num_bits );
361+ int num_unpacked =
362+ unpack (buffer + byte_offset, reinterpret_cast <std::make_unsigned_t <T>*>(v + i),
363+ batch_size - i);
331364 i += num_unpacked;
332365 byte_offset += num_unpacked * num_bits / 8 ;
333366 } else {
@@ -337,8 +370,7 @@ inline int BitReader::GetBatch(int num_bits, T* v, int batch_size) {
337370 uint32_t unpack_buffer[buffer_size];
338371 while (i < batch_size) {
339372 int unpack_size = std::min (buffer_size, batch_size - i);
340- int num_unpacked =
341- internal::unpack (buffer + byte_offset, unpack_buffer, unpack_size, num_bits);
373+ int num_unpacked = unpack (buffer + byte_offset, unpack_buffer, unpack_size);
342374 if (num_unpacked == 0 ) {
343375 break ;
344376 }
@@ -372,6 +404,11 @@ inline int BitReader::GetBatch(int num_bits, T* v, int batch_size) {
372404 return batch_size;
373405}
374406
407+ template <typename T>
408+ inline int BitReader::GetBatch (int num_bits, T* v, int batch_size) {
409+ return GetBatch (num_bits, v, batch_size, get_unpack_fn<T>(num_bits));
410+ }
411+
375412template <typename T>
376413inline bool BitReader::GetAligned (int num_bytes, T* v) {
377414 if (ARROW_PREDICT_FALSE (num_bytes > static_cast <int >(sizeof (T)))) {
0 commit comments