Skip to content

Commit 420e207

Browse files
feat: add LazyPartitioned mode for hash join to reduce RepartitionExec overhead
This commit adds a new PartitionMode::LazyPartitioned that avoids the full build-side RepartitionExec when executing partitioned hash joins. Instead of pre-repartitioning all columns of the build table, rows are filtered lazily during hash table construction using hash(join_keys) % partition_count. Key changes: - Add LazyPartitioned variant to PartitionMode enum - Build side requests UnspecifiedDistribution (merged, no repartition) - Probe side still requests HashPartitioned distribution - Add filter_batch_by_partition() to filter build rows per partition - Update collect_left_input to accept optional partition filter - Add protobuf serialization support for new mode - Update optimizer to handle LazyPartitioned in key reordering This optimization is beneficial for wide build tables where copying all columns in RepartitionExec is expensive. Closes #19789
1 parent 472a729 commit 420e207

File tree

10 files changed

+262
-38
lines changed

10 files changed

+262
-38
lines changed

datafusion/physical-optimizer/src/enforce_distribution.rs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,34 @@ pub fn adjust_input_keys_ordering(
327327
)
328328
.map(Transformed::yes);
329329
}
330+
PartitionMode::LazyPartitioned => {
331+
// LazyPartitioned mode uses the same key reordering as Partitioned,
332+
// but with LazyPartitioned mode preserved
333+
let join_constructor = |new_conditions: (
334+
Vec<(PhysicalExprRef, PhysicalExprRef)>,
335+
Vec<SortOptions>,
336+
)| {
337+
HashJoinExec::try_new(
338+
Arc::clone(left),
339+
Arc::clone(right),
340+
new_conditions.0,
341+
filter.clone(),
342+
join_type,
343+
projection.clone(),
344+
PartitionMode::LazyPartitioned,
345+
*null_equality,
346+
*null_aware,
347+
)
348+
.map(|e| Arc::new(e) as _)
349+
};
350+
return reorder_partitioned_join_keys(
351+
requirements,
352+
on,
353+
&[],
354+
&join_constructor,
355+
)
356+
.map(Transformed::yes);
357+
}
330358
PartitionMode::CollectLeft => {
331359
// Push down requirements to the right side
332360
requirements.children[1].data = match join_type {
@@ -624,7 +652,10 @@ pub fn reorder_join_keys_to_inputs(
624652
..
625653
}) = plan_any.downcast_ref::<HashJoinExec>()
626654
{
627-
if matches!(mode, PartitionMode::Partitioned) {
655+
if matches!(
656+
mode,
657+
PartitionMode::Partitioned | PartitionMode::LazyPartitioned
658+
) {
628659
let (join_keys, positions) = reorder_current_join_keys(
629660
extract_join_keys(on),
630661
Some(left.output_partitioning()),
@@ -645,7 +676,7 @@ pub fn reorder_join_keys_to_inputs(
645676
filter.clone(),
646677
join_type,
647678
projection.clone(),
648-
PartitionMode::Partitioned,
679+
*mode,
649680
*null_equality,
650681
*null_aware,
651682
)?));
@@ -1257,6 +1288,10 @@ pub fn ensure_distribution(
12571288
//
12581289
// CollectLeft/CollectRight modes are safe because one side is collected
12591290
// to a single partition which eliminates partition-to-partition mapping.
1291+
//
1292+
// LazyPartitioned mode is also safe from this issue because the build side
1293+
// is not pre-partitioned; instead, rows are filtered locally during hash
1294+
// table construction. Only the probe side is hash-partitioned.
12601295
let is_partitioned_join = plan
12611296
.as_any()
12621297
.downcast_ref::<HashJoinExec>()

datafusion/physical-optimizer/src/join_selection.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ fn statistical_join_selection_subrule(
293293
|| partitioned_hash_join(hash_join).map(Some),
294294
|v| Ok(Some(v)),
295295
)?,
296-
PartitionMode::Partitioned => {
296+
PartitionMode::Partitioned | PartitionMode::LazyPartitioned => {
297297
let left = hash_join.left();
298298
let right = hash_join.right();
299299
// Don't swap null-aware anti joins as they have specific side requirements
@@ -302,7 +302,7 @@ fn statistical_join_selection_subrule(
302302
&& should_swap_join_order(&**left, &**right)?
303303
{
304304
hash_join
305-
.swap_inputs(PartitionMode::Partitioned)
305+
.swap_inputs(*hash_join.partition_mode())
306306
.map(Some)?
307307
} else {
308308
None
@@ -540,6 +540,9 @@ pub(crate) fn swap_join_according_to_unboundedness(
540540
(PartitionMode::Partitioned, _) => {
541541
hash_join.swap_inputs(PartitionMode::Partitioned)
542542
}
543+
(PartitionMode::LazyPartitioned, _) => {
544+
hash_join.swap_inputs(PartitionMode::LazyPartitioned)
545+
}
543546
(PartitionMode::CollectLeft, _) => {
544547
hash_join.swap_inputs(PartitionMode::CollectLeft)
545548
}

datafusion/physical-plan/src/joins/hash_join/exec.rs

Lines changed: 179 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use crate::filter_pushdown::{
2727
ChildPushdownResult, FilterDescription, FilterPushdownPhase,
2828
FilterPushdownPropagation,
2929
};
30+
use crate::hash_utils::create_hashes;
3031
use crate::joins::Map;
3132
use crate::joins::array_map::ArrayMap;
3233
use crate::joins::hash_join::inlist_builder::build_struct_inlist_values;
@@ -49,6 +50,7 @@ use crate::projection::{
4950
};
5051
use crate::repartition::REPARTITION_RANDOM_STATE;
5152
use crate::spill::get_record_batch_memory_size;
53+
use crate::stream::RecordBatchReceiverStream;
5254
use crate::{
5355
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
5456
PlanProperties, SendableRecordBatchStream, Statistics,
@@ -61,8 +63,8 @@ use crate::{
6163
metrics::{ExecutionPlanMetricsSet, MetricsSet},
6264
};
6365

64-
use arrow::array::{ArrayRef, BooleanBufferBuilder};
65-
use arrow::compute::concat_batches;
66+
use arrow::array::{ArrayRef, BooleanArray, BooleanBufferBuilder};
67+
use arrow::compute::{concat_batches, filter_record_batch};
6668
use arrow::datatypes::SchemaRef;
6769
use arrow::record_batch::RecordBatch;
6870
use arrow::util::bit_util;
@@ -740,6 +742,12 @@ impl HashJoinExec {
740742
PartitionMode::Partitioned => {
741743
symmetric_join_output_partitioning(left, right, &join_type)?
742744
}
745+
PartitionMode::LazyPartitioned => {
746+
// LazyPartitioned: output partitioning is determined by the probe (right) side,
747+
// since each partition builds its own hash table from filtered build rows.
748+
// This is similar to Partitioned mode but the build side isn't pre-partitioned.
749+
symmetric_join_output_partitioning(left, right, &join_type)?
750+
}
743751
};
744752

745753
let emission_type = if left.boundedness().is_unbounded() {
@@ -958,6 +966,15 @@ impl ExecutionPlan for HashJoinExec {
958966
Distribution::HashPartitioned(right_expr),
959967
]
960968
}
969+
PartitionMode::LazyPartitioned => {
970+
// LazyPartitioned mode: build side is NOT repartitioned (we read all
971+
// partitions and filter locally), but probe side IS hash-partitioned.
972+
let right_expr = self.on.iter().map(|(_, r)| Arc::clone(r)).collect();
973+
vec![
974+
Distribution::UnspecifiedDistribution,
975+
Distribution::HashPartitioned(right_expr),
976+
]
977+
}
961978
PartitionMode::Auto => vec![
962979
Distribution::UnspecifiedDistribution,
963980
Distribution::UnspecifiedDistribution,
@@ -1116,6 +1133,7 @@ impl ExecutionPlan for HashJoinExec {
11161133
Arc::clone(context.session_config().options()),
11171134
self.null_equality,
11181135
array_map_created_count,
1136+
None, // No partition filtering for CollectLeft mode
11191137
))
11201138
})?,
11211139
PartitionMode::Partitioned => {
@@ -1137,6 +1155,58 @@ impl ExecutionPlan for HashJoinExec {
11371155
Arc::clone(context.session_config().options()),
11381156
self.null_equality,
11391157
array_map_created_count,
1158+
None, // No partition filtering - already pre-partitioned by RepartitionExec
1159+
))
1160+
}
1161+
PartitionMode::LazyPartitioned => {
1162+
// LazyPartitioned mode: read ALL build partitions and filter locally
1163+
// by computing hash % num_partitions == current_partition.
1164+
// This avoids the overhead of RepartitionExec copying all columns.
1165+
let left_partitions = self.left.output_partitioning().partition_count();
1166+
let right_partitions = self.right.output_partitioning().partition_count();
1167+
1168+
let left_stream: SendableRecordBatchStream = if left_partitions == 1 {
1169+
// Single input partition - use directly
1170+
self.left.execute(0, Arc::clone(&context))?
1171+
} else {
1172+
// Multiple input partitions - merge all into single stream
1173+
let mut builder = RecordBatchReceiverStream::builder(
1174+
self.left.schema(),
1175+
left_partitions,
1176+
);
1177+
for part_i in 0..left_partitions {
1178+
builder.run_input(
1179+
Arc::clone(&self.left),
1180+
part_i,
1181+
Arc::clone(&context),
1182+
);
1183+
}
1184+
builder.build()
1185+
};
1186+
1187+
let reservation =
1188+
MemoryConsumer::new(format!("HashJoinInput[{partition}]"))
1189+
.register(context.memory_pool());
1190+
1191+
// Create partition filter - filter rows to only keep those belonging to this partition
1192+
let partition_filter = Some(PartitionFilter {
1193+
current_partition: partition,
1194+
total_partitions: right_partitions,
1195+
});
1196+
1197+
OnceFut::new(collect_left_input(
1198+
self.random_state.random_state().clone(),
1199+
left_stream,
1200+
on_left.clone(),
1201+
join_metrics.clone(),
1202+
reservation,
1203+
need_produce_result_in_final(self.join_type),
1204+
1, // Each partition has its own hash table
1205+
enable_dynamic_filter_pushdown,
1206+
Arc::clone(context.session_config().options()),
1207+
self.null_equality,
1208+
array_map_created_count,
1209+
partition_filter,
11401210
))
11411211
}
11421212
PartitionMode::Auto => {
@@ -1514,6 +1584,62 @@ fn should_collect_min_max_for_perfect_hash(
15141584
Ok(ArrayMap::is_supported_type(&data_type))
15151585
}
15161586

1587+
/// Partition filter configuration for lazy partitioning.
1588+
///
1589+
/// When set, only rows where `hash(join_keys) % total_partitions == current_partition`
1590+
/// are kept during hash table construction. This avoids the overhead of
1591+
/// pre-partitioning with `RepartitionExec` by filtering rows locally.
1592+
#[derive(Debug, Clone, Copy)]
1593+
struct PartitionFilter {
1594+
/// The partition index this execution is responsible for
1595+
current_partition: usize,
1596+
/// Total number of partitions in the join
1597+
total_partitions: usize,
1598+
}
1599+
1600+
/// Filters a record batch to only include rows that belong to the specified partition.
1601+
///
1602+
/// Uses the same hash function and seeds as `RepartitionExec` to ensure that
1603+
/// rows are routed consistently. Rows where `hash(join_keys) % total_partitions == current_partition`
1604+
/// are kept; all others are filtered out.
1605+
///
1606+
/// This function is used in `LazyPartitioned` mode to avoid the overhead of
1607+
/// `RepartitionExec` by filtering rows during hash table construction instead
1608+
/// of pre-partitioning.
1609+
fn filter_batch_by_partition(
1610+
batch: &RecordBatch,
1611+
on_left: &[PhysicalExprRef],
1612+
partition_filter: &PartitionFilter,
1613+
) -> Result<RecordBatch> {
1614+
let num_rows = batch.num_rows();
1615+
if num_rows == 0 {
1616+
return Ok(batch.clone());
1617+
}
1618+
1619+
// Evaluate join key columns
1620+
let arrays = evaluate_expressions_to_arrays(on_left, batch)?;
1621+
1622+
// Compute hashes using the same random state as RepartitionExec
1623+
let mut hashes_buffer = vec![0u64; num_rows];
1624+
create_hashes(
1625+
&arrays,
1626+
REPARTITION_RANDOM_STATE.random_state(),
1627+
&mut hashes_buffer,
1628+
)?;
1629+
1630+
// Create a boolean mask for rows belonging to this partition
1631+
let mask: BooleanArray = hashes_buffer
1632+
.iter()
1633+
.map(|hash| {
1634+
*hash % partition_filter.total_partitions as u64
1635+
== partition_filter.current_partition as u64
1636+
})
1637+
.collect();
1638+
1639+
// Filter the batch
1640+
Ok(filter_record_batch(batch, &mask)?)
1641+
}
1642+
15171643
/// Collects all batches from the left (build) side stream and creates a hash map for joining.
15181644
///
15191645
/// This function is responsible for:
@@ -1531,6 +1657,8 @@ fn should_collect_min_max_for_perfect_hash(
15311657
/// * `with_visited_indices_bitmap` - Whether to track visited indices (for outer joins)
15321658
/// * `probe_threads_count` - Number of threads that will probe this hash table
15331659
/// * `should_compute_dynamic_filters` - Whether to compute min/max bounds for dynamic filtering
1660+
/// * `partition_filter` - Optional partition filter for lazy partitioning mode.
1661+
/// When set, only rows belonging to the specified partition are included in the hash table.
15341662
///
15351663
/// # Dynamic Filter Coordination
15361664
/// When `should_compute_dynamic_filters` is true, this function computes the min/max bounds
@@ -1555,6 +1683,7 @@ async fn collect_left_input(
15551683
config: Arc<ConfigOptions>,
15561684
null_equality: NullEquality,
15571685
array_map_created_count: Count,
1686+
partition_filter: Option<PartitionFilter>,
15581687
) -> Result<JoinLeftData> {
15591688
let schema = left_stream.schema();
15601689

@@ -1569,28 +1698,46 @@ async fn collect_left_input(
15691698
should_compute_dynamic_filters || should_collect_min_max_for_phj,
15701699
)?;
15711700

1701+
// Clone on_left for use in the closure
1702+
let on_left_for_filter = on_left.clone();
1703+
15721704
let state = left_stream
1573-
.try_fold(initial, |mut state, batch| async move {
1574-
// Update accumulators if computing bounds
1575-
if let Some(ref mut accumulators) = state.bounds_accumulators {
1576-
for accumulator in accumulators {
1577-
accumulator.update_batch(&batch)?;
1705+
.try_fold(initial, |mut state, batch| {
1706+
let on_left_clone = on_left_for_filter.clone();
1707+
async move {
1708+
// Apply partition filter if in lazy partitioning mode
1709+
let batch = if let Some(ref pf) = partition_filter {
1710+
let filtered = filter_batch_by_partition(&batch, &on_left_clone, pf)?;
1711+
// Skip empty batches after filtering
1712+
if filtered.num_rows() == 0 {
1713+
return Ok(state);
1714+
}
1715+
filtered
1716+
} else {
1717+
batch
1718+
};
1719+
1720+
// Update accumulators if computing bounds
1721+
if let Some(ref mut accumulators) = state.bounds_accumulators {
1722+
for accumulator in accumulators {
1723+
accumulator.update_batch(&batch)?;
1724+
}
15781725
}
1579-
}
15801726

1581-
// Decide if we spill or not
1582-
let batch_size = get_record_batch_memory_size(&batch);
1583-
// Reserve memory for incoming batch
1584-
state.reservation.try_grow(batch_size)?;
1585-
// Update metrics
1586-
state.metrics.build_mem_used.add(batch_size);
1587-
state.metrics.build_input_batches.add(1);
1588-
state.metrics.build_input_rows.add(batch.num_rows());
1589-
// Update row count
1590-
state.num_rows += batch.num_rows();
1591-
// Push batch to output
1592-
state.batches.push(batch);
1593-
Ok(state)
1727+
// Decide if we spill or not
1728+
let batch_size = get_record_batch_memory_size(&batch);
1729+
// Reserve memory for incoming batch
1730+
state.reservation.try_grow(batch_size)?;
1731+
// Update metrics
1732+
state.metrics.build_mem_used.add(batch_size);
1733+
state.metrics.build_input_batches.add(1);
1734+
state.metrics.build_input_rows.add(batch.num_rows());
1735+
// Update row count
1736+
state.num_rows += batch.num_rows();
1737+
// Push batch to output
1738+
state.batches.push(batch);
1739+
Ok(state)
1740+
}
15941741
})
15951742
.await?;
15961743

@@ -1983,6 +2130,10 @@ mod tests {
19832130
left,
19842131
Partitioning::Hash(left_expr, partition_count),
19852132
)?),
2133+
PartitionMode::LazyPartitioned => {
2134+
// For LazyPartitioned, the build side is merged and filtered lazily per partition
2135+
Arc::new(CoalescePartitionsExec::new(left))
2136+
}
19862137
PartitionMode::Auto => {
19872138
return internal_err!("Unexpected PartitionMode::Auto in join tests");
19882139
}
@@ -2000,10 +2151,13 @@ mod tests {
20002151
Partitioning::Hash(partition_expr, partition_count),
20012152
)?) as _
20022153
}
2003-
PartitionMode::Partitioned => Arc::new(RepartitionExec::try_new(
2004-
right,
2005-
Partitioning::Hash(right_expr, partition_count),
2006-
)?),
2154+
PartitionMode::Partitioned | PartitionMode::LazyPartitioned => {
2155+
// For both Partitioned and LazyPartitioned, probe side is hash partitioned
2156+
Arc::new(RepartitionExec::try_new(
2157+
right,
2158+
Partitioning::Hash(right_expr, partition_count),
2159+
)?)
2160+
}
20072161
PartitionMode::Auto => {
20082162
return internal_err!("Unexpected PartitionMode::Auto in join tests");
20092163
}

0 commit comments

Comments
 (0)