Skip to content

Commit 059a4d6

Browse files
committed
add has_enough_room and unit test
1 parent 7429210 commit 059a4d6

File tree

2 files changed

+177
-5
lines changed

2 files changed

+177
-5
lines changed

native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ impl<const L_OUTER: bool, const R_OUTER: bool> FullJoiner<L_OUTER, R_OUTER> {
5656
self.lindices.len() >= self.join_params.batch_size
5757
}
5858

59+
fn has_enough_room(&self, new_size: usize) -> bool {
60+
self.lindices.len() + new_size < self.join_params.batch_size
61+
}
62+
5963
async fn flush(
6064
mut self: Pin<&mut Self>,
6165
cur1: &mut StreamCursor,
@@ -160,12 +164,22 @@ impl<const L_OUTER: bool, const R_OUTER: bool> Joiner for FullJoiner<L_OUTER, R_
160164
continue;
161165
}
162166

163-
for &lidx in &equal_lindices {
164-
for &ridx in &equal_rindices {
167+
let new_size = equal_lindices.len() * equal_rindices.len();
168+
if self.has_enough_room(new_size) {
169+
// old cartesian_product way
170+
for (&lidx, &ridx) in equal_lindices.iter().cartesian_product(&equal_rindices) {
165171
self.lindices.push(lidx);
166172
self.rindices.push(ridx);
167-
if self.should_flush() {
168-
self.as_mut().flush(cur1, cur2).await?;
173+
}
174+
} else {
175+
// do more aggressive flush
176+
for &lidx in &equal_lindices {
177+
for &ridx in &equal_rindices {
178+
self.lindices.push(lidx);
179+
self.rindices.push(ridx);
180+
if self.should_flush() {
181+
self.as_mut().flush(cur1, cur2).await?;
182+
}
169183
}
170184
}
171185
}

native-engine/datafusion-ext-plans/src/joins/test.rs

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ mod tests {
3434
physical_plan::{ExecutionPlan, common, joins::utils::*, test::TestMemoryExec},
3535
prelude::SessionContext,
3636
};
37-
37+
use datafusion::prelude::SessionConfig;
3838
use crate::{
3939
broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec,
4040
broadcast_join_exec::BroadcastJoinExec,
@@ -264,6 +264,92 @@ mod tests {
264264
Ok((columns, batches))
265265
}
266266

267+
async fn join_collect_with_batch_size(
268+
test_type: TestType,
269+
left: Arc<dyn ExecutionPlan>,
270+
right: Arc<dyn ExecutionPlan>,
271+
on: JoinOn,
272+
join_type: JoinType,
273+
batch_size: usize
274+
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
275+
MemManager::init(1000000);
276+
let session_config = SessionConfig::new().with_batch_size(batch_size);
277+
let session_ctx = SessionContext::new_with_config(session_config);
278+
let session_ctx = SessionContext::new();
279+
let task_ctx = session_ctx.task_ctx();
280+
let schema = build_join_schema_for_test(&left.schema(), &right.schema(), join_type)?;
281+
282+
let join: Arc<dyn ExecutionPlan> = match test_type {
283+
SMJ => {
284+
let sort_options = vec![SortOptions::default(); on.len()];
285+
Arc::new(SortMergeJoinExec::try_new(
286+
schema,
287+
left,
288+
right,
289+
on,
290+
join_type,
291+
sort_options,
292+
)?)
293+
}
294+
BHJLeftProbed => {
295+
let right = Arc::new(BroadcastJoinBuildHashMapExec::new(
296+
right,
297+
on.iter().map(|(_, right_key)| right_key.clone()).collect(),
298+
));
299+
Arc::new(BroadcastJoinExec::try_new(
300+
schema,
301+
left,
302+
right,
303+
on,
304+
join_type,
305+
JoinSide::Right,
306+
true,
307+
None,
308+
)?)
309+
}
310+
BHJRightProbed => {
311+
let left = Arc::new(BroadcastJoinBuildHashMapExec::new(
312+
left,
313+
on.iter().map(|(left_key, _)| left_key.clone()).collect(),
314+
));
315+
Arc::new(BroadcastJoinExec::try_new(
316+
schema,
317+
left,
318+
right,
319+
on,
320+
join_type,
321+
JoinSide::Left,
322+
true,
323+
None,
324+
)?)
325+
}
326+
SHJLeftProbed => Arc::new(BroadcastJoinExec::try_new(
327+
schema,
328+
left,
329+
right,
330+
on,
331+
join_type,
332+
JoinSide::Right,
333+
false,
334+
None,
335+
)?),
336+
SHJRightProbed => Arc::new(BroadcastJoinExec::try_new(
337+
schema,
338+
left,
339+
right,
340+
on,
341+
join_type,
342+
JoinSide::Left,
343+
false,
344+
None,
345+
)?),
346+
};
347+
let columns = columns(&join.schema());
348+
let stream = join.execute(0, task_ctx)?;
349+
let batches = common::collect(stream).await?;
350+
Ok((columns, batches))
351+
}
352+
267353
const ALL_TEST_TYPE: [TestType; 5] = [
268354
SMJ,
269355
BHJLeftProbed,
@@ -428,6 +514,78 @@ mod tests {
428514
Ok(())
429515
}
430516

517+
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
518+
async fn join_inner_batchsize() -> Result<()> {
519+
for test_type in ALL_TEST_TYPE {
520+
let left = build_table(
521+
("a1", &vec![1, 1, 1, 1, 1]),
522+
("b1", &vec![1, 2, 3, 4, 5]),
523+
("c1", &vec![1, 2, 3, 4, 5]),
524+
);
525+
let right = build_table(
526+
("a2", &vec![1, 1, 1, 1, 1, 1, 1]),
527+
("b2", &vec![1, 2, 3, 4, 5, 6, 7]),
528+
("c2", &vec![1, 2, 3, 4, 5, 6, 7]),
529+
);
530+
let on: JoinOn = vec![(
531+
Arc::new(Column::new_with_schema("a1", &left.schema())?),
532+
Arc::new(Column::new_with_schema("a2", &right.schema())?),
533+
)];
534+
let expected = vec![
535+
"+----+----+----+----+----+----+",
536+
"| a1 | b1 | c1 | a2 | b2 | c2 |",
537+
"+----+----+----+----+----+----+",
538+
"| 1 | 1 | 1 | 1 | 1 | 1 |",
539+
"| 1 | 1 | 1 | 1 | 2 | 2 |",
540+
"| 1 | 1 | 1 | 1 | 3 | 3 |",
541+
"| 1 | 1 | 1 | 1 | 4 | 4 |",
542+
"| 1 | 1 | 1 | 1 | 5 | 5 |",
543+
"| 1 | 1 | 1 | 1 | 6 | 6 |",
544+
"| 1 | 1 | 1 | 1 | 7 | 7 |",
545+
"| 1 | 2 | 2 | 1 | 1 | 1 |",
546+
"| 1 | 2 | 2 | 1 | 2 | 2 |",
547+
"| 1 | 2 | 2 | 1 | 3 | 3 |",
548+
"| 1 | 2 | 2 | 1 | 4 | 4 |",
549+
"| 1 | 2 | 2 | 1 | 5 | 5 |",
550+
"| 1 | 2 | 2 | 1 | 6 | 6 |",
551+
"| 1 | 2 | 2 | 1 | 7 | 7 |",
552+
"| 1 | 3 | 3 | 1 | 1 | 1 |",
553+
"| 1 | 3 | 3 | 1 | 2 | 2 |",
554+
"| 1 | 3 | 3 | 1 | 3 | 3 |",
555+
"| 1 | 3 | 3 | 1 | 4 | 4 |",
556+
"| 1 | 3 | 3 | 1 | 5 | 5 |",
557+
"| 1 | 3 | 3 | 1 | 6 | 6 |",
558+
"| 1 | 3 | 3 | 1 | 7 | 7 |",
559+
"| 1 | 4 | 4 | 1 | 1 | 1 |",
560+
"| 1 | 4 | 4 | 1 | 2 | 2 |",
561+
"| 1 | 4 | 4 | 1 | 3 | 3 |",
562+
"| 1 | 4 | 4 | 1 | 4 | 4 |",
563+
"| 1 | 4 | 4 | 1 | 5 | 5 |",
564+
"| 1 | 4 | 4 | 1 | 6 | 6 |",
565+
"| 1 | 4 | 4 | 1 | 7 | 7 |",
566+
"| 1 | 5 | 5 | 1 | 1 | 1 |",
567+
"| 1 | 5 | 5 | 1 | 2 | 2 |",
568+
"| 1 | 5 | 5 | 1 | 3 | 3 |",
569+
"| 1 | 5 | 5 | 1 | 4 | 4 |",
570+
"| 1 | 5 | 5 | 1 | 5 | 5 |",
571+
"| 1 | 5 | 5 | 1 | 6 | 6 |",
572+
"| 1 | 5 | 5 | 1 | 7 | 7 |",
573+
"+----+----+----+----+----+----+",
574+
];
575+
let (_, batches) = join_collect_with_batch_size(test_type, left.clone(), right.clone(), on.clone(), Inner, 2).await?;
576+
assert_batches_sorted_eq!(expected, &batches);
577+
let (_, batches) = join_collect_with_batch_size(test_type, left.clone(), right.clone(), on.clone(), Inner, 3).await?;
578+
assert_batches_sorted_eq!(expected, &batches);
579+
let (_, batches) = join_collect_with_batch_size(test_type, left.clone(), right.clone(), on.clone(), Inner, 4).await?;
580+
assert_batches_sorted_eq!(expected, &batches);
581+
let (_, batches) = join_collect_with_batch_size(test_type, left.clone(), right.clone(), on.clone(), Inner, 5).await?;
582+
assert_batches_sorted_eq!(expected, &batches);
583+
let (_, batches) = join_collect_with_batch_size(test_type, left.clone(), right.clone(), on.clone(), Inner, 7).await?;
584+
assert_batches_sorted_eq!(expected, &batches);
585+
}
586+
Ok(())
587+
}
588+
431589
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
432590
async fn join_left_one() -> Result<()> {
433591
for test_type in ALL_TEST_TYPE {

0 commit comments

Comments
 (0)