Skip to content

Commit a53a77c

Browse files
authored
GH-45506: [C++][Acero] More overflow-safe Swiss table (#45515)
### Rationale for this change See #45506. ### What changes are included in this PR? 1. Abstract current overflow-prone block data access into functions that do proper type promotion to avoid overflow. Also remove the old block base address accessor. 2. Unify the data types used for various concepts as they naturally are (i.e., w/o explicit promotion): `uint32_t` for `block_id`, `int` for `num_xxx_bits/bytes`, `uint32_t` for `group_id`, `int` for `local_slot_id` and `uint32_t` for `global_slot_id`. 3. Abstract several constants and utility functions for readability and maintainability. ### Are these changes tested? Existing tests should suffice. It is really hard (gosh I did try) to create a concrete test case that fails w/o this change and passes w/ this change. ### Are there any user-facing changes? None. * GitHub Issue: #45506 Authored-by: Rossi Sun <zanmato1984@gmail.com> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 27c639b commit a53a77c

File tree

5 files changed

+256
-212
lines changed

5 files changed

+256
-212
lines changed

cpp/src/arrow/acero/swiss_join.cc

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -643,37 +643,38 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
643643
//
644644
int source_group_id_bits =
645645
SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks());
646-
uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits);
647-
int64_t source_block_bytes = source_group_id_bits + 8;
646+
int source_block_bytes =
647+
SwissTable::num_block_bytes_from_num_groupid_bits(source_group_id_bits);
648+
uint32_t source_group_id_mask =
649+
SwissTable::group_id_mask_from_num_groupid_bits(source_group_id_bits);
648650
ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0);
649651

650652
// Compute index of the last block in target that corresponds to the given
651653
// partition.
652654
//
653655
ARROW_DCHECK(num_partition_bits <= target->log_blocks());
654-
int64_t target_max_block_id =
656+
uint32_t target_max_block_id =
655657
((partition_id + 1) << (target->log_blocks() - num_partition_bits)) - 1;
656658

657659
overflow_group_ids->clear();
658660
overflow_hashes->clear();
659661

660662
// For each source block...
661-
int64_t source_blocks = 1LL << source->log_blocks();
662-
for (int64_t block_id = 0; block_id < source_blocks; ++block_id) {
663-
uint8_t* block_bytes = source->blocks() + block_id * source_block_bytes;
663+
uint32_t source_blocks = 1 << source->log_blocks();
664+
for (uint32_t block_id = 0; block_id < source_blocks; ++block_id) {
665+
const uint8_t* block_bytes = source->block_data(block_id, source_block_bytes);
664666
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes);
665667

666668
// For each non-empty source slot...
667669
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;
668-
constexpr int kSlotsPerBlock = 8;
669-
int num_full_slots =
670-
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
670+
int num_full_slots = SwissTable::kSlotsPerBlock -
671+
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
671672
for (int local_slot_id = 0; local_slot_id < num_full_slots; ++local_slot_id) {
672673
// Read group id and hash for this slot.
673674
//
674-
uint64_t group_id =
675-
source->extract_group_id(block_bytes, local_slot_id, source_group_id_mask);
676-
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
675+
uint32_t group_id = SwissTable::extract_group_id(
676+
block_bytes, local_slot_id, source_group_id_bits, source_group_id_mask);
677+
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id);
677678
uint32_t hash = source->hashes()[global_slot_id];
678679
// Insert partition id into the highest bits of hash, shifting the
679680
// remaining hash bits right.
@@ -696,17 +697,18 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
696697
}
697698
}
698699

699-
inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_id,
700-
uint32_t hash, int64_t max_block_id) {
700+
inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint32_t group_id,
701+
uint32_t hash, uint32_t max_block_id) {
701702
// Load the first block to visit for this hash
702703
//
703-
int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks());
704-
int64_t block_id_mask = ((1LL << target->log_blocks()) - 1);
704+
uint32_t block_id = SwissTable::block_id_from_hash(hash, target->log_blocks());
705+
uint32_t block_id_mask = (1 << target->log_blocks()) - 1;
705706
int num_group_id_bits =
706707
SwissTable::num_groupid_bits_from_log_blocks(target->log_blocks());
707-
int64_t num_block_bytes = num_group_id_bits + sizeof(uint64_t);
708+
int num_block_bytes =
709+
SwissTable::num_block_bytes_from_num_groupid_bits(num_group_id_bits);
708710
ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0);
709-
uint8_t* block_bytes = target->blocks() + block_id * num_block_bytes;
711+
const uint8_t* block_bytes = target->block_data(block_id, num_block_bytes);
710712
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes);
711713

712714
// Search for the first block with empty slots.
@@ -715,25 +717,23 @@ inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_i
715717
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;
716718
while ((block & kHighBitOfEachByte) == 0 && block_id < max_block_id) {
717719
block_id = (block_id + 1) & block_id_mask;
718-
block_bytes = target->blocks() + block_id * num_block_bytes;
720+
block_bytes = target->block_data(block_id, num_block_bytes);
719721
block = *reinterpret_cast<const uint64_t*>(block_bytes);
720722
}
721723
if ((block & kHighBitOfEachByte) == 0) {
722724
return false;
723725
}
724-
constexpr int kSlotsPerBlock = 8;
725-
int local_slot_id =
726-
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
727-
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
728-
target->insert_into_empty_slot(static_cast<uint32_t>(global_slot_id), hash,
729-
static_cast<uint32_t>(group_id));
726+
int local_slot_id = SwissTable::kSlotsPerBlock -
727+
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
728+
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id);
729+
target->insert_into_empty_slot(global_slot_id, hash, group_id);
730730
return true;
731731
}
732732

733733
void SwissTableMerge::InsertNewGroups(SwissTable* target,
734734
const std::vector<uint32_t>& group_ids,
735735
const std::vector<uint32_t>& hashes) {
736-
int64_t num_blocks = 1LL << target->log_blocks();
736+
uint32_t num_blocks = 1 << target->log_blocks();
737737
for (size_t i = 0; i < group_ids.size(); ++i) {
738738
std::ignore = InsertNewGroup(target, group_ids[i], hashes[i], num_blocks);
739739
}
@@ -1191,7 +1191,7 @@ Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
11911191
// We want each partition to correspond to a range of block indices,
11921192
// so we also partition on the highest bits of the hash.
11931193
//
1194-
return locals.batch_hashes[i] >> (31 - log_num_prtns_) >> 1;
1194+
return locals.batch_hashes[i] >> (SwissTable::bits_hash_ - log_num_prtns_);
11951195
},
11961196
[&locals](int64_t i, int pos) {
11971197
locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i);

cpp/src/arrow/acero/swiss_join_internal.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ class SwissTableMerge {
380380
// Max block id value greater or equal to the number of blocks guarantees that
381381
// the search will not be stopped.
382382
//
383-
static inline bool InsertNewGroup(SwissTable* target, uint64_t group_id, uint32_t hash,
384-
int64_t max_block_id);
383+
static inline bool InsertNewGroup(SwissTable* target, uint32_t group_id, uint32_t hash,
384+
uint32_t max_block_id);
385385
};
386386

387387
struct SwissTableWithKeys {

0 commit comments

Comments
 (0)