Skip to content

Commit 331242e

Browse files
authored
fix(cubestore): Panic when nested Union All (#6010)
1 parent b90b3f2 commit 331242e

File tree

3 files changed

+149
-1
lines changed

3 files changed

+149
-1
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
use datafusion::error::DataFusionError;
2+
use datafusion::execution::context::ExecutionProps;
3+
use datafusion::logical_plan::{DFSchema, LogicalPlan};
4+
use datafusion::optimizer::optimizer::OptimizerRule;
5+
use datafusion::optimizer::utils;
6+
use std::sync::Arc;
7+
8+
pub struct FlattenUnion;
9+
impl OptimizerRule for FlattenUnion {
10+
fn optimize(
11+
&self,
12+
plan: &LogicalPlan,
13+
execution_props: &ExecutionProps,
14+
) -> Result<LogicalPlan, DataFusionError> {
15+
match plan {
16+
LogicalPlan::Union { inputs, schema, .. } => {
17+
let new_inputs = inputs
18+
.iter()
19+
.map(|p| self.optimize(p, execution_props))
20+
.collect::<Result<Vec<_>, _>>()?;
21+
22+
let result_inputs = try_remove_sub_union(&new_inputs, schema.clone());
23+
24+
let expr = plan.expressions().clone();
25+
26+
utils::from_plan(plan, &expr, &result_inputs)
27+
}
28+
// Rest: recurse into plan, apply optimization where possible
29+
LogicalPlan::Filter { .. }
30+
| LogicalPlan::Projection { .. }
31+
| LogicalPlan::Window { .. }
32+
| LogicalPlan::Aggregate { .. }
33+
| LogicalPlan::Repartition { .. }
34+
| LogicalPlan::CreateExternalTable { .. }
35+
| LogicalPlan::Extension { .. }
36+
| LogicalPlan::Sort { .. }
37+
| LogicalPlan::Explain { .. }
38+
| LogicalPlan::Limit { .. }
39+
| LogicalPlan::Skip { .. }
40+
| LogicalPlan::Join { .. }
41+
| LogicalPlan::CrossJoin { .. } => {
42+
// apply the optimization to all inputs of the plan
43+
let inputs = plan.inputs();
44+
let new_inputs = inputs
45+
.iter()
46+
.map(|p| self.optimize(p, execution_props))
47+
.collect::<Result<Vec<_>, _>>()?;
48+
49+
let expr = plan.expressions().clone();
50+
51+
utils::from_plan(plan, &expr, &new_inputs)
52+
}
53+
LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()),
54+
}
55+
}
56+
57+
fn name(&self) -> &str {
58+
"flatten_union"
59+
}
60+
}
61+
62+
fn try_remove_sub_union(
63+
parent_inputs: &Vec<LogicalPlan>,
64+
parent_schema: Arc<DFSchema>,
65+
) -> Vec<LogicalPlan> {
66+
let mut may_be_result = Vec::new();
67+
for inp in parent_inputs.iter() {
68+
match inp {
69+
LogicalPlan::Union { inputs, schema, .. } => {
70+
if *schema == *&parent_schema {
71+
may_be_result.extend(inputs.iter().cloned());
72+
} else {
73+
return parent_inputs.clone();
74+
}
75+
}
76+
_ => {
77+
return parent_inputs.clone();
78+
}
79+
}
80+
}
81+
return may_be_result;
82+
}

