Skip to content

Commit f43a5a5

Browse files
authored
feat: automatically materialize CTEs (#18123)
* feat: automatically materialize CTEs * add setting * not materialize r_cte * fix ut * fix explain * make lint * fix ut * fix explain * fix #18143 * fix expr replacer * fix missing header * use visitor trait * fix
1 parent e3dc218 commit f43a5a5

File tree

17 files changed

+246
-511
lines changed

17 files changed

+246
-511
lines changed

src/query/ast/src/ast/expr.rs

Lines changed: 0 additions & 456 deletions
Large diffs are not rendered by default.

src/query/expression/src/types.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use std::ops::Range;
4646

4747
use borsh::BorshDeserialize;
4848
use borsh::BorshSerialize;
49+
use databend_common_ast::ast::TypeName;
4950
pub use databend_common_io::deserialize_bitmap;
5051
use enum_as_inner::EnumAsInner;
5152
use serde::Deserialize;
@@ -373,6 +374,52 @@ impl DataType {
373374
}
374375
}
375376

377+
pub fn convert_to_type_name(ty: &DataType) -> TypeName {
378+
match ty {
379+
DataType::Boolean => TypeName::Boolean,
380+
DataType::Number(NumberDataType::UInt8) => TypeName::UInt8,
381+
DataType::Number(NumberDataType::UInt16) => TypeName::UInt16,
382+
DataType::Number(NumberDataType::UInt32) => TypeName::UInt32,
383+
DataType::Number(NumberDataType::UInt64) => TypeName::UInt64,
384+
DataType::Number(NumberDataType::Int8) => TypeName::Int8,
385+
DataType::Number(NumberDataType::Int16) => TypeName::Int16,
386+
DataType::Number(NumberDataType::Int32) => TypeName::Int32,
387+
DataType::Number(NumberDataType::Int64) => TypeName::Int64,
388+
DataType::Number(NumberDataType::Float32) => TypeName::Float32,
389+
DataType::Number(NumberDataType::Float64) => TypeName::Float64,
390+
DataType::Decimal(size) => TypeName::Decimal {
391+
precision: size.precision(),
392+
scale: size.scale(),
393+
},
394+
DataType::Date => TypeName::Date,
395+
DataType::Timestamp => TypeName::Timestamp,
396+
DataType::String => TypeName::String,
397+
DataType::Bitmap => TypeName::Bitmap,
398+
DataType::Variant => TypeName::Variant,
399+
DataType::Binary => TypeName::Binary,
400+
DataType::Geometry => TypeName::Geometry,
401+
DataType::Nullable(box inner_ty) => {
402+
TypeName::Nullable(Box::new(convert_to_type_name(inner_ty)))
403+
}
404+
DataType::Array(box inner_ty) => TypeName::Array(Box::new(convert_to_type_name(inner_ty))),
405+
DataType::Map(box inner_ty) => match inner_ty {
406+
DataType::Tuple(inner_tys) => TypeName::Map {
407+
key_type: Box::new(convert_to_type_name(&inner_tys[0])),
408+
val_type: Box::new(convert_to_type_name(&inner_tys[1])),
409+
},
410+
_ => unreachable!(),
411+
},
412+
DataType::Tuple(inner_tys) => TypeName::Tuple {
413+
fields_name: None,
414+
fields_type: inner_tys
415+
.iter()
416+
.map(convert_to_type_name)
417+
.collect::<Vec<_>>(),
418+
},
419+
_ => TypeName::String,
420+
}
421+
}
422+
376423
/// [AccessType] defines a series of access methods for a data type
377424
pub trait AccessType: Debug + Clone + PartialEq + Sized + 'static {
378425
type Scalar: Debug + Clone + PartialEq;

src/query/service/tests/it/sql/planner/optimizer/optimizer_test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ async fn test_optimizer() -> Result<()> {
425425
let suite = TestSuite::new(base_path.clone(), subdir);
426426
let fixture = TestFixture::setup().await?;
427427
let ctx = fixture.new_query_ctx().await?;
428+
ctx.get_settings().set_enable_auto_materialize_cte(0)?;
428429

429430
suite.setup_tables(&ctx).await?;
430431

src/query/settings/src/settings_default.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,6 +1333,13 @@ impl DefaultSettings {
13331333
scope: SettingScope::Both,
13341334
range: Some(SettingRange::Numeric(0..=1)),
13351335
}),
1336+
("enable_auto_materialize_cte", DefaultSettingValue {
1337+
value: UserSettingValue::UInt64(1),
1338+
desc: "Enables auto materialize CTE, 0 for disable, 1 for enable",
1339+
mode: SettingMode::Both,
1340+
scope: SettingScope::Both,
1341+
range: Some(SettingRange::Numeric(0..=1)),
1342+
}),
13361343
]);
13371344

