Skip to content

Commit 136ad9a

Browse files
authored
GH-45551: [C++][Acero] Release temp states of Swiss join building hash table to reduce memory consumption (#45552)
### Rationale for this change #45551 describes the basic idea. Some profiling from real cases follows. Take https://github.com/apache/arrow/blob/a53a77c93217399c4fda8c6328db2c492a30b0b0/cpp/src/arrow/acero/hash_join_node_test.cc#L3368 and print the memory pool stats at the end. Before this change: ``` heap stats: peak total freed current unit count reserved: 22.6 GiB 30.3 GiB 8.5 GiB 21.8 GiB not all freed! committed: 22.9 GiB 30.6 GiB 8.4 GiB 22.1 GiB not all freed! ``` After this change: ``` heap stats: peak total freed current unit count reserved: 17.5 GiB 30.3 GiB 16.0 GiB 14.3 GiB not all freed! committed: 17.8 GiB 30.5 GiB 15.8 GiB 14.7 GiB not all freed! ``` The peak memory is reduced from `22.9GB` to `17.8GB`. Though the reduction is really case-by-case, IMO this could be considered a good improvement for most general cases at zero cost. ### What changes are included in this PR? Make `hash_table_build_`, which only holds temporary states for building the final hash table, transient. And release it (via pointer) as early as possible. ### Are these changes tested? Existing tests should suffice. ### Are there any user-facing changes? None. Except that the user will see good reduction on peak memory usage :) * GitHub Issue: #45551 Authored-by: Rossi Sun <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent a53a77c commit 136ad9a

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

cpp/src/arrow/acero/swiss_join.cc

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2590,7 +2590,8 @@ class SwissJoin : public HashJoinImpl {
25902590
ColumnMetadataFromDataType(schema->data_type(HashJoinProjection::PAYLOAD, i)));
25912591
payload_types.push_back(metadata);
25922592
}
2593-
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.Init(
2593+
hash_table_build_ = std::make_unique<SwissTableForJoinBuild>();
2594+
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->Init(
25942595
&hash_table_, num_threads_, build_side_batches_.row_count(),
25952596
reject_duplicate_keys, no_payload, key_types, payload_types, pool_,
25962597
hardware_flags_)));
@@ -2609,7 +2610,8 @@ class SwissJoin : public HashJoinImpl {
26092610
DCHECK_GT(build_side_batches_[batch_id].length, 0);
26102611

26112612
const HashJoinProjectionMaps* schema = schema_[1];
2612-
bool no_payload = hash_table_build_.no_payload();
2613+
DCHECK_NE(hash_table_build_, nullptr);
2614+
bool no_payload = hash_table_build_->no_payload();
26132615

26142616
ExecBatch input_batch;
26152617
ARROW_ASSIGN_OR_RAISE(
@@ -2639,7 +2641,8 @@ class SwissJoin : public HashJoinImpl {
26392641
}
26402642
}
26412643
arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;
2642-
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.PushNextBatch(
2644+
DCHECK_NE(hash_table_build_, nullptr);
2645+
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->PushNextBatch(
26432646
static_cast<int64_t>(thread_id), key_batch, no_payload ? nullptr : &payload_batch,
26442647
temp_stack)));
26452648

@@ -2654,23 +2657,26 @@ class SwissJoin : public HashJoinImpl {
26542657
// On a single thread prepare for merging partitions of the resulting hash
26552658
// table.
26562659
//
2657-
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.PreparePrtnMerge()));
2660+
DCHECK_NE(hash_table_build_, nullptr);
2661+
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->PreparePrtnMerge()));
26582662
return CancelIfNotOK(
2659-
start_task_group_callback_(task_group_merge_, hash_table_build_.num_prtns()));
2663+
start_task_group_callback_(task_group_merge_, hash_table_build_->num_prtns()));
26602664
}
26612665

26622666
Status MergeTask(size_t /*thread_id*/, int64_t prtn_id) {
26632667
if (IsCancelled()) {
26642668
return Status::OK();
26652669
}
2666-
hash_table_build_.PrtnMerge(static_cast<int>(prtn_id));
2670+
DCHECK_NE(hash_table_build_, nullptr);
2671+
hash_table_build_->PrtnMerge(static_cast<int>(prtn_id));
26672672
return Status::OK();
26682673
}
26692674

26702675
Status MergeFinished(size_t thread_id) {
26712676
RETURN_NOT_OK(status());
26722677
arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;
2673-
hash_table_build_.FinishPrtnMerge(temp_stack);
2678+
DCHECK_NE(hash_table_build_, nullptr);
2679+
hash_table_build_->FinishPrtnMerge(temp_stack);
26742680
return CancelIfNotOK(OnBuildHashTableFinished(static_cast<int64_t>(thread_id)));
26752681
}
26762682

@@ -2679,6 +2685,9 @@ class SwissJoin : public HashJoinImpl {
26792685
return status();
26802686
}
26812687

2688+
DCHECK_NE(hash_table_build_, nullptr);
2689+
hash_table_build_.reset();
2690+
26822691
for (int i = 0; i < num_threads_; ++i) {
26832692
local_states_[i].materialize.SetBuildSide(hash_table_.keys()->keys(),
26842693
hash_table_.payloads(),
@@ -2910,7 +2919,8 @@ class SwissJoin : public HashJoinImpl {
29102919
SwissTableForJoin hash_table_;
29112920
JoinProbeProcessor probe_processor_;
29122921
JoinResidualFilter residual_filter_;
2913-
SwissTableForJoinBuild hash_table_build_;
2922+
// Temporarily used during build phase, and released afterward.
2923+
std::unique_ptr<SwissTableForJoinBuild> hash_table_build_;
29142924
AccumulationQueue build_side_batches_;
29152925

29162926
// Atomic state flags.

0 commit comments

Comments
 (0)