rust/cubestore/cubestore/src/queryplanner/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ mod topk;
1111
pub use topk::MIN_TOPK_STREAM_ROWS;
1212
mod coalesce;
1313
mod filter_by_key_range;
14+
mod flatten_union;
1415
pub mod info_schema;
1516
pub mod now;
1617
pub mod udfs;
@@ -21,6 +22,7 @@ use crate::config::ConfigObj;
2122
use crate::metastore::multi_index::MultiPartition;
2223
use crate::metastore::table::{Table, TablePath};
2324
use crate::metastore::{IdRow, MetaStore};
25+
use crate::queryplanner::flatten_union::FlattenUnion;
2426
use crate::queryplanner::info_schema::{
2527
SchemataInfoSchemaTableDef, SystemCacheTableDef, SystemChunksTableDef, SystemIndexesTableDef,
2628
SystemJobsTableDef, SystemPartitionsTableDef, SystemReplayHandlesTableDef,
@@ -171,7 +173,9 @@ impl QueryPlannerImpl {
171173
impl QueryPlannerImpl {
172174
async fn execution_context(&self) -> Result<Arc<ExecutionContext>, CubeError> {
173175
Ok(Arc::new(ExecutionContext::with_config(
174-
ExecutionConfig::new().add_optimizer_rule(Arc::new(MaterializeNow {})),
176+
ExecutionConfig::new()
177+
.add_optimizer_rule(Arc::new(MaterializeNow {}))
178+
.add_optimizer_rule(Arc::new(FlattenUnion {})),
175179
)))
176180
}
177181
}

rust/cubestore/cubestore/src/sql/mod.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2293,6 +2293,68 @@ mod tests {
22932293
}).await;
22942294
}
22952295

2296+
#[tokio::test]
2297+
async fn flatten_union() {
2298+
Config::test("flatten_union").start_test(async move |services| {
2299+
let service = services.sql_service;
2300+
2301+
let _ = service.exec_query("CREATE SCHEMA foo").await.unwrap();
2302+
2303+
let _ = service.exec_query("CREATE TABLE foo.a (a int, b int, c int)").await.unwrap();
2304+
let _ = service.exec_query("CREATE TABLE foo.b (a int, b int, c int)").await.unwrap();
2305+
2306+
let _ = service.exec_query("CREATE TABLE foo.a1 (a int, b int, c int)").await.unwrap();
2307+
let _ = service.exec_query("CREATE TABLE foo.b1 (a int, b int, c int)").await.unwrap();
2308+
2309+
service.exec_query(
2310+
"INSERT INTO foo.a (a, b, c) VALUES (1, 1, 1)"
2311+
).await.unwrap();
2312+
service.exec_query(
2313+
"INSERT INTO foo.b (a, b, c) VALUES (2, 2, 1)"
2314+
).await.unwrap();
2315+
service.exec_query(
2316+
"INSERT INTO foo.a1 (a, b, c) VALUES (1, 1, 2)"
2317+
).await.unwrap();
2318+
service.exec_query(
2319+
"INSERT INTO foo.b1 (a, b, c) VALUES (2, 2, 2)"
2320+
).await.unwrap();
2321+
2322+
let result = service.exec_query("EXPLAIN SELECT a, b, sum(c) from ( \
2323+
select * from ( \
2324+
select * from foo.a \
2325+
union all \
2326+
select * from foo.b \
2327+
) \
2328+
union all
2329+
select * from
2330+
( \
2331+
select * from foo.a1 \
2332+
union all \
2333+
select * from foo.b1 \
2334+
union all \
2335+
select * from foo.b \
2336+
) \
2337+
) group by 1, 2").await.unwrap();
2338+
match &result.get_rows()[0].values()[0] {
2339+
TableValue::String(s) => {
2340+
assert_eq!(s,
2341+
"Projection, [a, b, SUM(c)]\
2342+
\n Aggregate\
2343+
\n ClusterSend, indices: [[1, 2, 3, 4, 2]]\
2344+
\n Union\
2345+
\n Scan foo.a, source: CubeTable(index: default:1:[1]:sort_on[a, b]), fields: *\
2346+
\n Scan foo.b, source: CubeTable(index: default:2:[2]:sort_on[a, b]), fields: *\
2347+
\n Scan foo.a1, source: CubeTable(index: default:3:[3]:sort_on[a, b]), fields: *\
2348+
\n Scan foo.b1, source: CubeTable(index: default:4:[4]:sort_on[a, b]), fields: *\
2349+
\n Scan foo.b, source: CubeTable(index: default:2:[2]:sort_on[a, b]), fields: *"
2350+
2351+
);
2352+
}
2353+
_ => assert!(false),
2354+
};
2355+
}).await;
2356+
}
2357+
22962358
#[tokio::test]
22972359
async fn over_10k_join() {
22982360
Config::test("over_10k_join").update_config(|mut c| {

0 commit comments

Comments
 (0)