@@ -643,37 +643,38 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
643643 //
644644 int source_group_id_bits =
645645 SwissTable::num_groupid_bits_from_log_blocks (source->log_blocks ());
646- uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits);
647- int64_t source_block_bytes = source_group_id_bits + 8 ;
646+ int source_block_bytes =
647+ SwissTable::num_block_bytes_from_num_groupid_bits (source_group_id_bits);
648+ uint32_t source_group_id_mask =
649+ SwissTable::group_id_mask_from_num_groupid_bits (source_group_id_bits);
648650 ARROW_DCHECK (source_block_bytes % sizeof (uint64_t ) == 0 );
649651
650652 // Compute index of the last block in target that corresponds to the given
651653 // partition.
652654 //
653655 ARROW_DCHECK (num_partition_bits <= target->log_blocks ());
654- int64_t target_max_block_id =
656+ uint32_t target_max_block_id =
655657 ((partition_id + 1 ) << (target->log_blocks () - num_partition_bits)) - 1 ;
656658
657659 overflow_group_ids->clear ();
658660 overflow_hashes->clear ();
659661
660662 // For each source block...
661- int64_t source_blocks = 1LL << source->log_blocks ();
662- for (int64_t block_id = 0 ; block_id < source_blocks; ++block_id) {
663- uint8_t * block_bytes = source->blocks () + block_id * source_block_bytes;
663+ uint32_t source_blocks = 1 << source->log_blocks ();
664+ for (uint32_t block_id = 0 ; block_id < source_blocks; ++block_id) {
665+ const uint8_t * block_bytes = source->block_data ( block_id, source_block_bytes) ;
664666 uint64_t block = *reinterpret_cast <const uint64_t *>(block_bytes);
665667
666668 // For each non-empty source slot...
667669 constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL ;
668- constexpr int kSlotsPerBlock = 8 ;
669- int num_full_slots =
670- kSlotsPerBlock - static_cast <int >(ARROW_POPCOUNT64 (block & kHighBitOfEachByte ));
670+ int num_full_slots = SwissTable::kSlotsPerBlock -
671+ static_cast <int >(ARROW_POPCOUNT64 (block & kHighBitOfEachByte ));
671672 for (int local_slot_id = 0 ; local_slot_id < num_full_slots; ++local_slot_id) {
672673 // Read group id and hash for this slot.
673674 //
674- uint64_t group_id =
675- source-> extract_group_id ( block_bytes, local_slot_id, source_group_id_mask);
676- int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
675+ uint32_t group_id = SwissTable::extract_group_id (
676+ block_bytes, local_slot_id, source_group_id_bits , source_group_id_mask);
677+ uint32_t global_slot_id = SwissTable::global_slot_id ( block_id, local_slot_id) ;
677678 uint32_t hash = source->hashes ()[global_slot_id];
678679 // Insert partition id into the highest bits of hash, shifting the
679680 // remaining hash bits right.
@@ -696,17 +697,18 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
696697 }
697698}
698699
699- inline bool SwissTableMerge::InsertNewGroup (SwissTable* target, uint64_t group_id,
700- uint32_t hash, int64_t max_block_id) {
700+ inline bool SwissTableMerge::InsertNewGroup (SwissTable* target, uint32_t group_id,
701+ uint32_t hash, uint32_t max_block_id) {
701702 // Load the first block to visit for this hash
702703 //
703- int64_t block_id = hash >> ( SwissTable::bits_hash_ - target->log_blocks ());
704- int64_t block_id_mask = (( 1LL << target->log_blocks ()) - 1 ) ;
704+ uint32_t block_id = SwissTable::block_id_from_hash (hash, target->log_blocks ());
705+ uint32_t block_id_mask = (1 << target->log_blocks ()) - 1 ;
705706 int num_group_id_bits =
706707 SwissTable::num_groupid_bits_from_log_blocks (target->log_blocks ());
707- int64_t num_block_bytes = num_group_id_bits + sizeof (uint64_t );
708+ int num_block_bytes =
709+ SwissTable::num_block_bytes_from_num_groupid_bits (num_group_id_bits);
708710 ARROW_DCHECK (num_block_bytes % sizeof (uint64_t ) == 0 );
709- uint8_t * block_bytes = target->blocks () + block_id * num_block_bytes;
711+ const uint8_t * block_bytes = target->block_data ( block_id, num_block_bytes) ;
710712 uint64_t block = *reinterpret_cast <const uint64_t *>(block_bytes);
711713
712714 // Search for the first block with empty slots.
@@ -715,25 +717,23 @@ inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_i
715717 constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL ;
716718 while ((block & kHighBitOfEachByte ) == 0 && block_id < max_block_id) {
717719 block_id = (block_id + 1 ) & block_id_mask;
718- block_bytes = target->blocks () + block_id * num_block_bytes;
720+ block_bytes = target->block_data ( block_id, num_block_bytes) ;
719721 block = *reinterpret_cast <const uint64_t *>(block_bytes);
720722 }
721723 if ((block & kHighBitOfEachByte ) == 0 ) {
722724 return false ;
723725 }
724- constexpr int kSlotsPerBlock = 8 ;
725- int local_slot_id =
726- kSlotsPerBlock - static_cast <int >(ARROW_POPCOUNT64 (block & kHighBitOfEachByte ));
727- int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
728- target->insert_into_empty_slot (static_cast <uint32_t >(global_slot_id), hash,
729- static_cast <uint32_t >(group_id));
726+ int local_slot_id = SwissTable::kSlotsPerBlock -
727+ static_cast <int >(ARROW_POPCOUNT64 (block & kHighBitOfEachByte ));
728+ uint32_t global_slot_id = SwissTable::global_slot_id (block_id, local_slot_id);
729+ target->insert_into_empty_slot (global_slot_id, hash, group_id);
730730 return true ;
731731}
732732
733733void SwissTableMerge::InsertNewGroups (SwissTable* target,
734734 const std::vector<uint32_t >& group_ids,
735735 const std::vector<uint32_t >& hashes) {
736- int64_t num_blocks = 1LL << target->log_blocks ();
736+ uint32_t num_blocks = 1 << target->log_blocks ();
737737 for (size_t i = 0 ; i < group_ids.size (); ++i) {
738738 std::ignore = InsertNewGroup (target, group_ids[i], hashes[i], num_blocks);
739739 }
@@ -1191,7 +1191,7 @@ Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
11911191 // We want each partition to correspond to a range of block indices,
11921192 // so we also partition on the highest bits of the hash.
11931193 //
1194- return locals.batch_hashes [i] >> (31 - log_num_prtns_) >> 1 ;
1194+ return locals.batch_hashes [i] >> (SwissTable::bits_hash_ - log_num_prtns_);
11951195 },
11961196 [&locals](int64_t i, int pos) {
11971197 locals.batch_prtn_row_ids [pos] = static_cast <uint16_t >(i);
0 commit comments