Skip to content

Commit 7b81e3b

Browse files
fix: handle case where join keys are different for sort-merge multi-partition join (#6243)
## Changes Made the current sort merge join (multi partition) implementation does not correctly handle the case where the join keys in the left and right dataframes are different. this PR fixes this issue by doing the following: - aliasing the right keys when generating the samples for determining boundaries - renames materialized `boundaries` with right keys when applying the boundaries to create range partition tasks - regression test added to ensure fix works --------- Co-authored-by: gmweaver <gmweaver.usc@gmail.com>
1 parent 37b352a commit 7b81e3b

File tree

2 files changed

+78
-6
lines changed

2 files changed

+78
-6
lines changed

src/daft-distributed/src/pipeline_node/join/sort_merge_join.rs

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use common_metrics::ops::{NodeCategory, NodeType};
55
use daft_dsl::expr::bound_expr::BoundExpr;
66
use daft_local_plan::{LocalNodeContext, LocalPhysicalPlan};
77
use daft_logical_plan::{JoinType, stats::StatsState};
8+
use daft_recordbatch::RecordBatch;
89
use daft_schema::schema::SchemaRef;
910
use futures::{TryStreamExt, future::try_join_all};
1011

@@ -174,18 +175,35 @@ impl SortMergeJoinNode {
174175
scheduler_handle,
175176
)?;
176177

178+
let left_boundary_key_names = self
179+
.left_on
180+
.iter()
181+
.map(|expr| {
182+
expr.inner()
183+
.to_field(&self.left.config().schema)
184+
.map(|f| f.name)
185+
})
186+
.collect::<DaftResult<Vec<_>>>()?;
187+
188+
let right_sample_by_aliased = self
189+
.right_on
190+
.iter()
191+
.zip(left_boundary_key_names.into_iter())
192+
.map(|(expr, key_name)| BoundExpr::new_unchecked(expr.inner().alias(key_name)))
193+
.collect::<Vec<_>>();
194+
177195
// Sample right side
178196
let right_sample_tasks = create_sample_tasks(
179197
right_materialized.clone(),
180198
self.right.config().schema.clone(),
181-
self.right_on.clone(),
199+
right_sample_by_aliased,
182200
self.as_ref(),
183201
task_id_counter,
184202
scheduler_handle,
185203
)?;
186204

187205
// Collect all samples
188-
let sampled_outputs = try_join_all(
206+
let combined_sampled_outputs = try_join_all(
189207
left_sample_tasks
190208
.into_iter()
191209
.chain(right_sample_tasks.into_iter()),
@@ -196,8 +214,8 @@ impl SortMergeJoinNode {
196214
.collect::<Vec<_>>();
197215

198216
// Compute partition boundaries from combined samples
199-
let boundaries = get_partition_boundaries_from_samples(
200-
sampled_outputs,
217+
let left_partition_boundaries = get_partition_boundaries_from_samples(
218+
combined_sampled_outputs,
201219
&self.left_on,
202220
descending.clone(),
203221
nulls_first,
@@ -212,21 +230,40 @@ impl SortMergeJoinNode {
212230
left_schema,
213231
self.left_on.clone(),
214232
descending.clone(),
215-
boundaries.clone(),
233+
left_partition_boundaries.clone(),
216234
num_partitions,
217235
self.as_ref(),
218236
task_id_counter,
219237
scheduler_handle,
220238
)?;
221239

240+
let right_boundary_names = self
241+
.right_on
242+
.iter()
243+
.map(|expr| {
244+
expr.inner()
245+
.to_field(&self.right.config().schema)
246+
.map(|f| f.name)
247+
})
248+
.collect::<DaftResult<Vec<_>>>()?;
249+
250+
let right_partition_boundaries = RecordBatch::from_nonempty_columns(
251+
left_partition_boundaries
252+
.columns()
253+
.iter()
254+
.zip(right_boundary_names)
255+
.map(|(series, name)| series.clone().rename(name))
256+
.collect::<Vec<_>>(),
257+
)?;
258+
222259
// Range repartition right side
223260
let right_schema = self.right.config().schema.clone();
224261
let right_partition_tasks = create_range_repartition_tasks(
225262
right_materialized,
226263
right_schema,
227264
self.right_on.clone(),
228265
descending,
229-
boundaries,
266+
right_partition_boundaries,
230267
num_partitions,
231268
self.as_ref(),
232269
task_id_counter,

tests/dataframe/test_joins.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,41 @@ def test_sort_merge_join_small_partitions(make_df, with_default_morsel_size):
12731273
assert pd["rv"] == [200, 300]
12741274

12751275

1276+
@pytest.mark.parametrize("left_partitions,right_partitions", [(2, 2), (4, 3), (8, 4)])
1277+
def test_sort_merge_join_different_left_right_keys(
1278+
left_partitions, right_partitions, make_df, with_default_morsel_size
1279+
):
1280+
if get_tests_daft_runner_name() == "native":
1281+
pytest.skip("Sort-merge joins are not supported on native runner")
1282+
1283+
left = make_df(
1284+
{"left_k": [1, 2, 3], "lv": [10, 20, 30]},
1285+
repartition=left_partitions,
1286+
repartition_columns=["left_k"],
1287+
)
1288+
right = make_df(
1289+
{"right_k": [2, 3, 4], "rv": [200, 300, 400]},
1290+
repartition=right_partitions,
1291+
repartition_columns=["right_k"],
1292+
)
1293+
1294+
out = left.join(
1295+
right,
1296+
left_on="left_k",
1297+
right_on="right_k",
1298+
how="inner",
1299+
strategy="sort_merge",
1300+
).sort("left_k")
1301+
1302+
pd = out.to_pydict()
1303+
assert pd == {
1304+
"left_k": [2, 3],
1305+
"lv": [20, 30],
1306+
"right_k": [2, 3],
1307+
"rv": [200, 300],
1308+
}
1309+
1310+
12761311
@pytest.mark.parametrize(
12771312
"suffix,prefix,expected",
12781313
[

0 commit comments

Comments
 (0)