Skip to content

Commit a08d981

Browse files
committed
chore(cubestore): Upgrade DF: Push down Sort,fetch at logical plan phase
1 parent 3d6fcb5 commit a08d981

File tree

2 files changed

+106
-63
lines changed

2 files changed

+106
-63
lines changed

rust/cubestore/cubestore-sql-tests/src/tests.rs

Lines changed: 80 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8325,12 +8325,13 @@ async fn build_range_end(service: Box<dyn SqlClient>) {
83258325
]
83268326
);
83278327
}
8328-
async fn assert_limit_pushdown(
8328+
8329+
async fn assert_limit_pushdown_using_search_string(
83298330
service: &Box<dyn SqlClient>,
83308331
query: &str,
83318332
expected_index: Option<&str>,
83328333
is_limit_expected: bool,
8333-
is_tail_limit: bool,
8334+
search_string: &str,
83348335
) -> Result<Vec<Row>, String> {
83358336
let res = service
83368337
.exec_query(&format!("EXPLAIN ANALYZE {}", query))
@@ -8347,11 +8348,7 @@ async fn assert_limit_pushdown(
83478348
));
83488349
}
83498350
}
8350-
let expected_limit = if is_tail_limit {
8351-
"TailLimit"
8352-
} else {
8353-
"GlobalLimit"
8354-
};
8351+
let expected_limit = search_string;
83558352
if is_limit_expected {
83568353
if s.find(expected_limit).is_none() {
83578354
return Err(format!("{} expected but not found", expected_limit));
@@ -8369,6 +8366,27 @@ async fn assert_limit_pushdown(
83698366
Ok(res.get_rows().clone())
83708367
}
83718368

8369+
async fn assert_limit_pushdown(
8370+
service: &Box<dyn SqlClient>,
8371+
query: &str,
8372+
expected_index: Option<&str>,
8373+
is_limit_expected: bool,
8374+
is_tail_limit: bool,
8375+
) -> Result<Vec<Row>, String> {
8376+
assert_limit_pushdown_using_search_string(
8377+
service,
8378+
query,
8379+
expected_index,
8380+
is_limit_expected,
8381+
if is_tail_limit {
8382+
"TailLimit"
8383+
} else {
8384+
"GlobalLimit"
8385+
},
8386+
)
8387+
.await
8388+
}
8389+
83728390
async fn cache_incr(service: Box<dyn SqlClient>) {
83738391
service.note_non_idempotent_migration_test();
83748392
let query = r#"CACHE INCR "prefix:key""#;
@@ -9393,7 +9411,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
93939411
.await
93949412
.unwrap();
93959413
// ====================================
9396-
let res = assert_limit_pushdown(
9414+
let res = assert_limit_pushdown_using_search_string(
93979415
&service,
93989416
"SELECT a aaa, b bbbb, c FROM (
93999417
SELECT * FROM foo.pushdown_where_group1
@@ -9404,39 +9422,46 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
94049422
ORDER BY 2 LIMIT 4",
94059423
Some("ind1"),
94069424
true,
9407-
false,
9425+
"Sort, fetch: 4",
94089426
)
94099427
.await
94109428
.unwrap();
94119429

9412-
assert_eq!(
9413-
res,
9414-
vec![
9415-
Row::new(vec![
9416-
TableValue::Int(12),
9417-
TableValue::Int(20),
9418-
TableValue::Int(4)
9419-
]),
9420-
Row::new(vec![
9421-
TableValue::Int(12),
9422-
TableValue::Int(25),
9423-
TableValue::Int(5)
9424-
]),
9425-
Row::new(vec![
9426-
TableValue::Int(12),
9427-
TableValue::Int(25),
9428-
TableValue::Int(6)
9429-
]),
9430-
Row::new(vec![
9431-
TableValue::Int(12),
9432-
TableValue::Int(30),
9433-
TableValue::Int(7)
9434-
]),
9435-
]
9436-
);
9430+
let mut expected = vec![
9431+
Row::new(vec![
9432+
TableValue::Int(12),
9433+
TableValue::Int(20),
9434+
TableValue::Int(4),
9435+
]),
9436+
Row::new(vec![
9437+
TableValue::Int(12),
9438+
TableValue::Int(25),
9439+
TableValue::Int(5),
9440+
]),
9441+
Row::new(vec![
9442+
TableValue::Int(12),
9443+
TableValue::Int(25),
9444+
TableValue::Int(6),
9445+
]),
9446+
Row::new(vec![
9447+
TableValue::Int(12),
9448+
TableValue::Int(30),
9449+
TableValue::Int(7),
9450+
]),
9451+
];
9452+
if res != expected {
9453+
// Given the query, there are two valid orderings -- (12, 25, 5) and (12, 25, 6) can be swapped.
9454+
9455+
let mut values1 = expected[1].values().clone();
9456+
let mut values2 = expected[2].values().clone();
9457+
std::mem::swap(&mut values1[2], &mut values2[2]);
9458+
expected[1] = Row::new(values1);
9459+
expected[2] = Row::new(values2);
9460+
assert_eq!(res, expected);
9461+
}
94379462

94389463
// ====================================
9439-
let res = assert_limit_pushdown(
9464+
let res = assert_limit_pushdown_using_search_string(
94409465
&service,
94419466
"SELECT a, b, c FROM (
94429467
SELECT * FROM foo.pushdown_where_group1
@@ -9446,7 +9471,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
94469471
ORDER BY 3 LIMIT 3",
94479472
Some("ind2"),
94489473
true,
9449-
false,
9474+
"Sort, fetch: 3",
94509475
)
94519476
.await
94529477
.unwrap();
@@ -9473,7 +9498,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
94739498
);
94749499
//
94759500
// ====================================
9476-
let res = assert_limit_pushdown(
9501+
let res = assert_limit_pushdown_using_search_string(
94779502
&service,
94789503
"SELECT a, b, c FROM (
94799504
SELECT * FROM foo.pushdown_where_group1
@@ -9483,7 +9508,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
94839508
ORDER BY 3 DESC LIMIT 3",
94849509
Some("ind2"),
94859510
true,
9486-
true,
9511+
"Sort, fetch: 3",
94879512
)
94889513
.await
94899514
.unwrap();
@@ -9510,7 +9535,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
95109535
);
95119536
//
95129537
// ====================================
9513-
let res = assert_limit_pushdown(
9538+
let res = assert_limit_pushdown_using_search_string(
95149539
&service,
95159540
"SELECT a, b FROM (SELECT a, b, c FROM (
95169541
SELECT * FROM foo.pushdown_where_group1
@@ -9520,7 +9545,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
95209545
ORDER BY 1, 2 LIMIT 3) x",
95219546
Some("ind1"),
95229547
true,
9523-
false,
9548+
"Sort, fetch: 3",
95249549
)
95259550
.await
95269551
.unwrap();
@@ -9546,7 +9571,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
95469571
]
95479572
);
95489573
// ====================================
9549-
let res = assert_limit_pushdown(
9574+
let res = assert_limit_pushdown_using_search_string(
95509575
&service,
95519576
"SELECT a, b FROM (SELECT a, b, c FROM (
95529577
SELECT * FROM foo.pushdown_where_group1
@@ -9556,7 +9581,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
95569581
ORDER BY 1, 2 LIMIT 2 OFFSET 1) x",
95579582
Some("ind1"),
95589583
true,
9559-
false,
9584+
"Sort, fetch: 3",
95609585
)
95619586
.await
95629587
.unwrap();
@@ -9577,7 +9602,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
95779602
]
95789603
);
95799604
// ====================================
9580-
let res = assert_limit_pushdown(
9605+
let res = assert_limit_pushdown_using_search_string(
95819606
&service,
95829607
"SELECT a, b, c FROM (
95839608
SELECT * FROM foo.pushdown_where_group1
@@ -9588,7 +9613,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
95889613
ORDER BY 1 LIMIT 3",
95899614
Some("ind1"),
95909615
true,
9591-
false,
9616+
"Sort, fetch: 3",
95929617
)
95939618
.await
95949619
.unwrap();
@@ -9609,7 +9634,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
96099634
]
96109635
);
96119636
// ====================================
9612-
let res = assert_limit_pushdown(
9637+
let res = assert_limit_pushdown_using_search_string(
96139638
&service,
96149639
"SELECT a, b, c FROM (
96159640
SELECT * FROM foo.pushdown_where_group1
@@ -9620,7 +9645,7 @@ async fn limit_pushdown_without_group(service: Box<dyn SqlClient>) {
96209645
ORDER BY 1, 3 LIMIT 3",
96219646
Some("ind1"),
96229647
true,
9623-
false,
9648+
"Sort, fetch: 3",
96249649
)
96259650
.await
96269651
.unwrap();
@@ -9683,7 +9708,7 @@ async fn limit_pushdown_without_group_resort(service: Box<dyn SqlClient>) {
96839708
.await
96849709
.unwrap();
96859710
// ====================================
9686-
let res = assert_limit_pushdown(
9711+
let res = assert_limit_pushdown_using_search_string(
96879712
&service,
96889713
"SELECT a aaa, b bbbb, c FROM (
96899714
SELECT * FROM foo.pushdown_where_group1
@@ -9694,7 +9719,7 @@ async fn limit_pushdown_without_group_resort(service: Box<dyn SqlClient>) {
96949719
ORDER BY 2 desc LIMIT 4",
96959720
Some("ind1"),
96969721
true,
9697-
true,
9722+
"Sort, fetch: 4",
96989723
)
96999724
.await
97009725
.unwrap();
@@ -9726,7 +9751,7 @@ async fn limit_pushdown_without_group_resort(service: Box<dyn SqlClient>) {
97269751
);
97279752

97289753
// ====================================
9729-
let res = assert_limit_pushdown(
9754+
let res = assert_limit_pushdown_using_search_string(
97309755
&service,
97319756
"SELECT a aaa, b bbbb, c FROM (
97329757
SELECT * FROM foo.pushdown_where_group1
@@ -9736,7 +9761,7 @@ async fn limit_pushdown_without_group_resort(service: Box<dyn SqlClient>) {
97369761
ORDER BY 1 desc, 2 desc LIMIT 3",
97379762
Some("ind1"),
97389763
true,
9739-
true,
9764+
"Sort, fetch: 3",
97409765
)
97419766
.await
97429767
.unwrap();
@@ -9836,7 +9861,7 @@ async fn limit_pushdown_unique_key(service: Box<dyn SqlClient>) {
98369861
.await
98379862
.unwrap();
98389863
// ====================================
9839-
let res = assert_limit_pushdown(
9864+
let res = assert_limit_pushdown_using_search_string(
98409865
&service,
98419866
"SELECT a, b, c FROM (
98429867
SELECT * FROM foo.pushdown_where_group1
@@ -9847,7 +9872,7 @@ async fn limit_pushdown_unique_key(service: Box<dyn SqlClient>) {
98479872
ORDER BY 2 LIMIT 4",
98489873
Some("ind1"),
98499874
true,
9850-
false,
9875+
"Sort, fetch: 4",
98519876
)
98529877
.await
98539878
.unwrap();
@@ -9874,7 +9899,7 @@ async fn limit_pushdown_unique_key(service: Box<dyn SqlClient>) {
98749899
);
98759900

98769901
// ====================================
9877-
let res = assert_limit_pushdown(
9902+
let res = assert_limit_pushdown_using_search_string(
98789903
&service,
98799904
"SELECT a, b, c FROM (
98809905
SELECT * FROM foo.pushdown_where_group1
@@ -9883,8 +9908,8 @@ async fn limit_pushdown_unique_key(service: Box<dyn SqlClient>) {
98839908
) as `tb`
98849909
ORDER BY 3 LIMIT 3",
98859910
Some("ind1"),
9886-
false,
9887-
false,
9911+
true,
9912+
"Sort, fetch: 3",
98889913
)
98899914
.await
98909915
.unwrap();

rust/cubestore/cubestore/src/queryplanner/planning.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,20 +1597,38 @@ fn pull_up_cluster_send(mut p: LogicalPlan) -> Result<LogicalPlan, DataFusionErr
15971597
LogicalPlan::Extension { .. } => return Ok(p),
15981598
// These nodes collect results from multiple partitions, return unchanged.
15991599
LogicalPlan::Aggregate { .. }
1600-
| LogicalPlan::Sort { .. }
1601-
| LogicalPlan::Limit { .. }
1602-
| LogicalPlan::Repartition { .. } => return Ok(p),
1600+
| LogicalPlan::Repartition { .. }
1601+
| LogicalPlan::Limit { .. } => return Ok(p),
1602+
// Collects results but let's push sort,fetch underneath the input.
1603+
LogicalPlan::Sort(Sort { expr, input, fetch }) => {
1604+
let Some(send) = try_extract_cluster_send(input) else {
1605+
return Ok(p);
1606+
};
1607+
let Some(fetch) = fetch else {
1608+
return Ok(p);
1609+
};
1610+
let id = send.id;
1611+
snapshots = send.snapshots.clone();
1612+
let under_sort = LogicalPlan::Sort(Sort {
1613+
expr: expr.clone(),
1614+
input: send.input.clone(),
1615+
fetch: Some(*fetch),
1616+
});
1617+
// We discard limit_and_reverse, because we add a Sort node into the plan right here.
1618+
let limit_and_reverse = None;
1619+
let new_send =
1620+
ClusterSendNode::new(id, Arc::new(under_sort), snapshots, limit_and_reverse);
1621+
*input = Arc::new(new_send.into_plan());
1622+
return Ok(p);
1623+
}
16031624
// We can always pull cluster send for these nodes.
16041625
LogicalPlan::Projection(Projection { input, .. })
16051626
| LogicalPlan::Filter(Filter { input, .. })
16061627
| LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. })
16071628
| LogicalPlan::Unnest(Unnest { input, .. }) => {
1608-
let send;
1609-
if let Some(s) = try_extract_cluster_send(input) {
1610-
send = s;
1611-
} else {
1629+
let Some(send) = try_extract_cluster_send(input) else {
16121630
return Ok(p);
1613-
}
1631+
};
16141632
let id = send.id;
16151633
snapshots = send.snapshots.clone();
16161634
let limit = send.limit_and_reverse.clone();

0 commit comments

Comments
 (0)