diff --git a/rust/cubesql/cubesql/src/compile/query_engine.rs b/rust/cubesql/cubesql/src/compile/query_engine.rs index c3c9fbf6b9723..4de9b441e9a58 100644 --- a/rust/cubesql/cubesql/src/compile/query_engine.rs +++ b/rust/cubesql/cubesql/src/compile/query_engine.rs @@ -10,7 +10,7 @@ use crate::{ df::{ optimizers::{FilterPushDown, FilterSplitMeta, LimitPushDown, SortPushDown}, scan::CubeScanNode, - wrapper::CubeScanWrapperNode, + wrapper::{CubeScanWrappedSqlNode, CubeScanWrapperNode}, }, udf::*, CubeContext, VariablesProvider, @@ -394,15 +394,20 @@ impl QueryEngine for SqlQueryEngine { state.get_load_request_meta("sql"), self.config_ref().clone(), )); - let mut ctx = DFSessionContext::with_state( - default_session_builder( - DFSessionConfig::new() - .create_default_catalog_and_schema(false) - .with_information_schema(false) - .with_default_catalog_and_schema("db", "public"), - ) - .with_query_planner(query_planner), - ); + let mut df_state = default_session_builder( + DFSessionConfig::new() + .create_default_catalog_and_schema(false) + .with_information_schema(false) + .with_default_catalog_and_schema("db", "public"), + ) + .with_query_planner(query_planner); + df_state + .optimizer + .rules + // projection_push_down is broken even for non-OLAP queries + // TODO enable it back + .retain(|r| r.name() != "projection_push_down"); + let mut ctx = DFSessionContext::with_state(df_state); if state.protocol == DatabaseProtocol::MySQL { let system_variable_provider = @@ -580,7 +585,11 @@ fn is_olap_query(parent: &LogicalPlan) -> Result { fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if let LogicalPlan::Extension(ext) = plan { - if let Some(_) = ext.node.as_any().downcast_ref::() { + let node = ext.node.as_any(); + if node.is::() + || node.is::() + || node.is::() + { self.0 = true; return Ok(false); diff --git a/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__explain-2.snap b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__explain-2.snap index fc8f7b4f6b450..03435d6ccd48b 100644 --- a/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__explain-2.snap +++ b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__explain-2.snap @@ -12,7 +12,6 @@ expression: "execute_query(\"EXPLAIN VERBOSE SELECT 1+1;\".to_string(),\n | logical_plan after common_sub_expression_eliminate | SAME TEXT AS ABOVE | | logical_plan after eliminate_limit | SAME TEXT AS ABOVE | | logical_plan after projection_drop_out | SAME TEXT AS ABOVE | -| logical_plan after projection_push_down | SAME TEXT AS ABOVE | | logical_plan after filter_push_down | SAME TEXT AS ABOVE | | logical_plan after limit_push_down | SAME TEXT AS ABOVE | | logical_plan after SingleDistinctAggregationToGroupBy | SAME TEXT AS ABOVE | diff --git a/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_df_execution__union_all_alias_mismatch.snap b/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_df_execution__union_all_alias_mismatch.snap new file mode 100644 index 0000000000000..466fd312237b1 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_df_execution__union_all_alias_mismatch.snap @@ -0,0 +1,9 @@ +--- +source: cubesql/src/compile/test/test_df_execution.rs +expression: "execute_query(query.to_string(), DatabaseProtocol::PostgreSQL,).await.unwrap()" +--- ++-----+-----+ +| foo | bar | ++-----+-----+ +| foo | bar | ++-----+-----+ diff --git a/rust/cubesql/cubesql/src/compile/test/test_df_execution.rs b/rust/cubesql/cubesql/src/compile/test/test_df_execution.rs index 2558517b6efe6..726466f3621da 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_df_execution.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_df_execution.rs @@ -61,3 +61,33 @@ async fn test_triple_join_with_coercion() { .await .unwrap()); } + +#[tokio::test] +async fn union_all_alias_mismatch() { + init_testing_logger(); + + // language=PostgreSQL + let query = r#" +SELECT + foo, + bar +FROM ( + SELECT + 'foo' as foo, + 'bar' as bar + UNION ALL + SELECT + 'foo' as foo, + 'bar' as qux +) t +GROUP BY + foo, bar +; + "#; + + insta::assert_snapshot!( + execute_query(query.to_string(), DatabaseProtocol::PostgreSQL,) + .await + .unwrap() + ); +}