13381345
Ok(Arc::new(DefaultSettings {

src/query/settings/src/settings_getter_setter.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,14 @@ impl Settings {
980980
Ok(self.try_get_u64("enable_experimental_virtual_column")? == 1)
981981
}
982982

983+
pub fn get_enable_auto_materialize_cte(&self) -> Result<bool> {
984+
Ok(self.try_get_u64("enable_auto_materialize_cte")? == 1)
985+
}
986+
987+
pub fn set_enable_auto_materialize_cte(&self, val: u64) -> Result<()> {
988+
self.try_set_u64("enable_auto_materialize_cte", val)
989+
}
990+
983991
pub fn get_max_aggregate_restore_worker(&self) -> Result<u64> {
984992
self.try_get_u64("max_aggregate_restore_worker")
985993
}

src/query/sql/src/planner/binder/bind_query/bind.rs

Lines changed: 152 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,31 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::collections::HashMap;
1516
use std::sync::Arc;
1617

18+
use databend_common_ast::ast::ColumnDefinition;
1719
use databend_common_ast::ast::CreateOption;
20+
use databend_common_ast::ast::CreateTableSource;
1821
use databend_common_ast::ast::CreateTableStmt;
1922
use databend_common_ast::ast::Engine;
2023
use databend_common_ast::ast::Expr;
21-
use databend_common_ast::ast::ExprReplacer;
2224
use databend_common_ast::ast::Identifier;
2325
use databend_common_ast::ast::Query;
2426
use databend_common_ast::ast::SetExpr;
27+
use databend_common_ast::ast::TableReference;
2528
use databend_common_ast::ast::TableType;
2629
use databend_common_ast::ast::With;
2730
use databend_common_ast::ast::CTE;
2831
use databend_common_ast::Span;
2932
use databend_common_catalog::catalog::CATALOG_DEFAULT;
3033
use databend_common_exception::ErrorCode;
3134
use databend_common_exception::Result;
35+
use databend_common_expression::types::convert_to_type_name;
36+
use derive_visitor::Drive;
37+
use derive_visitor::DriveMut;
38+
use derive_visitor::Visitor;
39+
use derive_visitor::VisitorMut;
3240

3341
use crate::binder::CteInfo;
3442
use crate::normalize_identifier;
@@ -40,15 +48,42 @@ use crate::plans::BoundColumnRef;
4048
use crate::plans::ScalarExpr;
4149
use crate::plans::Sort;
4250
use crate::plans::SortItem;
51+
use crate::NameResolutionContext;
52+
53+
#[derive(Debug, Default, Visitor)]
54+
#[visitor(TableReference(enter))]
55+
struct CTERefCounter {
56+
cte_ref_count: HashMap<String, usize>,
57+
name_resolution_ctx: NameResolutionContext,
58+
}
59+
60+
impl CTERefCounter {
61+
fn enter_table_reference(&mut self, table_ref: &TableReference) {
62+
if let TableReference::Table { table, .. } = table_ref {
63+
let table_name = normalize_identifier(table, &self.name_resolution_ctx).name;
64+
if let Some(count) = self.cte_ref_count.get_mut(&table_name) {
65+
*count += 1;
66+
}
67+
}
68+
}
69+
}
4370

4471
impl Binder {
4572
pub(crate) fn bind_query(
4673
&mut self,
4774
bind_context: &mut BindContext,
4875
query: &Query,
4976
) -> Result<(SExpr, BindContext)> {
77+
let mut with = query.with.clone();
78+
if self.ctx.get_settings().get_enable_auto_materialize_cte()? {
79+
if let Some(with) = &mut with {
80+
if !with.recursive {
81+
self.auto_materialize_cte(with, query)?;
82+
}
83+
}
84+
}
5085
// Initialize cte map.
51-
self.init_cte(bind_context, &query.with)?;
86+
self.init_cte(bind_context, &with)?;
5287

5388
// Extract limit and offset from query.
5489
let (limit, offset) = self.extract_limit_and_offset(query)?;
@@ -66,6 +101,35 @@ impl Binder {
66101
Ok((s_expr, bind_context))
67102
}
68103

104+
fn auto_materialize_cte(&mut self, with: &mut With, query: &Query) -> Result<()> {
105+
// Initialize the count of each CTE to 0
106+
let mut cte_ref_count: HashMap<String, usize> = HashMap::new();
107+
for cte in with.ctes.iter() {
108+
let table_name = self.normalize_identifier(&cte.alias.name).name;
109+
cte_ref_count.insert(table_name, 0);
110+
}
111+
112+
// Count the number of times each CTE is referenced in the query
113+
let mut visitor = CTERefCounter {
114+
cte_ref_count,
115+
name_resolution_ctx: self.name_resolution_ctx.clone(),
116+
};
117+
query.drive(&mut visitor);
118+
cte_ref_count = visitor.cte_ref_count;
119+
120+
// Update materialization based on reference count
121+
for cte in with.ctes.iter_mut() {
122+
let table_name = self.normalize_identifier(&cte.alias.name).name;
123+
if let Some(count) = cte_ref_count.get(&table_name) {
124+
log::info!("[CTE]cte_ref_count: {table_name} {count}");
125+
// Materialize if referenced more than once
126+
cte.materialized |= *count > 1;
127+
}
128+
}
129+
130+
Ok(())
131+
}
132+
69133
// Initialize cte map.
70134
pub(crate) fn init_cte(
71135
&mut self,
@@ -205,7 +269,11 @@ impl Binder {
205269
)));
206270
}
207271

208-
let expr_replacer = ExprReplacer::new(database.clone(), self.m_cte_table_name.clone());
272+
let mut expr_replacer = TableNameReplacer::new(
273+
database.clone(),
274+
self.m_cte_table_name.clone(),
275+
self.name_resolution_ctx.clone(),
276+
);
209277
let mut as_query = cte.query.clone();
210278
with.ctes.truncate(cte_index);
211279
with.ctes.retain(|cte| !cte.materialized);
@@ -214,15 +282,42 @@ impl Binder {
214282
} else {
215283
None
216284
};
217-
expr_replacer.replace_query(&mut as_query);
285+
as_query.drive_mut(&mut expr_replacer);
286+
287+
let source = if cte.alias.columns.is_empty() {
288+
None
289+
} else {
290+
let mut bind_context = BindContext::new();
291+
let (_, bind_context) = self.bind_query(&mut bind_context, &as_query)?;
292+
let columns = &bind_context.columns;
293+
if columns.len() != cte.alias.columns.len() {
294+
return Err(ErrorCode::Internal("Number of columns does not match"));
295+
}
296+
Some(CreateTableSource::Columns(
297+
columns
298+
.iter()
299+
.zip(cte.alias.columns.iter())
300+
.map(|(column, ident)| {
301+
let data_type = convert_to_type_name(&column.data_type);
302+
ColumnDefinition {
303+
name: ident.clone(),
304+
data_type,
305+
expr: None,
306+
comment: None,
307+
}
308+
})
309+
.collect(),
310+
None,
311+
))
312+
};
218313

219314
let catalog = self.ctx.get_current_catalog();
220315
let create_table_stmt = CreateTableStmt {
221316
create_option: CreateOption::Create,
222317
catalog: Some(Identifier::from_name(Span::None, catalog.clone())),
223318
database: Some(Identifier::from_name(Span::None, database.clone())),
224319
table: table_identifier,
225-
source: None,
320+
source,
226321
engine: Some(engine),
227322
uri_location: None,
228323
cluster_by: None,
@@ -234,6 +329,7 @@ impl Binder {
234329
};
235330

236331
let create_table_sql = create_table_stmt.to_string();
332+
log::info!("[CTE]create_table_sql: {create_table_sql}");
237333
if let Some(subquery_executor) = &self.subquery_executor {
238334
let _ = databend_common_base::runtime::block_on(async move {
239335
subquery_executor
@@ -250,3 +346,54 @@ impl Binder {
250346
.evict_table_from_cache(&catalog, &database, &table_name)
251347
}
252348
}
349+
350+
#[derive(VisitorMut)]
351+
#[visitor(TableReference(enter), Expr(enter))]
352+
pub struct TableNameReplacer {
353+
database: String,
354+
new_name: HashMap<String, String>,
355+
name_resolution_ctx: NameResolutionContext,
356+
}
357+
358+
impl TableNameReplacer {
359+
pub fn new(
360+
database: String,
361+
new_name: HashMap<String, String>,
362+
name_resolution_ctx: NameResolutionContext,
363+
) -> Self {
364+
Self {
365+
database,
366+
new_name,
367+
name_resolution_ctx,
368+
}
369+
}
370+
371+
fn replace_identifier(&mut self, identifier: &mut Identifier) {
372+
let name = normalize_identifier(identifier, &self.name_resolution_ctx).name;
373+
if let Some(new_name) = self.new_name.get(&name) {
374+
identifier.name = new_name.clone();
375+
}
376+
}
377+
378+
fn enter_table_reference(&mut self, table_reference: &mut TableReference) {
379+
if let TableReference::Table {
380+
database, table, ..
381+
} = table_reference
382+
{
383+
if database.is_none() || database.as_ref().unwrap().name == self.database {
384+
self.replace_identifier(table);
385+
}
386+
}
387+
}
388+
389+
fn enter_expr(&mut self, expr: &mut Expr) {
390+
if let Expr::ColumnRef { column, .. } = expr {
391+
if column.database.is_none() || column.database.as_ref().unwrap().name == self.database
392+
{
393+
if let Some(table_identifier) = &mut column.table {
394+
self.replace_identifier(table_identifier);
395+
}
396+
}
397+
}
398+
}
399+
}

src/tests/sqlsmith/src/sql_gen/expr.rs

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ use databend_common_ast::ast::Literal;
2424
use databend_common_ast::ast::MapAccessor;
2525
use databend_common_ast::ast::SubqueryModifier;
2626
use databend_common_ast::ast::TrimWhere;
27-
use databend_common_ast::ast::TypeName;
2827
use databend_common_ast::ast::UnaryOperator;
28+
use databend_common_expression::types::convert_to_type_name;
2929
use databend_common_expression::types::DataType;
3030
use databend_common_expression::types::NumberDataType;
3131
use ethnum::I256;
@@ -881,49 +881,3 @@ impl<R: Rng> SqlGenerator<'_, R> {
881881
}
882882
}
883883
}
884-
885-
fn convert_to_type_name(ty: &DataType) -> TypeName {
886-
match ty {
887-
DataType::Boolean => TypeName::Boolean,
888-
DataType::Number(NumberDataType::UInt8) => TypeName::UInt8,
889-
DataType::Number(NumberDataType::UInt16) => TypeName::UInt16,
890-
DataType::Number(NumberDataType::UInt32) => TypeName::UInt32,
891-
DataType::Number(NumberDataType::UInt64) => TypeName::UInt64,
892-
DataType::Number(NumberDataType::Int8) => TypeName::Int8,
893-
DataType::Number(NumberDataType::Int16) => TypeName::Int16,
894-
DataType::Number(NumberDataType::Int32) => TypeName::Int32,
895-
DataType::Number(NumberDataType::Int64) => TypeName::Int64,
896-
DataType::Number(NumberDataType::Float32) => TypeName::Float32,
897-
DataType::Number(NumberDataType::Float64) => TypeName::Float64,
898-
DataType::Decimal(size) => TypeName::Decimal {
899-
precision: size.precision(),
900-
scale: size.scale(),
901-
},
902-
DataType::Date => TypeName::Date,
903-
DataType::Timestamp => TypeName::Timestamp,
904-
DataType::String => TypeName::String,
905-
DataType::Bitmap => TypeName::Bitmap,
906-
DataType::Variant => TypeName::Variant,
907-
DataType::Binary => TypeName::Binary,
908-
DataType::Geometry => TypeName::Geometry,
909-
DataType::Nullable(box inner_ty) => {
910-
TypeName::Nullable(Box::new(convert_to_type_name(inner_ty)))
911-
}
912-
DataType::Array(box inner_ty) => TypeName::Array(Box::new(convert_to_type_name(inner_ty))),
913-
DataType::Map(box inner_ty) => match inner_ty {
914-
DataType::Tuple(inner_tys) => TypeName::Map {
915-
key_type: Box::new(convert_to_type_name(&inner_tys[0])),
916-
val_type: Box::new(convert_to_type_name(&inner_tys[1])),
917-
},
918-
_ => unreachable!(),
919-
},
920-
DataType::Tuple(inner_tys) => TypeName::Tuple {
921-
fields_name: None,
922-
fields_type: inner_tys
923-
.iter()
924-
.map(convert_to_type_name)
925-
.collect::<Vec<_>>(),
926-
},
927-
_ => TypeName::String,
928-
}
929-
}

tests/sqllogictests/suites/mode/cluster/exchange.test

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
statement ok
2+
set enable_auto_materialize_cte = 0;
3+
14
query T
25
explain select * from numbers(1) t, numbers(2) t1 where t.number = t1.number
36
----

0 commit comments

Comments
 (0)