Skip to content

Commit fcb3c55

Browse files
committed
Add unit tests
1 parent b362c3a commit fcb3c55

File tree

2 files changed

+129
-6
lines changed

2 files changed

+129
-6
lines changed

cpp/src/arrow/util/rle_encoding_internal.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -661,11 +661,6 @@ auto RleBitPackedParser::PeekImpl(Handler&& handler) const
661661
uint32_t run_len_type = 0;
662662
const auto header_bytes = bit_util::ParseLeadingLEB128(data_, kMaxSize, &run_len_type);
663663

664-
if (ARROW_PREDICT_FALSE(header_bytes == 0)) {
665-
// Malformed LEB128 data
666-
return {0, ControlFlow::Break};
667-
}
668-
669664
const bool is_bit_packed = run_len_type & 1;
670665
const uint32_t count = run_len_type >> 1;
671666
if (is_bit_packed) {
@@ -691,7 +686,9 @@ auto RleBitPackedParser::PeekImpl(Handler&& handler) const
691686
bytes_read = data_size_;
692687
values_count =
693688
static_cast<rle_size_t>((bytes_read - header_bytes) * 8 / value_bit_width_);
694-
if (values_count < 1) {
689+
// Only allow errors where the bit-packed run is not padded to a multiple
690+
// of 8 values. Larger truncation should not occur.
691+
if (values_count <= static_cast<rle_size_t>((count - 1) * 8)) {
695692
return {0, ControlFlow::Break};
696693
}
697694
}

cpp/src/arrow/util/rle_encoding_test.cc

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "arrow/util/bit_util.h"
3636
#include "arrow/util/io_util.h"
3737
#include "arrow/util/rle_encoding_internal.h"
38+
#include "arrow/util/span.h"
3839

3940
namespace arrow::util {
4041

@@ -458,6 +459,29 @@ void TestRleBitPackedParser(std::vector<uint8_t> bytes, rle_size_t bit_width,
458459
EXPECT_EQ(decoded, expected);
459460
}
460461

462+
void TestRleBitPackedParserError(span<const uint8_t> bytes, rle_size_t bit_width) {
463+
auto parser =
464+
RleBitPackedParser(bytes.data(), static_cast<rle_size_t>(bytes.size()), bit_width);
465+
EXPECT_FALSE(parser.exhausted());
466+
467+
struct {
468+
auto OnRleRun(RleRun run) { return RleBitPackedParser::ControlFlow::Continue; }
469+
auto OnBitPackedRun(BitPackedRun run) {
470+
return RleBitPackedParser::ControlFlow::Continue;
471+
}
472+
} handler;
473+
474+
// Iterate over all runs
475+
parser.Parse(handler);
476+
// Non-exhaustion despite ControlFlow::Continue signals an error occurred.
477+
EXPECT_FALSE(parser.exhausted());
478+
}
479+
480+
void TestRleBitPackedParserError(const std::vector<uint8_t>& bytes,
481+
rle_size_t bit_width) {
482+
TestRleBitPackedParserError(span(bytes), bit_width);
483+
}
484+
461485
TEST(RleBitPacked, RleBitPackedParser) {
462486
TestRleBitPackedParser<uint16_t>(
463487
/* bytes= */
@@ -500,6 +524,108 @@ TEST(RleBitPacked, RleBitPackedParser) {
500524
}
501525
}
502526

527+
TEST(RleBitPacked, RleBitPackedParserInvalidNonPadded) {
528+
// GH-47981: a non-padded trailing bit-packed, produced by some non-compliant
529+
// encoders, should still be decoded successfully.
530+
531+
TestRleBitPackedParser<uint16_t>(
532+
/* bytes= */
533+
{/* LEB128 for 8 values bit packed marker */ 0x3,
534+
/* Bitpacked run */ 0x88, 0xc6},
535+
/* bit_width= */ 3,
536+
/* expected= */ {0, 1, 2, 3, 4});
537+
TestRleBitPackedParser<uint16_t>(
538+
/* bytes= */
539+
{/* LEB128 for 8 values bit packed marker */ 0x3,
540+
/* Bitpacked run */ 0x88},
541+
/* bit_width= */ 3,
542+
/* expected= */ {0, 1});
543+
TestRleBitPackedParser<uint16_t>(
544+
/* bytes= */
545+
{/* LEB128 for 8 values bit packed marker */ 0x3,
546+
/* Bitpacked run */ 0x1, 0x2, 0x3},
547+
/* bit_width= */ 8,
548+
/* expected= */ {1, 2, 3});
549+
TestRleBitPackedParser<uint16_t>(
550+
/* bytes= */
551+
{/* LEB128 for 8 values bit packed marker */ 0x3,
552+
/* Bitpacked run */ 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7},
553+
/* bit_width= */ 8,
554+
/* expected= */ {1, 2, 3, 4, 5, 6, 7});
555+
TestRleBitPackedParser<uint16_t>(
556+
/* bytes= */
557+
{/* LEB128 for 16 values bit packed marker */ 0x5,
558+
/* Bitpacked run */ 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9},
559+
/* bit_width= */ 8,
560+
/* expected= */ {1, 2, 3, 4, 5, 6, 7, 8, 9});
561+
562+
// If the trailing bit-packed declares more values than padding allows,
563+
// it's an error.
564+
565+
// 2 values encoded, 16 values declared (8 would be ok)
566+
TestRleBitPackedParserError(
567+
/* bytes= */
568+
{/* LEB128 for 16 values bit packed marker */ 0x5,
569+
/* Bitpacked run */ 0x88},
570+
/* bit_width= */ 3);
571+
// 8 values encoded, 16 values declared (8 would be ok)
572+
TestRleBitPackedParserError(
573+
/* bytes= */
574+
{/* LEB128 for 16 values bit packed marker */ 0x5,
575+
/* Bitpacked run */ 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8},
576+
/* bit_width= */ 8);
577+
578+
// If the trailing bit-packed run does not have room for at least 1 value,
579+
// it's an error.
580+
581+
TestRleBitPackedParserError(
582+
/* bytes= */
583+
{/* LEB128 for 8 values bit packed marker */ 0x3},
584+
/* bit_width= */ 3);
585+
TestRleBitPackedParserError(
586+
/* bytes= */
587+
{/* LEB128 for 8 values bit packed marker */ 0x3,
588+
/* Bitpacked run */ 0x1},
589+
/* bit_width= */ 9);
590+
}
591+
592+
TEST(RleBitPacked, RleBitPackedParserErrors) {
593+
// Truncated LEB128 header
594+
TestRleBitPackedParserError(
595+
/* bytes= */
596+
{0x81},
597+
/* bit_width= */ 3);
598+
599+
// Invalid LEB128 header for a 32-bit value
600+
TestRleBitPackedParserError(
601+
/* bytes= */
602+
{0xFF, 0xFF, 0xFF, 0xFF, 0x7f},
603+
/* bit_width= */ 3);
604+
605+
// Zero-length repeated run
606+
TestRleBitPackedParserError(
607+
/* bytes= */
608+
{0x00},
609+
/* bit_width= */ 3);
610+
TestRleBitPackedParserError(
611+
/* bytes= */
612+
{0x80, 0x00},
613+
/* bit_width= */ 3);
614+
615+
// Zero-length bit-packed run
616+
TestRleBitPackedParserError(
617+
/* bytes= */
618+
{0x01},
619+
/* bit_width= */ 3);
620+
621+
// Bit-packed run too large
622+
// (we pass a span<> on invalid memory, but only the reachable part should be read)
623+
std::vector<uint8_t> bytes = {0x80, 0x80, 0x80, 0x80, 0x01};
624+
TestRleBitPackedParserError(
625+
/* bytes= */ span(bytes.data(), 1ULL << 30),
626+
/* bit_width= */ 1);
627+
}
628+
503629
// Validates encoding of values by encoding and decoding them. If
504630
// expected_encoding != NULL, also validates that the encoded buffer is
505631
// exactly 'expected_encoding'.

0 commit comments

Comments
 (0)