Skip to content

Commit 32fcd18

Browse files
zanmato1984pitrou
andauthored
apacheGH-44513: [C++] Fix overflow issues for large build side in swiss join (apache#45108)
### Rationale for this change apache#44513 triggers two distinct overflow issues within swiss join, both happening when the build side table contains large enough number of rows or distinct keys. (Cases at this extent of hash join build side are rather rare, so we haven't seen them reported until now): 1. The first issue is, our swiss table implementation takes the higher `N` bits of 32-bit hash value as the index to a buffer storing "block"s (a block contains `8` key - in some code also referred to as "group" - ids). This `N`-bit number is further multiplied by the size of a block, which is also related to `N`. The `N` in the case of apache#44513 is `26` and a block takes `40` bytes. So the multiply is possible to produce a number over `1 << 31` (negative when interpreted as signed 32bit). In our AVX2 specialization of accessing the block buffer https://github.com/apache/arrow/blob/0a00e25f2f6fb927fb555b69038d0be9b9d9f265/cpp/src/arrow/compute/key_map_internal_avx2.cc#L404 , the issue like apache#41813 (comment) shows up. This is the actual issue that directly produced the segfault in apache#44513. 2. The other issue is, we take `7` bits of the 32-bit hash value after `N` as a "stamp" (to quick fail the hash comparison). But when `N` is greater than `25`, some arithmetic code like https://github.com/apache/arrow/blob/0a00e25f2f6fb927fb555b69038d0be9b9d9f265/cpp/src/arrow/compute/key_map_internal.cc#L397 (`bits_hash_` is `constexpr 32`, `log_blocks_` is `N`, `bits_stamp_` is `constexpr 7`, this is to retrieve the stamp from a hash) produces `hash >> -1` aka `hash >> 0xFFFFFFFF` aka `hash >> 31` (the heading `1`s are trimmed) then the stamp value is wrong and results in false-mismatched rows. This is the reason of my false positive run in apache#44513 (comment) . ### What changes are included in this PR? For issue 1, use 64-bit index gather intrinsic to avoid the offset overflow. For issue 2, do not right-shift the hash if `N + 7 >= 32`. This is actually allowing the bits overlapping between block id (the `N` bits) and stamp (the `7` bits). Though this may introduce more false-positive hash comparisons (thus worsen the performance), I think this is still more reasonable than brutally failing for `N > 25`. I introduce two members `bits_shift_for_block_and_stamp_` and `bits_shift_for_block_`, which are derived from `log_blocks_` - esp. set to `0` and `32 - N` when `N + 7 >= 32`, this is to avoid branching like `if (log_blocks_ + bits_stamp_ > bits_hash_)` in tight loops. ### Are these changes tested? The fix is manually tested with the original case in my local. (I do have a concrete C++ UT to verify the fix but it requires too much resource and runs for too long time so it is impractical to run in any reasonable CI environment.) ### Are there any user-facing changes? None. * GitHub Issue: apache#44513 Lead-authored-by: Rossi Sun <zanmato1984@gmail.com> Co-authored-by: Antoine Pitrou <pitrou@free.fr> Signed-off-by: Rossi Sun <zanmato1984@gmail.com>
1 parent 1f63646 commit 32fcd18

File tree

3 files changed

+69
-32
lines changed

3 files changed

+69
-32
lines changed

cpp/src/arrow/compute/key_map_internal.cc

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,9 @@ void SwissTable::early_filter_imp(const int num_keys, const uint32_t* hashes,
254254
// Extract from hash: block index and stamp
255255
//
256256
uint32_t hash = hashes[i];
257-
uint32_t iblock = hash >> (bits_hash_ - bits_stamp_ - log_blocks_);
257+
uint32_t iblock = hash >> bits_shift_for_block_and_stamp_;
258258
uint32_t stamp = iblock & stamp_mask;
259-
iblock >>= bits_stamp_;
259+
iblock >>= bits_shift_for_block_;
260260

261261
uint32_t num_block_bytes = num_groupid_bits + 8;
262262
const uint8_t* blockbase =
@@ -399,7 +399,7 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl
399399
const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
400400
constexpr uint64_t stamp_mask = 0x7f;
401401
const int stamp =
402-
static_cast<int>((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask);
402+
static_cast<int>((hash >> bits_shift_for_block_and_stamp_) & stamp_mask);
403403
uint64_t start_slot_id = wrap_global_slot_id(in_slot_id);
404404
int match_found;
405405
int local_slot;
@@ -659,6 +659,9 @@ Status SwissTable::grow_double() {
659659
int num_group_id_bits_after = num_groupid_bits_from_log_blocks(log_blocks_ + 1);
660660
uint64_t group_id_mask_before = ~0ULL >> (64 - num_group_id_bits_before);
661661
int log_blocks_after = log_blocks_ + 1;
662+
int bits_shift_for_block_and_stamp_after =
663+
ComputeBitsShiftForBlockAndStamp(log_blocks_after);
664+
int bits_shift_for_block_after = ComputeBitsShiftForBlock(log_blocks_after);
662665
uint64_t block_size_before = (8 + num_group_id_bits_before);
663666
uint64_t block_size_after = (8 + num_group_id_bits_after);
664667
uint64_t block_size_total_after = (block_size_after << log_blocks_after) + padding_;
@@ -701,8 +704,7 @@ Status SwissTable::grow_double() {
701704
}
702705

703706
int ihalf = block_id_new & 1;
704-
uint8_t stamp_new =
705-
hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;
707+
uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask;
706708
uint64_t group_id_bit_offs = j * num_group_id_bits_before;
707709
uint64_t group_id =
708710
(util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >> 3)) >>
@@ -744,8 +746,7 @@ Status SwissTable::grow_double() {
744746
(util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >> 3)) >>
745747
(group_id_bit_offs & 7)) &
746748
group_id_mask_before;
747-
uint8_t stamp_new =
748-
hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;
749+
uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask;
749750

750751
uint8_t* block_base_new =
751752
blocks_new->mutable_data() + block_id_new * block_size_after;
@@ -773,6 +774,8 @@ Status SwissTable::grow_double() {
773774
blocks_ = std::move(blocks_new);
774775
hashes_ = std::move(hashes_new_buffer);
775776
log_blocks_ = log_blocks_after;
777+
bits_shift_for_block_and_stamp_ = bits_shift_for_block_and_stamp_after;
778+
bits_shift_for_block_ = bits_shift_for_block_after;
776779

777780
return Status::OK();
778781
}
@@ -784,6 +787,8 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, int log_blocks
784787
log_minibatch_ = util::MiniBatch::kLogMiniBatchLength;
785788

786789
log_blocks_ = log_blocks;
790+
bits_shift_for_block_and_stamp_ = ComputeBitsShiftForBlockAndStamp(log_blocks_);
791+
bits_shift_for_block_ = ComputeBitsShiftForBlock(log_blocks_);
787792
int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
788793
num_inserted_ = 0;
789794

@@ -820,6 +825,8 @@ void SwissTable::cleanup() {
820825
hashes_ = nullptr;
821826
}
822827
log_blocks_ = 0;
828+
bits_shift_for_block_and_stamp_ = ComputeBitsShiftForBlockAndStamp(log_blocks_);
829+
bits_shift_for_block_ = ComputeBitsShiftForBlock(log_blocks_);
823830
num_inserted_ = 0;
824831
}
825832

cpp/src/arrow/compute/key_map_internal.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,23 @@ class ARROW_EXPORT SwissTable {
203203
// Resize large hash tables when 75% full.
204204
Status grow_double();
205205

206+
// When log_blocks is greater than 25, there will be overlapping bits between block id
207+
// and stamp within a 32-bit hash value. So we must check if this is the case when
208+
// right shifting a hash value to retrieve block id and stamp. The following two
209+
// functions derive the number of bits to right shift from the given log_blocks.
210+
static int ComputeBitsShiftForBlockAndStamp(int log_blocks) {
211+
if (ARROW_PREDICT_FALSE(log_blocks + bits_stamp_ > bits_hash_)) {
212+
return 0;
213+
}
214+
return bits_hash_ - log_blocks - bits_stamp_;
215+
}
216+
static int ComputeBitsShiftForBlock(int log_blocks) {
217+
if (ARROW_PREDICT_FALSE(log_blocks + bits_stamp_ > bits_hash_)) {
218+
return bits_hash_ - log_blocks;
219+
}
220+
return bits_stamp_;
221+
}
222+
206223
// Number of hash bits stored in slots in a block.
207224
// The highest bits of hash determine block id.
208225
// The next set of highest bits is a "stamp" stored in a slot in a block.
@@ -214,6 +231,11 @@ class ARROW_EXPORT SwissTable {
214231
int log_minibatch_;
215232
// Base 2 log of the number of blocks
216233
int log_blocks_ = 0;
234+
// The following two variables are derived from log_blocks_ as log_blocks_ changes, and
235+
// used in tight loops to avoid calling the ComputeXXX functions (introducing a
236+
// branching on whether log_blocks_ + bits_stamp_ > bits_hash_).
237+
int bits_shift_for_block_and_stamp_ = ComputeBitsShiftForBlockAndStamp(log_blocks_);
238+
int bits_shift_for_block_ = ComputeBitsShiftForBlock(log_blocks_);
217239
// Number of keys inserted into hash table
218240
uint32_t num_inserted_ = 0;
219241

@@ -271,8 +293,7 @@ void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash,
271293
constexpr uint64_t stamp_mask = 0x7f;
272294

273295
int start_slot = (slot_id & 7);
274-
int stamp =
275-
static_cast<int>((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask);
296+
int stamp = static_cast<int>((hash >> bits_shift_for_block_and_stamp_) & stamp_mask);
276297
uint64_t block_id = slot_id >> 3;
277298
uint8_t* blockbase = blocks_->mutable_data() + num_block_bytes * block_id;
278299

cpp/src/arrow/compute/key_map_internal_avx2.cc

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,9 @@ int SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* h
4545
// Calculate block index and hash stamp for a byte in a block
4646
//
4747
__m256i vhash = _mm256_loadu_si256(vhash_ptr + i);
48-
__m256i vblock_id = _mm256_srlv_epi32(
49-
vhash, _mm256_set1_epi32(bits_hash_ - bits_stamp_ - log_blocks_));
48+
__m256i vblock_id = _mm256_srli_epi32(vhash, bits_shift_for_block_and_stamp_);
5049
__m256i vstamp = _mm256_and_si256(vblock_id, vstamp_mask);
51-
vblock_id = _mm256_srli_epi32(vblock_id, bits_stamp_);
50+
vblock_id = _mm256_srli_epi32(vblock_id, bits_shift_for_block_);
5251

5352
// We now split inputs and process 4 at a time,
5453
// in order to process 64-bit blocks
@@ -301,19 +300,15 @@ int SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t*
301300
_mm256_and_si256(vhash2, _mm256_set1_epi32(0xffff0000)));
302301
vhash1 = _mm256_or_si256(_mm256_srli_epi32(vhash1, 16),
303302
_mm256_and_si256(vhash3, _mm256_set1_epi32(0xffff0000)));
304-
__m256i vstamp_A = _mm256_and_si256(
305-
_mm256_srlv_epi32(vhash0, _mm256_set1_epi32(16 - log_blocks_ - 7)),
306-
_mm256_set1_epi16(0x7f));
307-
__m256i vstamp_B = _mm256_and_si256(
308-
_mm256_srlv_epi32(vhash1, _mm256_set1_epi32(16 - log_blocks_ - 7)),
309-
_mm256_set1_epi16(0x7f));
303+
__m256i vstamp_A = _mm256_and_si256(_mm256_srli_epi32(vhash0, 16 - log_blocks_ - 7),
304+
_mm256_set1_epi16(0x7f));
305+
__m256i vstamp_B = _mm256_and_si256(_mm256_srli_epi32(vhash1, 16 - log_blocks_ - 7),
306+
_mm256_set1_epi16(0x7f));
310307
__m256i vstamp = _mm256_or_si256(vstamp_A, _mm256_slli_epi16(vstamp_B, 8));
311-
__m256i vblock_id_A =
312-
_mm256_and_si256(_mm256_srlv_epi32(vhash0, _mm256_set1_epi32(16 - log_blocks_)),
313-
_mm256_set1_epi16(block_id_mask));
314-
__m256i vblock_id_B =
315-
_mm256_and_si256(_mm256_srlv_epi32(vhash1, _mm256_set1_epi32(16 - log_blocks_)),
316-
_mm256_set1_epi16(block_id_mask));
308+
__m256i vblock_id_A = _mm256_and_si256(_mm256_srli_epi32(vhash0, 16 - log_blocks_),
309+
_mm256_set1_epi16(block_id_mask));
310+
__m256i vblock_id_B = _mm256_and_si256(_mm256_srli_epi32(vhash1, 16 - log_blocks_),
311+
_mm256_set1_epi16(block_id_mask));
317312
__m256i vblock_id = _mm256_or_si256(vblock_id_A, _mm256_slli_epi16(vblock_id_B, 8));
318313

319314
// Visit all block bytes in reverse order (overwriting data on multiple matches)
@@ -392,16 +387,30 @@ int SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashe
392387
} else {
393388
for (int i = 0; i < num_keys / unroll; ++i) {
394389
__m256i hash = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(hashes) + i);
390+
// Extend hash and local_slot to 64-bit to compute 64-bit group id offsets to
391+
// gather from. This is to prevent index overflow issues in GH-44513.
392+
// NB: Use zero-extend conversion for unsigned hash.
393+
__m256i hash_lo = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(hash));
394+
__m256i hash_hi = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(hash, 1));
395395
__m256i local_slot =
396396
_mm256_set1_epi64x(reinterpret_cast<const uint64_t*>(local_slots)[i]);
397-
local_slot = _mm256_shuffle_epi8(
398-
local_slot, _mm256_setr_epi32(0x80808000, 0x80808001, 0x80808002, 0x80808003,
399-
0x80808004, 0x80808005, 0x80808006, 0x80808007));
400-
local_slot = _mm256_mullo_epi32(local_slot, _mm256_set1_epi32(byte_size));
401-
__m256i pos = _mm256_srlv_epi32(hash, _mm256_set1_epi32(bits_hash_ - log_blocks_));
402-
pos = _mm256_mullo_epi32(pos, _mm256_set1_epi32(byte_multiplier));
403-
pos = _mm256_add_epi32(pos, local_slot);
404-
__m256i group_id = _mm256_i32gather_epi32(elements, pos, 1);
397+
__m256i local_slot_lo = _mm256_shuffle_epi8(
398+
local_slot, _mm256_setr_epi32(0x80808000, 0x80808080, 0x80808001, 0x80808080,
399+
0x80808002, 0x80808080, 0x80808003, 0x80808080));
400+
__m256i local_slot_hi = _mm256_shuffle_epi8(
401+
local_slot, _mm256_setr_epi32(0x80808004, 0x80808080, 0x80808005, 0x80808080,
402+
0x80808006, 0x80808080, 0x80808007, 0x80808080));
403+
local_slot_lo = _mm256_mul_epu32(local_slot_lo, _mm256_set1_epi32(byte_size));
404+
local_slot_hi = _mm256_mul_epu32(local_slot_hi, _mm256_set1_epi32(byte_size));
405+
__m256i pos_lo = _mm256_srli_epi64(hash_lo, bits_hash_ - log_blocks_);
406+
__m256i pos_hi = _mm256_srli_epi64(hash_hi, bits_hash_ - log_blocks_);
407+
pos_lo = _mm256_mul_epu32(pos_lo, _mm256_set1_epi32(byte_multiplier));
408+
pos_hi = _mm256_mul_epu32(pos_hi, _mm256_set1_epi32(byte_multiplier));
409+
pos_lo = _mm256_add_epi64(pos_lo, local_slot_lo);
410+
pos_hi = _mm256_add_epi64(pos_hi, local_slot_hi);
411+
__m128i group_id_lo = _mm256_i64gather_epi32(elements, pos_lo, 1);
412+
__m128i group_id_hi = _mm256_i64gather_epi32(elements, pos_hi, 1);
413+
__m256i group_id = _mm256_set_m128i(group_id_hi, group_id_lo);
405414
group_id = _mm256_and_si256(group_id, _mm256_set1_epi32(mask));
406415
_mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id);
407416
}

0 commit comments

Comments
 (0)