Skip to content

Commit 69c31bb

Browse files
committed
Use get_unpack_fn in RleBitPacker
1 parent d123368 commit 69c31bb

File tree

2 files changed

+93
-35
lines changed

2 files changed

+93
-35
lines changed

cpp/src/arrow/util/rle_encoding_internal.h

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,23 @@ class RleRunDecoder {
316316
};
317317

318318
/// Decoder class for single run of bit-packed encoded data.
319+
///
320+
/// The value_bit_width parameter is not part of this class, as it turned oput to be more
321+
/// performant but it must remain constant for as long as this decoder is tied to a run
322+
/// (in between calls to ``Reset``).
319323
template <typename T>
320324
class BitPackedRunDecoder {
321325
public:
322326
/// The type in which the data should be decoded.
323327
using value_type = T;
324328
/// The type of run that can be decoded.
325329
using RunType = BitPackedRun;
330+
/// The function that can extract packed integer for the ``BitReader``.
331+
using UnpackFn = ::arrow::bit_util::BitReader::UnpackFn<value_type>;
332+
333+
static UnpackFn get_unpack_fn(int value_bit_width) {
334+
return bit_util::BitReader::get_unpack_fn<value_type>(value_bit_width);
335+
}
326336

327337
BitPackedRunDecoder() noexcept = default;
328338

@@ -353,21 +363,22 @@ class BitPackedRunDecoder {
353363
}
354364

355365
/// Get the next value and return false if there are no more or an error occurred.
356-
[[nodiscard]] constexpr bool Get(value_type* out_value, rle_size_t value_bit_width) {
357-
return GetBatch(out_value, 1, value_bit_width) == 1;
366+
[[nodiscard]] constexpr bool Get(value_type* out_value, rle_size_t value_bit_width,
367+
UnpackFn unpack) {
368+
return GetBatch(out_value, 1, value_bit_width, unpack) == 1;
358369
}
359370

360371
/// Get a batch of values return the number of decoded elements.
361372
/// May write fewer elements to the output than requested if there are not enough values
362373
/// left or if an error occurred.
363374
[[nodiscard]] rle_size_t GetBatch(value_type* out, rle_size_t batch_size,
364-
rle_size_t value_bit_width) {
375+
rle_size_t value_bit_width, UnpackFn unpack) {
365376
if (ARROW_PREDICT_FALSE(remaining_count_ == 0)) {
366377
return 0;
367378
}
368379

369380
const auto to_read = std::min(remaining_count_, batch_size);
370-
const auto actual_read = bit_reader_.GetBatch(value_bit_width, out, to_read);
381+
const auto actual_read = bit_reader_.GetBatch(value_bit_width, out, to_read, unpack);
371382
// There should not be any reason why the actual read would be different
372383
// but this is error resistant.
373384
remaining_count_ -= actual_read;
@@ -408,6 +419,7 @@ class RleBitPackedDecoder {
408419
parser_.Reset(data, data_size, value_bit_width);
409420
decoder_ = {};
410421
value_bit_width_ = value_bit_width;
422+
unpack_ = BitPackedRunDecoder<value_type>::get_unpack_fn(value_bit_width);
411423
}
412424

413425
/// Whether there is still runs to iterate over.
@@ -463,6 +475,17 @@ class RleBitPackedDecoder {
463475
RleBitPackedParser parser_ = {};
464476
std::variant<RleRunDecoder<value_type>, BitPackedRunDecoder<value_type>> decoder_ = {};
465477
rle_size_t value_bit_width_;
478+
typename BitPackedRunDecoder<value_type>::UnpackFn unpack_;
479+
480+
template <typename Decoder>
481+
rle_size_t DecoderGetBatch(Decoder* decoder, value_type* out, rle_size_t batch_size) {
482+
if constexpr (std::is_same_v<std::decay_t<Decoder>,
483+
BitPackedRunDecoder<value_type>>) {
484+
return decoder->GetBatch(out, batch_size, value_bit_width_, unpack_);
485+
} else {
486+
return decoder->GetBatch(out, batch_size, value_bit_width_);
487+
}
488+
}
466489

467490
/// Return the number of values that are remaining in the current run.
468491
rle_size_t run_remaining() const {
@@ -471,9 +494,8 @@ class RleBitPackedDecoder {
471494

472495
/// Get a batch of values from the current run and return the number elements read.
473496
[[nodiscard]] rle_size_t RunGetBatch(value_type* out, rle_size_t batch_size) {
474-
return std::visit(
475-
[&](auto& dec) { return dec.GetBatch(out, batch_size, value_bit_width_); },
476-
decoder_);
497+
return std::visit([&](auto& dec) { return DecoderGetBatch(&dec, out, batch_size); },
498+
decoder_);
477499
}
478500

479501
/// Call the parser with a single callable for all event types.
@@ -746,7 +768,7 @@ auto RleBitPackedDecoder<T>::GetBatch(value_type* out, rle_size_t batch_size)
746768

747769
ARROW_DCHECK_LT(values_read, batch_size);
748770
RunDecoder decoder(run, value_bit_width_);
749-
const auto read = decoder.GetBatch(out, batch_size - values_read, value_bit_width_);
771+
const auto read = DecoderGetBatch(&decoder, out, batch_size - values_read);
750772
ARROW_DCHECK_LE(read, batch_size - values_read);
751773
values_read += read;
752774
out += read;
@@ -833,7 +855,7 @@ struct GetSpacedResult {
833855

834856
/// Overload for GetSpaced for a single run in a RleDecoder
835857
template <typename Converter, typename BitRunReader, typename BitRun, typename value_type>
836-
auto RunGetSpaced(Converter* converter, typename Converter::out_type* out,
858+
auto RleGetSpaced(Converter* converter, typename Converter::out_type* out,
837859
rle_size_t batch_size, rle_size_t null_count,
838860
rle_size_t value_bit_width, BitRunReader* validity_reader,
839861
BitRun* validity_run, RleRunDecoder<value_type>* decoder)
@@ -894,10 +916,12 @@ auto RunGetSpaced(Converter* converter, typename Converter::out_type* out,
894916
}
895917

896918
template <typename Converter, typename BitRunReader, typename BitRun, typename value_type>
897-
auto RunGetSpaced(Converter* converter, typename Converter::out_type* out,
898-
rle_size_t batch_size, rle_size_t null_count,
899-
rle_size_t value_bit_width, BitRunReader* validity_reader,
900-
BitRun* validity_run, BitPackedRunDecoder<value_type>* decoder)
919+
auto BitPackedGetSpaced(Converter* converter, typename Converter::out_type* out,
920+
rle_size_t batch_size, rle_size_t null_count,
921+
rle_size_t value_bit_width,
922+
typename BitPackedRunDecoder<value_type>::UnpackFn unpack,
923+
BitRunReader* validity_reader, BitRun* validity_run,
924+
BitPackedRunDecoder<value_type>* decoder)
901925
-> GetSpacedResult<rle_size_t> {
902926
ARROW_DCHECK_GT(batch_size, 0);
903927
// The equality case is handled in the main loop in GetSpaced
@@ -932,7 +956,7 @@ auto RunGetSpaced(Converter* converter, typename Converter::out_type* out,
932956
// buffer_start is 0 at this point so size is end
933957
buffer_end = std::min(std::min(run_values_remaining(), batch.values_remaining()),
934958
kBufferCapacity);
935-
buffer_end = decoder->GetBatch(buffer.data(), buffer_size(), value_bit_width);
959+
buffer_end = decoder->GetBatch(buffer.data(), buffer_size(), value_bit_width, unpack);
936960
ARROW_DCHECK_LE(buffer_size(), kBufferCapacity);
937961

938962
if (ARROW_PREDICT_FALSE(!converter->InputIsValid(buffer.data(), buffer_size()))) {
@@ -980,19 +1004,47 @@ auto RunGetSpaced(Converter* converter, typename Converter::out_type* out,
9801004
return {/* .values_read= */ batch.values_read(), /* .null_read= */ batch.null_read()};
9811005
}
9821006

1007+
/// Overload for GetSpaced that dispatch the Decoder type
1008+
template <typename Decoder, typename Converter, typename BitRunReader, typename BitRun,
1009+
typename UnpackFunc>
1010+
auto RunGetSpaced(Converter* converter, typename Converter::out_type* out,
1011+
rle_size_t batch_size, rle_size_t null_count,
1012+
rle_size_t value_bit_width, UnpackFunc unpack,
1013+
BitRunReader* validity_reader, BitRun* validity_run, Decoder* decoder)
1014+
-> GetSpacedResult<rle_size_t> {
1015+
using value_type = typename Decoder::value_type;
1016+
1017+
if constexpr (std::is_same_v<Decoder, BitPackedRunDecoder<value_type>>) {
1018+
return BitPackedGetSpaced(converter, out, batch_size, null_count, value_bit_width,
1019+
unpack, validity_reader, validity_run, decoder);
1020+
} else {
1021+
return RleGetSpaced(converter, out, batch_size, null_count, value_bit_width,
1022+
validity_reader, validity_run, decoder);
1023+
}
1024+
}
1025+
9831026
/// Overload for GetSpaced for a single run in a decoder variant
9841027
template <typename Converter, typename BitRunReader, typename BitRun, typename value_type>
9851028
auto RunGetSpaced(
9861029
Converter* converter, typename Converter::out_type* out, rle_size_t batch_size,
987-
rle_size_t null_count, rle_size_t value_bit_width, BitRunReader* validity_reader,
988-
BitRun* validity_run,
1030+
rle_size_t null_count, rle_size_t value_bit_width,
1031+
typename BitPackedRunDecoder<value_type>::UnpackFn unpack,
1032+
BitRunReader* validity_reader, BitRun* validity_run,
9891033
std::variant<RleRunDecoder<value_type>, BitPackedRunDecoder<value_type>>* decoder)
9901034
-> GetSpacedResult<rle_size_t> {
9911035
return std::visit(
9921036
[&](auto& dec) {
9931037
ARROW_DCHECK_GT(dec.remaining(), 0);
994-
return RunGetSpaced(converter, out, batch_size, null_count, value_bit_width,
995-
validity_reader, validity_run, &dec);
1038+
if constexpr (std::is_same_v<std::decay_t<decltype(dec)>,
1039+
BitPackedRunDecoder<value_type>>) {
1040+
return BitPackedGetSpaced(converter, out, batch_size, null_count,
1041+
value_bit_width, unpack, validity_reader,
1042+
validity_run, &dec);
1043+
1044+
} else {
1045+
return RleGetSpaced(converter, out, batch_size, null_count, value_bit_width,
1046+
validity_reader, validity_run, &dec);
1047+
}
9961048
},
9971049
*decoder);
9981050
}
@@ -1035,9 +1087,9 @@ auto RleBitPackedDecoder<T>::GetSpaced(Converter converter,
10351087

10361088
// Remaining from a previous call that would have left some unread data from a run.
10371089
if (ARROW_PREDICT_FALSE(run_remaining() > 0)) {
1038-
const auto read = internal::RunGetSpaced(&converter, out, batch.total_remaining(),
1039-
batch.null_remaining(), value_bit_width_,
1040-
&validity_reader, &validity_run, &decoder_);
1090+
const auto read = internal::RunGetSpaced(
1091+
&converter, out, batch.total_remaining(), batch.null_remaining(),
1092+
value_bit_width_, unpack_, &validity_reader, &validity_run, &decoder_);
10411093

10421094
batch.AccrueReadNulls(read.null_read);
10431095
batch.AccrueReadValues(read.values_read);
@@ -1059,9 +1111,9 @@ auto RleBitPackedDecoder<T>::GetSpaced(Converter converter,
10591111

10601112
RunDecoder decoder(run, value_bit_width_);
10611113

1062-
const auto read = internal::RunGetSpaced(&converter, out, batch.total_remaining(),
1063-
batch.null_remaining(), value_bit_width_,
1064-
&validity_reader, &validity_run, &decoder);
1114+
const auto read = internal::RunGetSpaced(
1115+
&converter, out, batch.total_remaining(), batch.null_remaining(),
1116+
value_bit_width_, unpack_, &validity_reader, &validity_run, &decoder);
10651117

10661118
batch.AccrueReadNulls(read.null_read);
10671119
batch.AccrueReadValues(read.values_read);
@@ -1218,9 +1270,10 @@ auto RleBitPackedDecoder<T>::GetBatchWithDict(const V* dictionary,
12181270
};
12191271

12201272
if (ARROW_PREDICT_FALSE(run_remaining() > 0)) {
1221-
const auto read = internal::RunGetSpaced(&converter, out, batch_size,
1222-
/* null_count= */ 0, value_bit_width_,
1223-
&validity_reader, &validity_run, &decoder_);
1273+
const auto read =
1274+
internal::RunGetSpaced(&converter, out, batch_size,
1275+
/* null_count= */ 0, value_bit_width_, unpack_,
1276+
&validity_reader, &validity_run, &decoder_);
12241277

12251278
ARROW_DCHECK_EQ(read.null_read, 0);
12261279
values_read += read.values_read;
@@ -1241,9 +1294,10 @@ auto RleBitPackedDecoder<T>::GetBatchWithDict(const V* dictionary,
12411294

12421295
RunDecoder decoder(run, value_bit_width_);
12431296

1244-
const auto read = internal::RunGetSpaced(&converter, out, batch_values_remaining(),
1245-
/* null_count= */ 0, value_bit_width_,
1246-
&validity_reader, &validity_run, &decoder);
1297+
const auto read =
1298+
internal::RunGetSpaced(&converter, out, batch_values_remaining(),
1299+
/* null_count= */ 0, value_bit_width_, unpack_,
1300+
&validity_reader, &validity_run, &decoder);
12471301

12481302
ARROW_DCHECK_EQ(read.null_read, 0);
12491303
values_read += read.values_read;

cpp/src/arrow/util/rle_encoding_test.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ TEST(Rle, RleDecoder) {
334334
template <typename T>
335335
void TestBitPackedDecoder(std::vector<uint8_t> bytes, rle_size_t value_count,
336336
rle_size_t bit_width, std::vector<T> expected) {
337+
auto unpack = BitPackedRunDecoder<T>::get_unpack_fn(bit_width);
338+
337339
// Pre-requisite for this test
338340
EXPECT_GT(value_count, 6);
339341

@@ -345,7 +347,7 @@ void TestBitPackedDecoder(std::vector<uint8_t> bytes, rle_size_t value_count,
345347
EXPECT_EQ(decoder.remaining(), value_count);
346348

347349
rle_size_t read = 0;
348-
EXPECT_EQ(decoder.Get(vals.data(), bit_width), 1);
350+
EXPECT_EQ(decoder.Get(vals.data(), bit_width, unpack), 1);
349351
EXPECT_EQ(vals.at(0), expected.at(0 + read));
350352
read += 1;
351353
EXPECT_EQ(decoder.remaining(), value_count - read);
@@ -355,7 +357,7 @@ void TestBitPackedDecoder(std::vector<uint8_t> bytes, rle_size_t value_count,
355357
EXPECT_EQ(decoder.remaining(), value_count - read);
356358

357359
vals = {0, 0};
358-
EXPECT_EQ(decoder.GetBatch(vals.data(), 2, bit_width), vals.size());
360+
EXPECT_EQ(decoder.GetBatch(vals.data(), 2, bit_width, unpack), vals.size());
359361
EXPECT_EQ(vals.at(0), expected.at(0 + read));
360362
EXPECT_EQ(vals.at(1), expected.at(1 + read));
361363
read += static_cast<decltype(read)>(vals.size());
@@ -366,15 +368,15 @@ void TestBitPackedDecoder(std::vector<uint8_t> bytes, rle_size_t value_count,
366368
EXPECT_EQ(decoder.remaining(), 0);
367369
EXPECT_EQ(decoder.Advance(1, bit_width), 0);
368370
vals = {0, 0};
369-
EXPECT_EQ(decoder.Get(vals.data(), bit_width), 0);
371+
EXPECT_EQ(decoder.Get(vals.data(), bit_width, unpack), 0);
370372
EXPECT_EQ(vals.at(0), 0);
371373

372374
// Reset the decoder
373375
decoder.Reset(run, bit_width);
374376
read = 0;
375377
EXPECT_EQ(decoder.remaining(), value_count);
376378
vals = {0, 0};
377-
EXPECT_EQ(decoder.GetBatch(vals.data(), 2, bit_width), vals.size());
379+
EXPECT_EQ(decoder.GetBatch(vals.data(), 2, bit_width, unpack), vals.size());
378380
EXPECT_EQ(vals.at(0), expected.at(0 + read));
379381
EXPECT_EQ(vals.at(1), expected.at(1 + read));
380382
}
@@ -436,13 +438,15 @@ void TestRleBitPackedParser(std::vector<uint8_t> bytes, rle_size_t bit_width,
436438
}
437439

438440
auto OnBitPackedRun(BitPackedRun run) {
441+
auto unpack = BitPackedRunDecoder<T>::get_unpack_fn(bit_width_);
442+
439443
bit_packed_decoder_ptr_->Reset(run, bit_width_);
440444

441445
const auto n_decoded = decoded_ptr_->size();
442446
const auto n_to_decode = bit_packed_decoder_ptr_->remaining();
443447
decoded_ptr_->resize(n_decoded + n_to_decode);
444448
EXPECT_EQ(bit_packed_decoder_ptr_->GetBatch(decoded_ptr_->data() + n_decoded,
445-
n_to_decode, bit_width_),
449+
n_to_decode, bit_width_, unpack),
446450
n_to_decode);
447451
EXPECT_EQ(bit_packed_decoder_ptr_->remaining(), 0);
448452

0 commit comments

Comments
 (0)