diff --git a/src/query/service/src/physical_plans/format/format_table_scan.rs b/src/query/service/src/physical_plans/format/format_table_scan.rs index 96c0a0bbcc567..c02ae39a42fea 100644 --- a/src/query/service/src/physical_plans/format/format_table_scan.rs +++ b/src/query/service/src/physical_plans/format/format_table_scan.rs @@ -137,16 +137,21 @@ impl<'a> PhysicalFormat for TableScanFormatter<'a> { // Aggregating index if let Some(agg_index) = agg_index { - let (_, agg_index_sql, _) = ctx + let table_index = self + .inner + .table_index + .expect("agg index should only exist for bound table scans"); + let index = ctx .metadata - .get_agg_indices(&table_name) + .get_agg_indices(table_index) .unwrap() .iter() - .find(|(index, _, _)| *index == agg_index.index_id) + .find(|index| index.index_id == agg_index.index_id) .unwrap(); children.push(FormatTreeNode::new(format!( - "aggregating index: [{agg_index_sql}]" + "aggregating index: [{}]", + index.sql ))); let agg_sel = agg_index diff --git a/src/query/service/tests/it/sql/planner/optimizer/optimizers/rule/agg_rules/agg_index_query_rewrite.rs b/src/query/service/tests/it/sql/planner/optimizer/optimizers/rule/agg_rules/agg_index_query_rewrite.rs index a6b63721ca7e9..34b613f944123 100644 --- a/src/query/service/tests/it/sql/planner/optimizer/optimizers/rule/agg_rules/agg_index_query_rewrite.rs +++ b/src/query/service/tests/it/sql/planner/optimizer/optimizers/rule/agg_rules/agg_index_query_rewrite.rs @@ -24,6 +24,7 @@ use databend_common_meta_app::schema::CreateOption; use databend_common_sql::BindContext; use databend_common_sql::MetadataRef; use databend_common_sql::optimizer::OptimizerContext; +use databend_common_sql::optimizer::build_agg_index_plan_for_table; use databend_common_sql::optimizer::ir::SExpr; use databend_common_sql::optimizer::optimizers::recursive::RecursiveRuleOptimizer; use databend_common_sql::optimizer::optimizers::rule::DEFAULT_REWRITE_RULES; @@ -227,6 +228,13 @@ fn get_test_suites() -> Vec { index_selection: vec!["index_col_0 (#0)", "index_col_1 (#1)"], rewritten_predicates: vec![], }, + TestSuite { + query: "select avg(a) from t", + index: "select count(), count(a), sum(a), sum(b) from t", + is_matched: true, + index_selection: vec!["index_col_2 (#2)", "index_col_1 (#1)"], + rewritten_predicates: vec![], + }, // query: eval-agg-eval-filter-scan, index: eval-agg-eval-scan TestSuite { query: "select sum(a) + 1 from t where b > 1 group by b", @@ -378,18 +386,32 @@ async fn test_query_rewrite_impl(format: &str) -> Result<()> { let test_suites = get_test_suites(); for suite in test_suites.into_iter() { let (index, _, _) = plan_sql(ctx.clone(), suite.index, false).await?; - let (mut query, _, metadata) = plan_sql(ctx.clone(), suite.query, true).await?; + let (mut query, _, metadata_ref) = plan_sql(ctx.clone(), suite.query, true).await?; + let table_index = { + let metadata = metadata_ref.read(); + metadata + .tables() + .iter() + .find(|table| table.name() == "t") + .map(|table| table.index()) + .expect("query metadata should contain table t") + }; + let agg_index_plan = build_agg_index_plan_for_table( + ctx.clone(), + None, + metadata_ref.clone(), + table_index, + 0, + suite.index.to_string(), + index, + )?; { - let mut metadata = metadata.write(); - metadata.add_agg_indices("default.default.t".to_string(), vec![( - 0, - suite.index.to_string(), - index, - )]); + let mut metadata = metadata_ref.write(); + metadata.add_agg_indices(table_index, vec![agg_index_plan]); } query.clear_applied_rules(); - let opt_ctx = OptimizerContext::new(ctx.clone(), metadata.clone()); + let opt_ctx = OptimizerContext::new(ctx.clone(), metadata_ref.clone()); let result = RecursiveRuleOptimizer::new(opt_ctx.clone(), &[RuleID::TryApplyAggIndex]) .optimize_sync(&query)?; let agg_index = find_push_down_index_info(&result)?; diff --git a/src/query/sql/src/planner/binder/bind_table_reference/bind_table.rs b/src/query/sql/src/planner/binder/bind_table_reference/bind_table.rs index 07c63d8abba69..26e6a780f1edb 100644 --- a/src/query/sql/src/planner/binder/bind_table_reference/bind_table.rs +++ b/src/query/sql/src/planner/binder/bind_table_reference/bind_table.rs @@ -136,6 +136,38 @@ impl Binder { let navigation = self.resolve_temporal_clause(bind_context, temporal)?; + if bind_context.planning_agg_index { + let source_table_index = { + let metadata = self.metadata.read(); + metadata + .tables() + .iter() + .find(|table| { + table.is_source_of_index() + && table.catalog() == catalog + && table.database() == database + && table.name() == table_name + }) + .map(|table| table.index()) + }; + if let Some(table_index) = source_table_index { + let (s_expr, mut bind_context) = self.bind_base_table( + bind_context, + database.as_str(), + table_index, + None, + sample, + true, + true, + )?; + + if let Some(alias) = alias { + bind_context.apply_table_alias(alias, &self.name_resolution_ctx)?; + } + return Ok((s_expr, bind_context)); + } + } + // Resolve table with catalog let table_meta = { let table_name = if let Some(cte_suffix_name) = cte_suffix_name.as_ref() { @@ -206,6 +238,7 @@ impl Binder { change_type, sample, true, + false, )?; if let Some(alias) = alias { @@ -327,6 +360,7 @@ impl Binder { None, sample, true, + false, )?; if let Some(alias) = alias { bind_context.apply_table_alias(alias, &self.name_resolution_ctx)?; diff --git a/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs b/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs index 4cac5721092ef..7cbf793c0a7c7 100644 --- a/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs +++ b/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs @@ -309,8 +309,15 @@ impl Binder { false, None, ); - let (s_expr, mut bind_context) = - self.bind_base_table(bind_context, "system", table_index, None, sample, true)?; + let (s_expr, mut bind_context) = self.bind_base_table( + bind_context, + "system", + table_index, + None, + sample, + true, + false, + )?; if let Some(alias) = alias { bind_context.apply_table_alias(alias, &self.name_resolution_ctx)?; } @@ -371,8 +378,15 @@ impl Binder { None, ); - let (s_expr, mut bind_context) = - self.bind_base_table(bind_context, "system", table_index, None, &None, true)?; + let (s_expr, mut bind_context) = self.bind_base_table( + bind_context, + "system", + table_index, + None, + &None, + true, + false, + )?; if let Some(alias) = alias { bind_context.apply_table_alias(alias, &self.name_resolution_ctx)?; } diff --git a/src/query/sql/src/planner/binder/ddl/index.rs b/src/query/sql/src/planner/binder/ddl/index.rs index f5386e8ecd57e..69f453c234eb7 100644 --- a/src/query/sql/src/planner/binder/ddl/index.rs +++ b/src/query/sql/src/planner/binder/ddl/index.rs @@ -14,7 +14,9 @@ use std::collections::BTreeMap; use std::collections::BTreeSet; +use std::collections::HashMap; use std::collections::HashSet; +use std::sync::Arc; use std::sync::LazyLock; use databend_common_ast::ast::CreateIndexStmt; @@ -38,6 +40,7 @@ use databend_common_exception::Result; use databend_common_expression::ColumnId; use databend_common_expression::TableDataType; use databend_common_expression::TableSchemaRef; +use databend_common_expression::types::DataType; use databend_common_meta_app::schema::GetIndexReq; use databend_common_meta_app::schema::IndexMeta; use databend_common_meta_app::schema::IndexNameIdent; @@ -46,15 +49,27 @@ use databend_storages_common_table_meta::meta::Location; use derive_visitor::Drive; use derive_visitor::DriveMut; use itertools::Itertools; +use parking_lot::RwLock; use crate::AggregatingIndexChecker; use crate::AggregatingIndexRewriter; use crate::BindContext; +use crate::ColumnEntry; +use crate::Metadata; use crate::MetadataRef; use crate::RefreshAggregatingIndexRewriter; use crate::SUPPORTED_AGGREGATING_INDEX_FUNCTIONS; +use crate::Symbol; +use crate::TableEntry; +use crate::Visibility; use crate::binder::Binder; +use crate::binder::ColumnBinding; +use crate::binder::ColumnBindingBuilder; use crate::optimizer::OptimizerContext; +use crate::optimizer::build_agg_index_plan_for_table; +use crate::optimizer::ir::SExpr; +use crate::optimizer::ir::SExprVisitor; +use crate::optimizer::ir::VisitAction; use crate::optimizer::optimize; use crate::plans::CreateIndexPlan; use crate::plans::CreateTableIndexPlan; @@ -63,6 +78,10 @@ use crate::plans::DropTableIndexPlan; use crate::plans::Plan; use crate::plans::RefreshIndexPlan; use crate::plans::RefreshTableIndexPlan; +use crate::plans::RelOperator; +use crate::plans::ScalarExpr; +use crate::plans::ScalarItem; +use crate::plans::Scan; const MAXIMUM_BLOOM_SIZE: u64 = 10 * 1024 * 1024; const MINIMUM_BLOOM_SIZE: u64 = 512; @@ -151,7 +170,6 @@ impl Binder { metadata: &MetadataRef, ) -> Result<()> { let catalog = self.ctx.get_current_catalog(); - let database = self.ctx.get_current_database(); let tables = metadata.read().tables().to_vec(); for table_entry in tables { @@ -171,33 +189,89 @@ impl Binder { .resolve_table_indexes(&self.ctx.get_tenant(), catalog.as_str(), table.get_id()) .await?; + let child_metadata = Arc::new(RwLock::new(metadata.read().clone())); + child_metadata + .write() + .set_table_source_of_index(table_entry.index(), true); let mut s_exprs = Vec::with_capacity(indexes.len()); for (index_id, _, index_meta) in indexes { - let tokens = tokenize_sql(&index_meta.query)?; - let (stmt, _) = parse_sql(&tokens, self.dialect)?; - let mut new_bind_context = BindContext::with_parent(bind_context.clone())?; - new_bind_context.planning_agg_index = true; - if let Statement::Query(query) = &stmt { - let (s_expr, _) = self.bind_query(&mut new_bind_context, query)?; - s_exprs.push((index_id, index_meta.query.clone(), s_expr)); + if let Some(agg_index_plan) = self.bind_agg_index_query_locally( + &index_meta, + &table_entry, + index_id, + &child_metadata, + )? { + s_exprs.push(agg_index_plan); } } agg_indexes.extend(s_exprs); } if !agg_indexes.is_empty() { - // Should use bound table id. - let table_name = table.name(); - let full_table_name = format!("{catalog}.{database}.{table_name}"); metadata .write() - .add_agg_indices(full_table_name, agg_indexes); + .add_agg_indices(table_entry.index(), agg_indexes); } } Ok(()) } + fn bind_agg_index_query_locally( + &self, + index_meta: &IndexMeta, + table_entry: &TableEntry, + index_id: u64, + child_metadata: &MetadataRef, + ) -> Result> { + let tokens = tokenize_sql(&index_meta.query)?; + let (stmt, _) = parse_sql(&tokens, self.dialect)?; + let Statement::Query(query) = &stmt else { + return Ok(None); + }; + + let index_metadata = Arc::new(RwLock::new(child_metadata.read().clone())); + let initial_table_count = index_metadata.read().tables().len(); + + let mut index_binder = Binder::new( + self.ctx.clone(), + self.catalogs.clone(), + self.name_resolution_ctx.clone(), + index_metadata, + ) + .with_subquery_executor(self.subquery_executor.clone()); + let mut index_bind_context = BindContext::new(); + index_bind_context.planning_agg_index = true; + let (s_expr, _) = index_binder.bind_query(&mut index_bind_context, query)?; + if index_binder.metadata.read().tables().len() != initial_table_count { + return Err(ErrorCode::Internal(format!( + "aggregating index query introduced unexpected tables while binding `{}`", + index_meta.query + ))); + } + let s_expr = normalize_agg_index_s_expr( + &index_binder.metadata, + table_entry.index(), + table_entry.database(), + table_entry.name(), + &s_expr, + )?; + + let mut agg_index_plan = build_agg_index_plan_for_table( + self.ctx.clone(), + self.subquery_executor.clone(), + index_binder.metadata.clone(), + table_entry.index(), + index_id, + index_meta.query.clone(), + s_expr, + )?; + // Keep per-index binding isolated, but expose one canonical child metadata per table. + agg_index_plan.metadata = child_metadata.clone(); + + Ok(Some(agg_index_plan)) + } + #[async_backtrace::framed] pub(in crate::planner::binder) async fn bind_create_index( &mut self, @@ -371,18 +445,29 @@ impl Binder { bind_context.planning_agg_index = true; let plan = if let Statement::Query(_) = &stmt { - let select_plan = self.bind_statement(bind_context, &stmt).await?; - let opt_ctx = OptimizerContext::new(self.ctx.clone(), self.metadata.clone()) + let refresh_metadata = Arc::new(RwLock::new(Metadata::default())); + let mut refresh_binder = Binder::new( + self.ctx.clone(), + self.catalogs.clone(), + self.name_resolution_ctx.clone(), + refresh_metadata, + ) + .with_subquery_executor(self.subquery_executor.clone()); + let select_plan = refresh_binder.bind_statement(bind_context, &stmt).await?; + let opt_ctx = OptimizerContext::new(self.ctx.clone(), refresh_binder.metadata.clone()) .set_planning_agg_index(true) .clone(); - Ok(optimize(opt_ctx, select_plan).await?) + Ok(( + optimize(opt_ctx, select_plan).await?, + refresh_binder.metadata.clone(), + )) } else { Err(ErrorCode::UnsupportedIndex("statement is not query")) }; - let plan = plan?; + let (plan, refresh_metadata) = plan?; bind_context.planning_agg_index = false; - let tables = self.metadata.read().tables().to_vec(); + let tables = refresh_metadata.read().tables().to_vec(); if tables.len() != 1 { return Err(ErrorCode::UnsupportedIndex( @@ -961,3 +1046,228 @@ impl Binder { Ok(Plan::RefreshTableIndex(Box::new(plan))) } } + +fn normalize_agg_index_s_expr( + metadata: &MetadataRef, + canonical_table_index: usize, + canonical_database_name: &str, + canonical_table_name: &str, + s_expr: &SExpr, +) -> Result { + let actual_table_index = find_scan_table_index(s_expr)?; + let replacements = { + let metadata = metadata.read(); + build_agg_index_column_replacements( + &metadata.columns_by_table_index(canonical_table_index), + &metadata.columns_by_table_index(actual_table_index), + canonical_table_index, + canonical_database_name, + canonical_table_name, + )? + }; + + let mut visitor = AggIndexSExprNormalizer { + actual_table_index, + canonical_table_index, + replacements: &replacements, + }; + Ok(s_expr + .accept(&mut visitor)? + .unwrap_or_else(|| s_expr.clone())) +} + +fn find_scan_table_index(s_expr: &SExpr) -> Result { + match s_expr.plan() { + RelOperator::Scan(scan) => Ok(scan.table_index), + _ => find_scan_table_index(s_expr.child(0)?), + } +} + +fn build_agg_index_column_replacements( + canonical_columns: &[ColumnEntry], + actual_columns: &[ColumnEntry], + canonical_table_index: usize, + canonical_database_name: &str, + canonical_table_name: &str, +) -> Result> { + let canonical_base_columns = canonical_columns + .iter() + .filter_map(|column| match column { + ColumnEntry::BaseTableColumn(base_column) => { + Some((base_column.column_name.clone(), base_column)) + } + _ => None, + }) + .collect::>(); + + let mut replacements = HashMap::new(); + for actual_column in actual_columns { + let ColumnEntry::BaseTableColumn(actual_base_column) = actual_column else { + continue; + }; + let Some(canonical_base_column) = + canonical_base_columns.get(&actual_base_column.column_name) + else { + return Err(ErrorCode::Internal(format!( + "missing canonical column mapping for aggregating index column {}", + actual_base_column.column_name + ))); + }; + + let column_binding = ColumnBindingBuilder::new( + canonical_base_column.column_name.clone(), + canonical_base_column.column_index, + Box::new(DataType::from(&canonical_base_column.data_type)), + Visibility::Visible, + ) + .database_name(Some(canonical_database_name.to_string())) + .table_name(Some(canonical_table_name.to_string())) + .table_index(Some(canonical_table_index)) + .column_position(canonical_base_column.column_position) + .virtual_expr(canonical_base_column.virtual_expr.clone()) + .build(); + + replacements.insert(actual_base_column.column_index, column_binding); + } + + Ok(replacements) +} + +struct AggIndexSExprNormalizer<'a> { + actual_table_index: usize, + canonical_table_index: usize, + replacements: &'a HashMap, +} + +impl SExprVisitor for AggIndexSExprNormalizer<'_> { + fn visit(&mut self, expr: &SExpr) -> Result { + if let RelOperator::Scan(scan) = expr.plan() { + if scan.table_index != self.actual_table_index { + return Err(ErrorCode::Internal(format!( + "unexpected aggregating index scan table index {}, expected {}", + scan.table_index, self.actual_table_index + ))); + } + return Ok(VisitAction::SkipChildren); + } + Ok(VisitAction::Continue) + } + + fn post_visit(&mut self, expr: &SExpr) -> Result { + let plan = match expr.plan().clone() { + RelOperator::EvalScalar(mut eval) => { + normalize_scalar_items(&mut eval.items, self.replacements)?; + RelOperator::EvalScalar(eval) + } + RelOperator::Aggregate(mut agg) => { + normalize_scalar_items(&mut agg.group_items, self.replacements)?; + normalize_scalar_items(&mut agg.aggregate_functions, self.replacements)?; + if let Some((items, _)) = &mut agg.rank_limit { + for item in items { + replace_sort_item(item, self.replacements); + } + } + RelOperator::Aggregate(agg) + } + RelOperator::Filter(mut filter) => { + normalize_scalars(&mut filter.predicates, self.replacements)?; + RelOperator::Filter(filter) + } + RelOperator::Sort(mut sort) => { + for (old, new_column) in self.replacements { + sort.replace_column(*old, new_column.index); + } + RelOperator::Sort(sort) + } + RelOperator::Scan(mut scan) => { + normalize_scan(&mut scan, self.canonical_table_index, self.replacements)?; + RelOperator::Scan(scan) + } + _ => return Ok(VisitAction::Continue), + }; + + Ok(VisitAction::Replace(expr.replace_plan(plan))) + } +} + +fn normalize_scalar_items( + items: &mut [ScalarItem], + replacements: &HashMap, +) -> Result<()> { + for item in items { + normalize_scalar(&mut item.scalar, replacements)?; + if let Some(new_column) = replacements.get(&item.index) { + item.index = new_column.index; + } + } + Ok(()) +} + +fn normalize_scalars( + scalars: &mut [ScalarExpr], + replacements: &HashMap, +) -> Result<()> { + for scalar in scalars { + normalize_scalar(scalar, replacements)?; + } + Ok(()) +} + +fn normalize_scalar( + scalar: &mut ScalarExpr, + replacements: &HashMap, +) -> Result<()> { + for (old, new_column) in replacements { + scalar.replace_column_binding(*old, new_column)?; + } + Ok(()) +} + +fn normalize_scan( + scan: &mut Scan, + canonical_table_index: usize, + replacements: &HashMap, +) -> Result<()> { + scan.table_index = canonical_table_index; + scan.columns = replace_column_set(&scan.columns, replacements); + scan.statistics = Default::default(); + + if let Some(predicates) = &mut scan.push_down_predicates { + normalize_scalars(predicates, replacements)?; + } + if let Some(order_by) = &mut scan.order_by { + for item in order_by { + replace_sort_item(item, replacements); + } + } + if let Some(prewhere) = &mut scan.prewhere { + prewhere.output_columns = replace_column_set(&prewhere.output_columns, replacements); + prewhere.prewhere_columns = replace_column_set(&prewhere.prewhere_columns, replacements); + normalize_scalars(&mut prewhere.predicates, replacements)?; + } + Ok(()) +} + +fn replace_sort_item( + item: &mut crate::plans::SortItem, + replacements: &HashMap, +) { + if let Some(new_column) = replacements.get(&item.index) { + item.index = new_column.index; + } +} + +fn replace_column_set( + columns: &crate::ColumnSet, + replacements: &HashMap, +) -> crate::ColumnSet { + columns + .iter() + .map(|column| { + replacements + .get(column) + .map(|new_column| new_column.index) + .unwrap_or(*column) + }) + .collect() +} diff --git a/src/query/sql/src/planner/binder/table.rs b/src/query/sql/src/planner/binder/table.rs index c8138584028d5..abc360313afce 100644 --- a/src/query/sql/src/planner/binder/table.rs +++ b/src/query/sql/src/planner/binder/table.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::BTreeMap; +use std::collections::BTreeSet; use std::collections::HashMap; use std::default::Default; use std::sync::Arc; @@ -161,8 +162,15 @@ impl Binder { None, ); - let (s_expr, mut bind_context) = - self.bind_base_table(bind_context, "system", table_index, None, &None, false)?; + let (s_expr, mut bind_context) = self.bind_base_table( + bind_context, + "system", + table_index, + None, + &None, + false, + false, + )?; if let Some(alias) = alias { bind_context.apply_table_alias(alias, &self.name_resolution_ctx)?; } @@ -369,6 +377,7 @@ impl Binder { change_type: Option, sample: &Option, case_sensitive: bool, + skip_internal_columns: bool, ) -> Result<(SExpr, BindContext)> { let mut bind_context = BindContext::with_parent(bind_context.clone())?; @@ -382,6 +391,7 @@ impl Binder { table ); let mut base_column_scan_id = HashMap::new(); + let mut scan_columns = BTreeSet::new(); for column in columns.iter() { match column { ColumnEntry::BaseTableColumn(BaseTableColumn { @@ -413,6 +423,7 @@ impl Binder { .build(); bind_context.add_column_binding(column_binding); base_column_scan_id.insert(*column_index, scan_id); + scan_columns.insert(*column_index); } ColumnEntry::VirtualColumn(VirtualColumn { table_index, @@ -433,7 +444,9 @@ impl Binder { .build(); bind_context.add_column_binding(column_binding); base_column_scan_id.insert(*column_index, scan_id); + scan_columns.insert(*column_index); } + ColumnEntry::InternalColumn(_) if skip_internal_columns => {} other => { return Err(ErrorCode::Internal(format!( "Invalid column entry '{:?}' encountered while binding the base table '{}'. Ensure that the table definition and column references are correct.", @@ -450,7 +463,7 @@ impl Binder { let scan_s_expr = SExpr::create_leaf(Arc::new( Scan { table_index, - columns: columns.into_iter().map(|col| col.index()).collect(), + columns: scan_columns, change_type, sample: sample.clone(), scan_id, diff --git a/src/query/sql/src/planner/dataframe.rs b/src/query/sql/src/planner/dataframe.rs index 5d20b49d085aa..826098336e33f 100644 --- a/src/query/sql/src/planner/dataframe.rs +++ b/src/query/sql/src/planner/dataframe.rs @@ -103,7 +103,15 @@ impl Dataframe { None, ); - binder.bind_base_table(&bind_context, database, table_index, None, &None, true) + binder.bind_base_table( + &bind_context, + database, + table_index, + None, + &None, + true, + false, + ) } else { binder.bind_table_reference(&mut bind_context, &table) }?; diff --git a/src/query/sql/src/planner/metadata/metadata.rs b/src/query/sql/src/planner/metadata/metadata.rs index 5567421d87520..ff28e1ca05717 100644 --- a/src/query/sql/src/planner/metadata/metadata.rs +++ b/src/query/sql/src/planner/metadata/metadata.rs @@ -36,6 +36,7 @@ use databend_common_expression::types::DataType; use jsonb::keypath::OwnedKeyPaths; use parking_lot::RwLock; +use crate::optimizer::AggIndexViewInfo; use crate::optimizer::ir::SExpr; /// Planner use [`usize`] as it's index type. @@ -75,7 +76,7 @@ pub struct Metadata { non_lazy_columns: ColumnSet, /// Mappings from table index to _row_id column index. table_row_id_index: HashMap, - agg_indices: HashMap>, + agg_indices: HashMap>, max_column_position: usize, // for CSV /// Scan id of each scan operator. @@ -86,6 +87,15 @@ pub struct Metadata { next_logical_recursive_cte_id: u32, } +#[derive(Clone, Debug)] +pub struct AggIndexPlan { + pub index_id: u64, + pub sql: String, + pub metadata: MetadataRef, + pub s_expr: SExpr, + pub prepared: Arc, +} + impl Metadata { fn next_column_index(&self) -> Symbol { Symbol::new(self.columns.len()) @@ -95,6 +105,12 @@ impl Metadata { self.tables.get(index).expect("metadata must contain table") } + pub fn set_table_source_of_index(&mut self, index: IndexType, source_of_index: bool) { + if let Some(table) = self.tables.get_mut(index) { + table.source_of_index = source_of_index; + } + } + pub fn tables(&self) -> &[TableEntry] { self.tables.as_slice() } @@ -333,8 +349,8 @@ impl Metadata { column_index } - pub fn add_agg_indices(&mut self, table: String, agg_indices: Vec<(u64, String, SExpr)>) { - match self.agg_indices.entry(table) { + pub fn add_agg_indices(&mut self, table_index: IndexType, agg_indices: Vec) { + match self.agg_indices.entry(table_index) { Entry::Occupied(occupied) => occupied.into_mut().extend(agg_indices), Entry::Vacant(vacant) => { vacant.insert(agg_indices); @@ -342,16 +358,16 @@ impl Metadata { } } - pub fn agg_indices(&self) -> &HashMap> { + pub fn agg_indices(&self) -> &HashMap> { &self.agg_indices } - pub fn replace_agg_indices(&mut self, agg_indices: HashMap>) { + pub fn replace_agg_indices(&mut self, agg_indices: HashMap>) { self.agg_indices = agg_indices } - pub fn get_agg_indices(&self, table: &str) -> Option<&[(u64, String, SExpr)]> { - self.agg_indices.get(table).map(|v| v.as_slice()) + pub fn get_agg_indices(&self, table_index: IndexType) -> Option<&[AggIndexPlan]> { + self.agg_indices.get(&table_index).map(|v| v.as_slice()) } pub fn has_agg_indices(&self) -> bool { diff --git a/src/query/sql/src/planner/optimizer/mod.rs b/src/query/sql/src/planner/optimizer/mod.rs index ace3f3e86a213..f55005ccbfe40 100644 --- a/src/query/sql/src/planner/optimizer/mod.rs +++ b/src/query/sql/src/planner/optimizer/mod.rs @@ -28,3 +28,5 @@ pub use optimizer::optimize; pub use optimizer::optimize_query; pub use optimizer_api::Optimizer; pub use optimizer_context::OptimizerContext; +pub use optimizers::rule::AggIndexViewInfo; +pub use optimizers::rule::build_agg_index_plan_for_table; diff --git a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/mod.rs b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/mod.rs index a8db02e8548d9..7e2aeeef85273 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/mod.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/mod.rs @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod prepare; mod query_rewrite; +mod rewrite; -pub use query_rewrite::*; +pub use prepare::AggIndexViewInfo; +pub use query_rewrite::build_agg_index_plan_for_table; +pub use query_rewrite::try_rewrite; diff --git a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/prepare.rs b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/prepare.rs new file mode 100644 index 0000000000000..c4c05ecd87065 --- /dev/null +++ b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/prepare.rs @@ -0,0 +1,1156 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Ordering; +use std::collections::HashMap; +use std::collections::HashSet; + +use databend_common_exception::ErrorCode; +use databend_common_exception::Result; +use databend_common_expression::FieldIndex; +use databend_common_expression::Scalar; +use databend_common_expression::TableField; +use databend_common_expression::infer_schema_type; +use databend_common_expression::types::DataType; +use databend_common_functions::aggregates::AggregateFunctionFactory; + +use super::rewrite::AggIndexMatcher; +use crate::ColumnEntry; +use crate::IndexType; +use crate::ScalarExpr; +use crate::Symbol; +use crate::Visibility; +use crate::binder::ColumnBindingBuilder; +use crate::optimizer::ir::SExpr; +use crate::plans::Aggregate; +use crate::plans::BoundColumnRef; +use crate::plans::ComparisonOp; +use crate::plans::ConstantExpr; +use crate::plans::RelOperator; +use crate::plans::ScalarItem; +use crate::plans::SortItem; + +pub(super) struct PreparedAggIndexQuery { + info: QueryInfo, +} + +impl PreparedAggIndexQuery { + pub(super) fn prepare( + table_index: IndexType, + table_name: &str, + base_columns: &[ColumnEntry], + s_expr: &SExpr, + ) -> Result { + Ok(Self { + info: QueryInfo::new(table_index, table_name, base_columns, s_expr)?, + }) + } + + pub(super) fn matcher(&self) -> AggIndexMatcher<'_> { + AggIndexMatcher { + query_info: &self.info, + } + } +} + +#[derive(Debug)] +pub(super) struct QueryInfo { + pub(super) equi_classes: EquivalenceClasses, + pub(super) range_classes: RangeClasses, + pub(super) residual_classes: ResidualClasses, + pub(super) sort_items: Option>, + pub(super) aggregate: Option, + pub(super) column_map: HashMap, + pub(super) column_exprs: Vec, + pub(super) output_cols: Vec, +} + +#[derive(Debug)] +pub(super) struct ColumnExprEntry { + pub(super) expr: ScalarExpr, + pub(super) index: Symbol, + pub(super) source: ColumnExprSource, +} + +struct ColumnExprSourceEntry { + index: Symbol, + source: ColumnExprSource, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(super) struct ColumnExprSource { + order: usize, +} + +impl ColumnExprSource { + fn new(order: usize) -> Self { + Self { order } + } + + fn prefer_over(self, other: Self, index: Symbol, other_index: Symbol) -> bool { + match self.order.cmp(&other.order) { + Ordering::Less => false, + Ordering::Greater => true, + Ordering::Equal => index < other_index, + } + } +} + +#[derive(Debug)] +pub(super) struct IndexOutputColumn { + pub(super) expr: ScalarExpr, + pub(super) index_scalar: ScalarExpr, + pub(super) is_agg: bool, +} + +#[derive(Debug)] +struct EqualityPredicate { + left: ScalarExpr, + right: ScalarExpr, +} + +#[derive(Debug)] +struct RangePredicate { + op: ComparisonOp, + column: ScalarExpr, + value: ConstantExpr, +} + +#[derive(Debug)] +struct RangeClassEntry { + column: ScalarExpr, + values: RangeValues, +} + +#[derive(Debug)] +pub(super) struct CompensatingRange { + pub(super) column: ScalarExpr, + pub(super) lower_bound: BoundValue, + pub(super) upper_bound: BoundValue, +} + +impl QueryInfo { + fn new( + table_index: IndexType, + table_name: &str, + base_columns: &[ColumnEntry], + s_expr: &SExpr, + ) -> Result { + let RelOperator::EvalScalar(selection) = s_expr.plan() else { + return Err(ErrorCode::Internal("Unsupported plan")); + }; + + let mut predicates: Option<&[ScalarExpr]> = None; + let mut sort_items = None; + let mut aggregate = None; + let mut column_map = HashMap::new(); + let mut column_expr_sources = Vec::new(); + let mut next_source_order = 0usize; + + let mut record_column_source = |index: Symbol, scalar: ScalarExpr| { + column_map.insert(index, scalar); + column_expr_sources.push(ColumnExprSourceEntry { + index, + source: ColumnExprSource::new(next_source_order), + }); + next_source_order += 1; + }; + + for item in &selection.items { + record_column_source(item.index, item.scalar.clone()); + } + + let mut s_expr = s_expr.unary_child(); + loop { + match s_expr.plan() { + RelOperator::EvalScalar(eval) => { + for item in &eval.items { + record_column_source(item.index, item.scalar.clone()); + } + } + RelOperator::Aggregate(agg) => { + if agg.grouping_sets.is_some() { + return Err(ErrorCode::Internal("Grouping sets is not supported")); + } + + aggregate = Some(agg.clone()); + for item in &agg.aggregate_functions { + record_column_source(item.index, item.scalar.clone()); + } + for item in &agg.group_items { + record_column_source(item.index, item.scalar.clone()); + } + let child = s_expr.unary_child(); + if let RelOperator::EvalScalar(eval) = child.plan() { + for item in &eval.items { + record_column_source(item.index, item.scalar.clone()); + } + s_expr = child.unary_child(); + continue; + } + } + RelOperator::Sort(sort) => { + sort_items = Some(sort.items.clone()); + } + RelOperator::Filter(filter) => { + predicates = Some(filter.predicates.as_ref()); + } + RelOperator::Scan(scan) => { + if let Some(prewhere) = &scan.prewhere { + debug_assert!(predicates.is_none()); + predicates = Some(prewhere.predicates.as_ref()); + } + break; + } + _ => { + return Err(ErrorCode::Internal("Unsupported plan")); + } + } + s_expr = s_expr.unary_child(); + } + + for base_column in base_columns { + let column_binding = ColumnBindingBuilder::new( + base_column.name(), + base_column.index(), + Box::new(base_column.data_type()), + Visibility::Visible, + ) + .table_name(Some(table_name.to_string())) + .table_index(Some(table_index)) + .build(); + + let column = ScalarExpr::BoundColumnRef(BoundColumnRef { + span: None, + column: column_binding, + }); + record_column_source(base_column.index(), column); + } + + let mut column_exprs = Vec::new(); + let column_matcher = ScalarExprMatcher::same(&column_map); + for (scalar, source) in column_expr_sources + .iter() + .filter_map(|source| column_map.get(&source.index).map(|scalar| (scalar, source))) + { + register_column_expr( + &mut column_exprs, + scalar.clone(), + source.index, + source.source, + &column_matcher, + ); + } + + let mut equi_classes = EquivalenceClasses::default(); + let mut range_classes = RangeClasses::default(); + let mut residual_classes = ResidualClasses::default(); + + if let Some(predicates) = predicates { + let mut preds_splitter = PredicatesSplitter::default(); + for pred in predicates { + preds_splitter.split(pred, &column_map); + } + + for equi_pred in &preds_splitter.equi_columns_preds { + equi_classes.add_equivalence_class( + &equi_pred.left, + &equi_pred.right, + &column_matcher, + ); + } + for range_pred in &preds_splitter.range_preds { + range_classes.add_range_class( + range_pred.op, + &range_pred.column, + &range_pred.value, + &column_matcher, + ); + } + for residual_pred in &preds_splitter.residual_preds { + residual_classes.add_residual_pred(residual_pred.clone(), &column_matcher); + } + } + + let output_cols = selection + .items + .iter() + .map(|item| { + let actual_scalar = actual_column_ref(&item.scalar, &column_map); + ScalarItem { + index: item.index, + scalar: actual_scalar.clone(), + } + }) + .collect(); + + Ok(Self { + equi_classes, + range_classes, + residual_classes, + aggregate, + sort_items, + column_map, + column_exprs, + output_cols, + }) + } + + fn rewrite_output_args( + &self, + args: &[ScalarExpr], + index_output_cols: &[IndexOutputColumn], + index_column_map: &HashMap, + new_selection_set: &mut HashSet, + ) -> Result>> { + let mut rewritten = Vec::with_capacity(args.len()); + let mut all_rewritten = true; + for arg in args { + let new_arg = self.check_output_cols( + arg, + index_output_cols, + index_column_map, + new_selection_set, + )?; + if let Some(new_arg) = new_arg { + rewritten.push(new_arg); + } else { + all_rewritten = false; + } + } + if all_rewritten { + Ok(Some(rewritten)) + } else { + Ok(None) + } + } + + pub(super) fn check_output_cols( + &self, + scalar: &ScalarExpr, + index_output_cols: &[IndexOutputColumn], + index_column_map: &HashMap, + new_selection_set: &mut HashSet, + ) -> Result> { + let output_matcher = ScalarExprMatcher::new(index_column_map, &self.column_map); + let query_matcher = ScalarExprMatcher::same(&self.column_map); + + if let Some((new_scalar, is_agg)) = + output_matcher.find_index_output_col(index_output_cols, scalar) + { + if let Some(index) = query_matcher.find_column_index(&self.column_exprs, scalar) { + let new_item = ScalarItem { + index, + scalar: new_scalar.clone(), + }; + new_selection_set.insert(new_item); + } + return if is_agg { + Ok(None) + } else { + Ok(Some(new_scalar.clone())) + }; + } + + let new_scalar = match scalar { + ScalarExpr::BoundColumnRef(_) => { + if let Some(actual_column) = query_matcher.resolve_right(scalar) { + return self.check_output_cols( + actual_column, + index_output_cols, + index_column_map, + new_selection_set, + ); + } + return Err(ErrorCode::Internal("Can't found column from index")); + } + ScalarExpr::ConstantExpr(_) => scalar.clone(), + ScalarExpr::FunctionCall(func) => { + let Some(new_args) = self.rewrite_output_args( + &func.arguments, + index_output_cols, + index_column_map, + new_selection_set, + )? + else { + return Ok(None); + }; + let mut new_func = func.clone(); + new_func.arguments = new_args; + ScalarExpr::FunctionCall(new_func) + } + ScalarExpr::CastExpr(cast) => { + if let Some(new_arg) = self.check_output_cols( + &cast.argument, + index_output_cols, + index_column_map, + new_selection_set, + )? { + let mut new_cast = cast.clone(); + new_cast.argument = Box::new(new_arg); + ScalarExpr::CastExpr(new_cast) + } else { + return Ok(None); + } + } + ScalarExpr::AggregateFunction(func) => { + for expr in func.exprs() { + self.check_output_cols( + expr, + index_output_cols, + index_column_map, + new_selection_set, + )?; + } + return Ok(None); + } + ScalarExpr::UDAFCall(udaf) => { + for arg in &udaf.arguments { + self.check_output_cols( + arg, + index_output_cols, + index_column_map, + new_selection_set, + )?; + } + return Ok(None); + } + ScalarExpr::UDFCall(udf) => { + let Some(new_args) = self.rewrite_output_args( + &udf.arguments, + index_output_cols, + index_column_map, + new_selection_set, + )? + else { + return Ok(None); + }; + let mut new_udf = udf.clone(); + new_udf.arguments = new_args; + ScalarExpr::UDFCall(new_udf) + } + _ => unreachable!(), + }; + + if let Some(index) = query_matcher.find_column_index(&self.column_exprs, scalar) { + let new_item = ScalarItem { + index, + scalar: new_scalar.clone(), + }; + new_selection_set.insert(new_item); + return Ok(None); + } + + Ok(Some(new_scalar)) + } +} + +#[derive(Debug)] +pub struct AggIndexViewInfo { + pub(super) query_info: QueryInfo, + pub(super) index_fields: Vec, + pub(super) index_output_cols: Vec, +} + +impl AggIndexViewInfo { + pub(super) fn new( + table_index: IndexType, + table_name: &str, + base_columns: &[ColumnEntry], + s_expr: &SExpr, + ) -> Result { + let query_info = QueryInfo::new(table_index, table_name, base_columns, s_expr)?; + + let mut index_fields = Vec::with_capacity(query_info.output_cols.len()); + let mut index_output_cols = Vec::with_capacity(query_info.output_cols.len()); + let factory = AggregateFunctionFactory::instance(); + for (index, item) in query_info.output_cols.iter().enumerate() { + let aggr_scalar_item = query_info.aggregate.as_ref().and_then(|aggregate| { + aggregate + .aggregate_functions + .iter() + .find(|agg_func| agg_func.index == item.index) + }); + + let (data_type, is_agg) = match aggr_scalar_item { + Some(item) => { + let func = match &item.scalar { + ScalarExpr::AggregateFunction(func) => func, + _ => unreachable!(), + }; + let func = factory.get( + &func.func_name, + func.params.clone(), + func.args + .iter() + .map(|arg| arg.data_type()) + .collect::>()?, + func.sort_descs + .iter() + .map(|desc| desc.try_into()) + .collect::>()?, + )?; + (func.serialize_data_type(), true) + } + None => (item.scalar.data_type().unwrap(), false), + }; + + let name = index.to_string(); + let table_ty = infer_schema_type(&data_type)?; + let index_field = TableField::new(&name, table_ty); + index_fields.push(index_field); + + let index_scalar = to_index_scalar(index, &data_type); + index_output_cols.push(IndexOutputColumn { + expr: item.scalar.clone(), + index_scalar, + is_agg, + }); + } + + Ok(Self { + query_info, + index_fields, + index_output_cols, + }) + } +} + +#[derive(Default)] +struct PredicatesSplitter { + equi_columns_preds: Vec, + range_preds: Vec, + residual_preds: Vec, +} + +impl PredicatesSplitter { + fn split(&mut self, pred: &ScalarExpr, column_map: &HashMap) { + let ScalarExpr::FunctionCall(func) = pred else { + self.residual_preds.push(pred.clone()); + return; + }; + + match func.func_name.as_str() { + "and" | "and_filters" => { + for arg in &func.arguments { + self.split(arg, column_map); + } + } + "eq" if matches!(func.arguments[0], ScalarExpr::BoundColumnRef(_)) + && matches!(func.arguments[1], ScalarExpr::BoundColumnRef(_)) => + { + let arg0 = actual_column_ref(&func.arguments[0], column_map); + let arg1 = actual_column_ref(&func.arguments[1], column_map); + self.equi_columns_preds.push(EqualityPredicate { + left: arg0.clone(), + right: arg1.clone(), + }); + } + "eq" | "lt" | "lte" | "gt" | "gte" + if matches!(func.arguments[0], ScalarExpr::BoundColumnRef(_)) + && matches!(func.arguments[1], ScalarExpr::ConstantExpr(_)) => + { + let op = ComparisonOp::try_from_func_name(func.func_name.as_str()).unwrap(); + let column = actual_column_ref(&func.arguments[0], column_map).clone(); + let value = ConstantExpr::try_from(func.arguments[1].clone()).unwrap(); + self.range_preds.push(RangePredicate { op, column, value }); + } + "eq" | "lt" | "lte" | "gt" | "gte" + if matches!(func.arguments[0], ScalarExpr::ConstantExpr(_)) + && matches!(func.arguments[1], ScalarExpr::BoundColumnRef(_)) => + { + let op = ComparisonOp::try_from_func_name(func.func_name.as_str()) + .unwrap() + .reverse(); + let value = ConstantExpr::try_from(func.arguments[0].clone()).unwrap(); + let column = actual_column_ref(&func.arguments[1], column_map).clone(); + self.range_preds.push(RangePredicate { op, column, value }); + } + _ => { + self.residual_preds.push(pred.clone()); + } + } + } +} + +#[derive(Debug, Default)] +pub(super) struct EquivalenceClasses { + classes: Vec>, +} + +impl EquivalenceClasses { + fn add_equivalence_class( + &mut self, + col1: &ScalarExpr, + col2: &ScalarExpr, + matcher: &ScalarExprMatcher<'_, '_>, + ) { + let mut merged = vec![col1.clone(), col2.clone()]; + let mut idx = 0; + while idx < self.classes.len() { + if matcher.list_contains(&self.classes[idx], col1) + || matcher.list_contains(&self.classes[idx], col2) + { + let class = self.classes.remove(idx); + for scalar in class { + matcher.push_unique_scalar(&mut merged, scalar); + } + } else { + idx += 1; + } + } + self.classes.push(merged); + } + + pub(super) fn check( + &self, + view_equi_classes: &EquivalenceClasses, + query_column_map: &HashMap, + view_column_map: &HashMap, + ) -> bool { + let matcher = ScalarExprMatcher::new(query_column_map, view_column_map); + for view_class in &view_equi_classes.classes { + if self.classes.iter().any(|query_class| { + view_class + .iter() + .all(|view_scalar| matcher.list_contains(query_class, view_scalar)) + }) { + continue; + } + return false; + } + true + } +} + +#[derive(Eq, Clone, Debug)] +pub(super) enum BoundValue { + Closed(Scalar), + Open(Scalar), + NegativeInfinite, + PositiveInfinite, +} + +#[allow(clippy::non_canonical_partial_ord_impl)] +impl PartialOrd for BoundValue { + fn partial_cmp(&self, other: &Self) -> Option { + fn cmp_scalar(s1: &Scalar, s2: &Scalar) -> Ordering { + match (s1, s2) { + (Scalar::Number(n1), Scalar::Number(n2)) => { + if n1.is_integer() && n2.is_integer() { + let v1 = n1.integer_to_i128().unwrap(); + let v2 = n2.integer_to_i128().unwrap(); + v1.cmp(&v2) + } else { + let v1 = n1.to_f64(); + let v2 = n2.to_f64(); + v1.cmp(&v2) + } + } + (_, _) => s1.cmp(s2), + } + } + + match (self, other) { + (BoundValue::NegativeInfinite, BoundValue::NegativeInfinite) => Some(Ordering::Equal), + (BoundValue::PositiveInfinite, BoundValue::PositiveInfinite) => Some(Ordering::Equal), + (BoundValue::NegativeInfinite, _) => Some(Ordering::Less), + (_, BoundValue::NegativeInfinite) => Some(Ordering::Greater), + (BoundValue::PositiveInfinite, _) => Some(Ordering::Greater), + (_, BoundValue::PositiveInfinite) => Some(Ordering::Less), + (BoundValue::Open(v1), BoundValue::Open(v2)) => Some(cmp_scalar(v1, v2)), + (BoundValue::Closed(v1), BoundValue::Closed(v2)) => Some(cmp_scalar(v1, v2)), + (BoundValue::Open(v1), BoundValue::Closed(v2)) => { + let res = cmp_scalar(v1, v2); + if res == Ordering::Equal { + Some(Ordering::Less) + } else { + Some(res) + } + } + (BoundValue::Closed(v1), BoundValue::Open(v2)) => { + let res = cmp_scalar(v1, v2); + if res == Ordering::Equal { + Some(Ordering::Greater) + } else { + Some(res) + } + } + } + } +} + +impl Ord for BoundValue { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap_or(Ordering::Equal) + } +} + +impl PartialEq for BoundValue { + fn eq(&self, other: &Self) -> bool { + self.partial_cmp(other) == Some(Ordering::Equal) + } +} + +impl BoundValue { + pub(super) fn comparison_bound( + &self, + closed_name: &'static str, + open_name: &'static str, + ) -> Option<(Scalar, &'static str)> { + match self { + BoundValue::Closed(val) => Some((val.clone(), closed_name)), + BoundValue::Open(val) => Some((val.clone(), open_name)), + BoundValue::NegativeInfinite | BoundValue::PositiveInfinite => None, + } + } +} + +#[derive(Debug)] +struct RangeValues { + bounds: Option<(BoundValue, BoundValue)>, +} + +impl RangeValues { + fn new() -> Self { + Self { + bounds: Some((BoundValue::NegativeInfinite, BoundValue::PositiveInfinite)), + } + } + + fn insert(&mut self, lower_bound: BoundValue, upper_bound: BoundValue) { + if let Some((orig_lower_bound, orig_upper_bound)) = &self.bounds { + if upper_bound.cmp(orig_lower_bound) == Ordering::Less + || lower_bound.cmp(orig_upper_bound) == Ordering::Greater + { + self.bounds = None; + return; + } + match ( + lower_bound.cmp(orig_lower_bound), + upper_bound.cmp(orig_upper_bound), + ) { + (Ordering::Greater | Ordering::Equal, Ordering::Less | Ordering::Equal) => { + self.bounds = Some((lower_bound, upper_bound)) + } + (Ordering::Less, Ordering::Greater) => {} + (Ordering::Less, Ordering::Less | Ordering::Equal) => { + if matches!( + upper_bound.cmp(orig_lower_bound), + Ordering::Greater | Ordering::Equal + ) { + self.bounds = Some((orig_lower_bound.clone(), upper_bound)); + } + } + (Ordering::Greater | Ordering::Equal, Ordering::Greater) => { + if matches!( + lower_bound.cmp(orig_upper_bound), + Ordering::Less | Ordering::Equal + ) { + self.bounds = Some((lower_bound, orig_upper_bound.clone())); + } + } + } + } + } +} + +#[derive(Debug, Default)] +pub(super) struct RangeClasses { + column_to_range_class: Vec, +} + +impl RangeClasses { + fn add_range_class( + &mut self, + op: ComparisonOp, + col: &ScalarExpr, + val: &ConstantExpr, + matcher: &ScalarExprMatcher<'_, '_>, + ) { + let (lower_bound, upper_bound) = match op { + ComparisonOp::Equal => ( + BoundValue::Closed(val.value.clone()), + BoundValue::Closed(val.value.clone()), + ), + ComparisonOp::LT => ( + BoundValue::NegativeInfinite, + BoundValue::Open(val.value.clone()), + ), + ComparisonOp::LTE => ( + BoundValue::NegativeInfinite, + BoundValue::Closed(val.value.clone()), + ), + ComparisonOp::GT => ( + BoundValue::Open(val.value.clone()), + BoundValue::PositiveInfinite, + ), + ComparisonOp::GTE => ( + BoundValue::Closed(val.value.clone()), + BoundValue::PositiveInfinite, + ), + _ => unreachable!(), + }; + if let Some(range_entry) = self + .column_to_range_class + .iter_mut() + .find(|entry| matcher.expr_equal(&entry.column, col)) + { + range_entry.values.insert(lower_bound, upper_bound); + return; + } + let mut range_values = RangeValues::new(); + range_values.insert(lower_bound, upper_bound); + self.column_to_range_class.push(RangeClassEntry { + column: col.clone(), + values: range_values, + }); + } + + pub(super) fn check( + &self, + view_range_classes: &RangeClasses, + query_column_map: &HashMap, + view_column_map: &HashMap, + ) -> (bool, Option>) { + let mut extra_ranges = Vec::new(); + let matcher = ScalarExprMatcher::new(query_column_map, view_column_map); + let reverse_matcher = ScalarExprMatcher::new(view_column_map, query_column_map); + for range_entry in self.column_to_range_class.iter() { + if !view_range_classes + .column_to_range_class + .iter() + .any(|entry| reverse_matcher.expr_equal(&entry.column, &range_entry.column)) + { + if let Some((query_lower_bound, query_upper_bound)) = &range_entry.values.bounds { + extra_ranges.push(CompensatingRange { + column: range_entry.column.clone(), + lower_bound: query_lower_bound.clone(), + upper_bound: query_upper_bound.clone(), + }); + } + } + } + + for view_range_entry in view_range_classes.column_to_range_class.iter() { + if let Some(query_range_entry) = self + .column_to_range_class + .iter() + .find(|entry| matcher.expr_equal(&entry.column, &view_range_entry.column)) + { + match ( + &query_range_entry.values.bounds, + &view_range_entry.values.bounds, + ) { + ( + Some((query_lower_bound, query_upper_bound)), + Some((view_lower_bound, view_upper_bound)), + ) => { + let lower_res = view_lower_bound.cmp(query_lower_bound); + let upper_res = view_upper_bound.cmp(query_upper_bound); + + match (lower_res, upper_res) { + (Ordering::Equal, Ordering::Equal) => continue, + (Ordering::Equal, Ordering::Greater) => { + extra_ranges.push(CompensatingRange { + column: view_range_entry.column.clone(), + lower_bound: BoundValue::NegativeInfinite, + upper_bound: query_upper_bound.clone(), + }); + } + (Ordering::Less, Ordering::Equal) => { + extra_ranges.push(CompensatingRange { + column: view_range_entry.column.clone(), + lower_bound: query_lower_bound.clone(), + upper_bound: BoundValue::PositiveInfinite, + }); + } + (Ordering::Less, Ordering::Greater) => { + extra_ranges.push(CompensatingRange { + column: view_range_entry.column.clone(), + lower_bound: query_lower_bound.clone(), + upper_bound: query_upper_bound.clone(), + }); + } + (_, _) => return (false, None), + } + } + (Some((query_lower_bound, query_upper_bound)), None) => { + extra_ranges.push(CompensatingRange { + column: view_range_entry.column.clone(), + lower_bound: query_lower_bound.clone(), + upper_bound: query_upper_bound.clone(), + }); + } + (_, _) => return (false, None), + } + } else { + return (false, None); + } + } + if extra_ranges.is_empty() { + (true, None) + } else { + (true, Some(extra_ranges)) + } + } +} + +#[derive(Debug, Default)] +pub(super) struct ResidualClasses { + residual_preds: Vec, +} + +impl ResidualClasses { + fn add_residual_pred(&mut self, pred: ScalarExpr, matcher: &ScalarExprMatcher<'_, '_>) { + if !matcher.list_contains(&self.residual_preds, &pred) { + self.residual_preds.push(pred); + } + } + + pub(super) fn check( + &self, + view_residual_classes: &ResidualClasses, + query_column_map: &HashMap, + view_column_map: &HashMap, + ) -> (bool, Option>) { + let matcher = ScalarExprMatcher::new(query_column_map, view_column_map); + let reverse_matcher = ScalarExprMatcher::new(view_column_map, query_column_map); + let mut extra_residual_preds = Vec::new(); + for view_residual_pred in &view_residual_classes.residual_preds { + if !matcher.list_contains(&self.residual_preds, view_residual_pred) { + return (false, None); + } + } + for query_residual_pred in &self.residual_preds { + if !reverse_matcher + .list_contains(&view_residual_classes.residual_preds, query_residual_pred) + { + extra_residual_preds.push(query_residual_pred.clone()); + } + } + if extra_residual_preds.is_empty() { + (true, None) + } else { + (true, Some(extra_residual_preds)) + } + } +} + +pub(super) fn to_index_scalar(index: FieldIndex, data_type: &DataType) -> ScalarExpr { + let col = BoundColumnRef { + span: None, + column: ColumnBindingBuilder::new( + format!("index_col_{index}"), + Symbol::from_field_index(index), + Box::new(data_type.clone()), + Visibility::Visible, + ) + .build(), + }; + ScalarExpr::BoundColumnRef(col) +} + +fn actual_column_ref<'a>( + col: &'a ScalarExpr, + column_map: &'a HashMap, +) -> &'a ScalarExpr { + if let ScalarExpr::BoundColumnRef(col) = col { + if let Some(arg) = column_map.get(&col.column.index) { + return arg; + } + } + col +} + +fn register_column_expr( + column_exprs: &mut Vec, + expr: ScalarExpr, + index: Symbol, + source: ColumnExprSource, + matcher: &ScalarExprMatcher<'_, '_>, +) { + if let Some(existing_entry) = column_exprs + .iter_mut() + .find(|entry| matcher.expr_equal(&entry.expr, &expr)) + { + if source.prefer_over(existing_entry.source, index, existing_entry.index) { + existing_entry.index = index; + existing_entry.source = source; + } + return; + } + + column_exprs.push(ColumnExprEntry { + expr, + index, + source, + }); +} + +pub(super) struct ScalarExprMatcher<'a, 'b> { + left_column_map: &'a HashMap, + right_column_map: &'b HashMap, +} + +impl<'a, 'b> ScalarExprMatcher<'a, 'b> { + pub(super) fn new( + left_column_map: &'a HashMap, + right_column_map: &'b HashMap, + ) -> Self { + Self { + left_column_map, + right_column_map, + } + } + + pub(super) fn same(column_map: &'a HashMap) -> ScalarExprMatcher<'a, 'a> { + ScalarExprMatcher::new(column_map, column_map) + } + + fn resolve_left(&self, scalar: &ScalarExpr) -> Option<&'a ScalarExpr> { + Self::resolve_scalar(scalar, self.left_column_map) + } + + pub(super) fn resolve_right(&self, scalar: &ScalarExpr) -> Option<&'b ScalarExpr> { + Self::resolve_scalar(scalar, self.right_column_map) + } + + fn resolve_scalar<'map>( + scalar: &ScalarExpr, + column_map: &'map HashMap, + ) -> Option<&'map ScalarExpr> { + let ScalarExpr::BoundColumnRef(col) = scalar else { + return None; + }; + + let mapped = column_map.get(&col.column.index)?; + match mapped { + ScalarExpr::BoundColumnRef(mapped_col) + if mapped_col.column.index == col.column.index + && mapped_col.column.table_index == col.column.table_index => + { + None + } + _ => Some(mapped), + } + } + + fn expr_equal(&self, left: &ScalarExpr, right: &ScalarExpr) -> bool { + if let Some(mapped_left) = self.resolve_left(left) { + return self.expr_equal(mapped_left, right); + } + if let Some(mapped_right) = self.resolve_right(right) { + return self.expr_equal(left, mapped_right); + } + + match (left, right) { + (ScalarExpr::BoundColumnRef(l), ScalarExpr::BoundColumnRef(r)) => { + l.column.table_index == r.column.table_index + && l.column.column_name == r.column.column_name + } + (ScalarExpr::ConstantExpr(l), ScalarExpr::ConstantExpr(r)) + | (ScalarExpr::ConstantExpr(l), ScalarExpr::TypedConstantExpr(r, _)) + | (ScalarExpr::TypedConstantExpr(l, _), ScalarExpr::ConstantExpr(r)) + | (ScalarExpr::TypedConstantExpr(l, _), ScalarExpr::TypedConstantExpr(r, _)) => { + l.value == r.value + } + (ScalarExpr::FunctionCall(l), ScalarExpr::FunctionCall(r)) => { + l.func_name == r.func_name + && l.params == r.params + && l.arguments.len() == r.arguments.len() + && l.arguments + .iter() + .zip(r.arguments.iter()) + .all(|(l, r)| self.expr_equal(l, r)) + } + (ScalarExpr::CastExpr(l), ScalarExpr::CastExpr(r)) => { + l.is_try == r.is_try + && l.target_type == r.target_type + && self.expr_equal(&l.argument, &r.argument) + } + (ScalarExpr::AggregateFunction(l), ScalarExpr::AggregateFunction(r)) => { + l.func_name == r.func_name + && l.distinct == r.distinct + && l.params == r.params + && l.args.len() == r.args.len() + && l.sort_descs.len() == r.sort_descs.len() + && l.args + .iter() + .zip(r.args.iter()) + .all(|(l, r)| self.expr_equal(l, r)) + && l.sort_descs.iter().zip(r.sort_descs.iter()).all(|(l, r)| { + l.nulls_first == r.nulls_first + && l.asc == r.asc + && self.expr_equal(&l.expr, &r.expr) + }) + } + (ScalarExpr::UDAFCall(l), ScalarExpr::UDAFCall(r)) => { + l.name == r.name + && l.arguments.len() == r.arguments.len() + && l.arguments + .iter() + .zip(r.arguments.iter()) + .all(|(l, r)| self.expr_equal(l, r)) + } + (ScalarExpr::UDFCall(l), ScalarExpr::UDFCall(r)) => { + l.handler == r.handler + && l.arguments.len() == r.arguments.len() + && l.arguments + .iter() + .zip(r.arguments.iter()) + .all(|(l, r)| self.expr_equal(l, r)) + } + _ => false, + } + } + + pub(super) fn list_contains(&self, list: &[ScalarExpr], target: &ScalarExpr) -> bool { + list.iter().any(|expr| self.expr_equal(expr, target)) + } + + pub(super) fn push_unique_scalar(&self, list: &mut Vec, scalar: ScalarExpr) { + if !self.list_contains(list, &scalar) { + list.push(scalar); + } + } + + pub(super) fn find_column_index( + &self, + column_exprs: &[ColumnExprEntry], + target: &ScalarExpr, + ) -> Option { + column_exprs.iter().find_map(|entry| { + if self.expr_equal(&entry.expr, target) { + Some(entry.index) + } else { + None + } + }) + } + + pub(super) fn find_index_output_col( + &self, + index_output_cols: &[IndexOutputColumn], + target: &ScalarExpr, + ) -> Option<(ScalarExpr, bool)> { + index_output_cols.iter().find_map(|entry| { + if self.expr_equal(&entry.expr, target) { + Some((entry.index_scalar.clone(), entry.is_agg)) + } else { + None + } + }) + } +} diff --git a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs index 722cb24d5bcdc..408136ad2f4a8 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs @@ -12,61 +12,45 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::cmp::Ordering; -use std::collections::BTreeMap; -use std::collections::HashMap; -use std::collections::HashSet; use std::sync::Arc; -use databend_common_exception::ErrorCode; +use databend_common_catalog::table_context::TableContext; use databend_common_exception::Result; -use databend_common_expression::FieldIndex; -use databend_common_expression::Scalar; -use databend_common_expression::TableField; -use databend_common_expression::TableSchemaRefExt; -use databend_common_expression::infer_schema_type; -use databend_common_expression::types::DataType; -use databend_common_functions::aggregates::AggregateFunctionFactory; -use itertools::Itertools; -use log::info; -use crate::ColumnBinding; +use super::prepare::AggIndexViewInfo; +use super::prepare::PreparedAggIndexQuery; +use crate::AggIndexPlan; use crate::ColumnEntry; use crate::IndexType; -use crate::ScalarExpr; -use crate::Symbol; -use crate::Visibility; -use crate::binder::ColumnBindingBuilder; +use crate::MetadataRef; +use crate::optimizer::OptimizerContext; use crate::optimizer::ir::SExpr; -use crate::plans::AggIndexInfo; -use crate::plans::Aggregate; -use crate::plans::AggregateFunctionScalarSortDesc; -use crate::plans::BoundColumnRef; -use crate::plans::ConstantExpr; -use crate::plans::FunctionCall; -use crate::plans::RelOperator; -use crate::plans::ScalarItem; -use crate::plans::SortItem; +use crate::optimizer::optimizers::recursive::RecursiveRuleOptimizer; +use crate::optimizer::optimizers::rule::RuleID; +use crate::planner::QueryExecutor; pub fn try_rewrite( table_index: IndexType, table_name: &str, base_columns: &[ColumnEntry], s_expr: &SExpr, - index_plans: &[(u64, String, SExpr)], + index_plans: &[AggIndexPlan], ) -> Result> { if index_plans.is_empty() { return Ok(None); } - let query_info = QueryInfo::new(table_index, table_name, base_columns, s_expr)?; - let agg_index_rewriter = AggIndexRewriter::new(query_info); + let prepared_query = + PreparedAggIndexQuery::prepare(table_index, table_name, base_columns, s_expr)?; + let matcher = prepared_query.matcher(); - for (index_id, sql, view_s_expr) in index_plans.iter() { - let view_info = ViewInfo::new(table_index, table_name, base_columns, view_s_expr)?; - if let Some(result) = - agg_index_rewriter.try_rewrite_index(s_expr, *index_id, sql, &view_info)? - { + for index_plan in index_plans { + if let Some(result) = matcher.try_rewrite_index( + s_expr, + index_plan.index_id, + &index_plan.sql, + index_plan.prepared.as_ref(), + )? { return Ok(Some(result)); } } @@ -74,1314 +58,50 @@ pub fn try_rewrite( Ok(None) } -struct QueryInfo { - equi_classes: EquivalenceClasses, - range_classes: RangeClasses, - residual_classes: ResidualClasses, - sort_items: Option>, - aggregate: Option, - column_map: HashMap, - column_display_map: HashMap, - output_cols: Vec, -} - -impl QueryInfo { - fn new( - table_index: IndexType, - table_name: &str, - base_columns: &[ColumnEntry], - s_expr: &SExpr, - ) -> Result { - if let RelOperator::EvalScalar(eval) = s_expr.plan() { - let selection = eval; - - let mut predicates: std::option::Option<&[ScalarExpr]> = None; - let mut sort_items = None; - let mut aggregate = None; - let mut column_map = HashMap::new(); - - for item in &eval.items { - column_map.insert(item.index, item.scalar.clone()); - } - - // collect query info from the plan - let mut s_expr = s_expr.child(0)?; - loop { - match s_expr.plan() { - RelOperator::EvalScalar(eval) => { - for item in &eval.items { - column_map.insert(item.index, item.scalar.clone()); - } - } - RelOperator::Aggregate(agg) => { - if agg.grouping_sets.is_some() { - return Err(ErrorCode::Internal("Grouping sets is not supported")); - } - - aggregate = Some(agg.clone()); - for item in &agg.aggregate_functions { - column_map.insert(item.index, item.scalar.clone()); - } - for item in &agg.group_items { - column_map.insert(item.index, item.scalar.clone()); - } - let child = s_expr.child(0)?; - if let RelOperator::EvalScalar(eval) = child.plan() { - for item in &eval.items { - column_map.insert(item.index, item.scalar.clone()); - } - s_expr = child.child(0)?; - continue; - } - } - RelOperator::Sort(sort) => { - sort_items = Some(sort.items.clone()); - } - RelOperator::Filter(filter) => { - predicates = Some(filter.predicates.as_ref()); - } - RelOperator::Scan(scan) => { - if let Some(prewhere) = &scan.prewhere { - debug_assert!(predicates.is_none()); - predicates = Some(prewhere.predicates.as_ref()); - } - // Finish the recursion. - break; - } - _ => { - return Err(ErrorCode::Internal("Unsupported plan")); - } - } - s_expr = s_expr.child(0)?; - } - - for base_column in base_columns { - let column_binding = ColumnBindingBuilder::new( - base_column.name(), - base_column.index(), - Box::new(base_column.data_type()), - Visibility::Visible, - ) - .table_name(Some(table_name.to_string())) - .table_index(Some(table_index)) - .build(); - - let column = ScalarExpr::BoundColumnRef(BoundColumnRef { - span: None, - column: column_binding, - }); - column_map.insert(base_column.index(), column); - } - - let mut column_display_map = HashMap::new(); - for (index, scalar) in column_map.iter() { - let display_name = format_scalar(scalar, &column_map); - if let Some(old_index) = column_display_map.get(&display_name) { - // use index from low level plan first. - if old_index < index { - continue; - } - } - column_display_map.insert(display_name, *index); - } - - let mut preds_splitter = PredicatesSplitter::new(); - let mut equi_classes = EquivalenceClasses::new(); - let mut range_classes = RangeClasses::new(); - let mut residual_classes = ResidualClasses::new(); - - // split predicates as equal predicate, range predicate, and residual predicate. - if let Some(predicates) = predicates { - for pred in predicates { - preds_splitter.split(pred, &column_map); - } - - for equi_pred in &preds_splitter.equi_columns_preds { - equi_classes.add_equivalence_class(&equi_pred.0, &equi_pred.1); - } - for range_pred in &preds_splitter.range_preds { - range_classes.add_range_class(&range_pred.0, &range_pred.1, &range_pred.2); - } - for residual_pred in &preds_splitter.residual_preds { - let display_name = format_scalar(residual_pred, &column_map); - residual_classes.add_residual_pred(display_name, residual_pred); - } - } - - let mut output_cols = Vec::with_capacity(selection.items.len()); - for item in &selection.items { - let actual_scalar = actual_column_ref(&item.scalar, &column_map); - let actual_item = ScalarItem { - index: item.index, - scalar: actual_scalar.clone(), - }; - output_cols.push(actual_item); - } - - Ok(Self { - equi_classes, - range_classes, - residual_classes, - aggregate, - sort_items, - column_map, - column_display_map, - output_cols, - }) - } else { - Err(ErrorCode::Internal("Unsupported plan")) - } - } - - // check whether the scalar can be computed from index output columns. - // if not, the aggregating index can't be used. - fn check_output_cols( - &self, - scalar: &ScalarExpr, - index_output_cols: &HashMap, - new_selection_set: &mut HashSet, - ) -> Result> { - let display_name = format_scalar(scalar, &self.column_map); - - if let Some((new_scalar, is_agg)) = index_output_cols.get(&display_name) { - if let Some(index) = self.column_display_map.get(&display_name) { - let new_item = ScalarItem { - index: *index, - scalar: new_scalar.clone(), - }; - new_selection_set.insert(new_item); - } - // agg function can't used - if *is_agg { - return Ok(None); - } else { - return Ok(Some(new_scalar.clone())); - } - } - - let new_scalar = match scalar { - ScalarExpr::BoundColumnRef(_) => { - let actual_column = actual_column_ref(scalar, &self.column_map); - if !matches!(actual_column, ScalarExpr::BoundColumnRef(_)) { - return self.check_output_cols( - actual_column, - index_output_cols, - new_selection_set, - ); - } - return Err(ErrorCode::Internal("Can't found column from index")); - } - ScalarExpr::ConstantExpr(_) => scalar.clone(), - ScalarExpr::FunctionCall(func) => { - let mut valid = true; - let mut new_args = Vec::with_capacity(func.arguments.len()); - for arg in &func.arguments { - if let Some(new_arg) = - self.check_output_cols(arg, index_output_cols, new_selection_set)? - { - new_args.push(new_arg); - } else { - valid = false; - } - } - if !valid { - return Ok(None); - } - let mut new_func = func.clone(); - new_func.arguments = new_args; - ScalarExpr::FunctionCall(new_func) - } - ScalarExpr::CastExpr(cast) => { - if let Some(new_arg) = - self.check_output_cols(&cast.argument, index_output_cols, new_selection_set)? - { - let mut new_cast = cast.clone(); - new_cast.argument = Box::new(new_arg); - ScalarExpr::CastExpr(new_cast) - } else { - return Ok(None); - } - } - ScalarExpr::AggregateFunction(func) => { - // agg function can't push down - for expr in func.exprs() { - self.check_output_cols(expr, index_output_cols, new_selection_set)?; - } - return Ok(None); - } - ScalarExpr::UDAFCall(udaf) => { - for arg in &udaf.arguments { - self.check_output_cols(arg, index_output_cols, new_selection_set)?; - } - return Ok(None); - } - ScalarExpr::UDFCall(udf) => { - let mut valid = true; - let mut new_args = Vec::with_capacity(udf.arguments.len()); - for arg in &udf.arguments { - if let Some(new_arg) = - self.check_output_cols(arg, index_output_cols, new_selection_set)? - { - new_args.push(new_arg); - } else { - valid = false; - } - } - if !valid { - return Ok(None); - } - let mut new_udf = udf.clone(); - new_udf.arguments = new_args; - ScalarExpr::UDFCall(new_udf) - } - _ => unreachable!(), // Window function and subquery will not appear in index. - }; - - if let Some(index) = self.column_display_map.get(&display_name) { - let new_item = ScalarItem { - index: *index, - scalar: new_scalar.clone(), - }; - new_selection_set.insert(new_item); - return Ok(None); - } - - Ok(Some(new_scalar)) - } -} - -// Record information of aggregating index plan. -struct ViewInfo { - query_info: QueryInfo, - index_fields: Vec, - index_output_cols: HashMap, -} - -impl ViewInfo { - fn new( - table_index: IndexType, - table_name: &str, - base_columns: &[ColumnEntry], - s_expr: &SExpr, - ) -> Result { - let query_info = QueryInfo::new(table_index, table_name, base_columns, s_expr)?; - - // collect the output columns of aggregating index, - // query can use those columns to compute expressions. - let mut index_fields = Vec::with_capacity(query_info.output_cols.len()); - let mut index_output_cols = HashMap::with_capacity(query_info.output_cols.len()); - let factory = AggregateFunctionFactory::instance(); - for (index, item) in query_info.output_cols.iter().enumerate() { - let display_name = format_scalar(&item.scalar, &query_info.column_map); - - let aggr_scalar_item = query_info.aggregate.as_ref().and_then(|aggregate| { - aggregate - .aggregate_functions - .iter() - .find(|agg_func| agg_func.index == item.index) - }); - - let (data_type, is_agg) = match aggr_scalar_item { - Some(item) => { - let func = match &item.scalar { - ScalarExpr::AggregateFunction(func) => func, - _ => unreachable!(), - }; - let func = factory.get( - &func.func_name, - func.params.clone(), - func.args - .iter() - .map(|arg| arg.data_type()) - .collect::>()?, - func.sort_descs - .iter() - .map(|desc| desc.try_into()) - .collect::>()?, - )?; - (func.serialize_data_type(), true) - } - None => (item.scalar.data_type().unwrap(), false), - }; - - let name = index.to_string(); - let table_ty = infer_schema_type(&data_type)?; - let index_field = TableField::new(&name, table_ty); - index_fields.push(index_field); - - let index_scalar = to_index_scalar(index, &data_type); - index_output_cols.insert(display_name, (index_scalar, is_agg)); - } - - Ok(Self { - query_info, - index_fields, - index_output_cols, - }) - } -} - -struct PredicatesSplitter { - equi_columns_preds: Vec<(BoundColumnRef, BoundColumnRef)>, - range_preds: Vec<(String, BoundColumnRef, ConstantExpr)>, - residual_preds: Vec, -} - -impl PredicatesSplitter { - fn new() -> Self { - Self { - equi_columns_preds: vec![], - range_preds: vec![], - residual_preds: vec![], - } - } - - fn split(&mut self, pred: &ScalarExpr, column_map: &HashMap) { - if let ScalarExpr::FunctionCall(func) = pred { - match func.func_name.as_str() { - "and" | "and_filters" => { - for arg in &func.arguments { - self.split(arg, column_map); - } - } - "eq" if matches!(func.arguments[0], ScalarExpr::BoundColumnRef(_)) - && matches!(func.arguments[1], ScalarExpr::BoundColumnRef(_)) => - { - let arg0 = actual_column_ref(&func.arguments[0], column_map); - let arg1 = actual_column_ref(&func.arguments[1], column_map); - - let col0 = BoundColumnRef::try_from(arg0.clone()).unwrap(); - let col1 = BoundColumnRef::try_from(arg1.clone()).unwrap(); - self.equi_columns_preds.push((col0, col1)); - } - "eq" | "lt" | "lte" | "gt" | "gte" - if matches!(func.arguments[0], ScalarExpr::BoundColumnRef(_)) - && matches!(func.arguments[1], ScalarExpr::ConstantExpr(_)) => - { - let func_name = func.func_name.clone(); - let arg0 = actual_column_ref(&func.arguments[0], column_map); - let col = BoundColumnRef::try_from(arg0.clone()).unwrap(); - let val = ConstantExpr::try_from(func.arguments[1].clone()).unwrap(); - self.range_preds.push((func_name, col, val)); - } - "eq" | "lt" | "lte" | "gt" | "gte" - if matches!(func.arguments[0], ScalarExpr::ConstantExpr(_)) - && matches!(func.arguments[1], ScalarExpr::BoundColumnRef(_)) => - { - let func_name = reverse_op(func.func_name.as_str()); - - let val = ConstantExpr::try_from(func.arguments[0].clone()).unwrap(); - let arg1 = actual_column_ref(&func.arguments[1], column_map); - let col = BoundColumnRef::try_from(arg1.clone()).unwrap(); - self.range_preds.push((func_name, col, val)); - } - _ => { - self.residual_preds.push(pred.clone()); - } - } - } else { - self.residual_preds.push(pred.clone()); - } - } -} - -struct EquivalenceClasses { - column_to_equivalence_class: HashMap>, -} - -impl EquivalenceClasses { - fn new() -> Self { - Self { - column_to_equivalence_class: HashMap::new(), - } - } - - fn add_equivalence_class(&mut self, col1: &BoundColumnRef, col2: &BoundColumnRef) { - let mut equivalence_columns = HashSet::new(); - - let col1_name = format_col(&col1.column); - let col2_name = format_col(&col2.column); - - equivalence_columns.insert(col1_name.clone()); - equivalence_columns.insert(col2_name.clone()); - - if let Some(c1) = self.column_to_equivalence_class.get(&col1_name) { - for c in c1 { - equivalence_columns.insert(c.clone()); - } - } - if let Some(c2) = self.column_to_equivalence_class.get(&col2_name) { - for c in c2 { - equivalence_columns.insert(c.clone()); - } - } - - for column in &equivalence_columns { - if let Some(orig_columns) = self.column_to_equivalence_class.get_mut(column) { - for equi_column in &equivalence_columns { - if equi_column == column { - continue; - } - orig_columns.insert(equi_column.clone()); - } - } else { - let mut equi_cols = equivalence_columns.clone(); - equi_cols.remove(column); - self.column_to_equivalence_class - .insert(column.clone(), equi_cols); - } - } - } - - // Equijoin subsumption test. - fn check(&self, view_equi_classes: &EquivalenceClasses) -> bool { - for (col, view_equi_cols) in view_equi_classes.column_to_equivalence_class.iter() { - if let Some(query_equi_cols) = self.column_to_equivalence_class.get(col) { - // checking whether every non-trivial view equivalence class - // is a subset of some query equivalence class - if view_equi_cols.is_subset(query_equi_cols) { - continue; - } - } - return false; - } - true - } -} - -#[derive(Eq, Clone, Debug)] -enum BoundValue { - // column >= scalar value or column <= scalar value - Closed(Scalar), - // column > scalar value or column < scalar value - Open(Scalar), - // -∞ - NegativeInfinite, - // +∞ - PositiveInfinite, -} - -#[allow(clippy::non_canonical_partial_ord_impl)] -impl PartialOrd for BoundValue { - fn partial_cmp(&self, other: &Self) -> Option { - fn cmp_scalar(s1: &Scalar, s2: &Scalar) -> Ordering { - match (s1, s2) { - (Scalar::Number(n1), Scalar::Number(n2)) => { - if n1.is_integer() && n2.is_integer() { - let v1 = n1.integer_to_i128().unwrap(); - let v2 = n2.integer_to_i128().unwrap(); - v1.cmp(&v2) - } else { - let v1 = n1.to_f64(); - let v2 = n2.to_f64(); - v1.cmp(&v2) - } - } - (_, _) => s1.cmp(s2), - } - } - - match (self, other) { - (BoundValue::NegativeInfinite, BoundValue::NegativeInfinite) => Some(Ordering::Equal), - (BoundValue::PositiveInfinite, BoundValue::PositiveInfinite) => Some(Ordering::Equal), - (BoundValue::NegativeInfinite, _) => Some(Ordering::Less), - (_, BoundValue::NegativeInfinite) => Some(Ordering::Greater), - (BoundValue::PositiveInfinite, _) => Some(Ordering::Greater), - (_, BoundValue::PositiveInfinite) => Some(Ordering::Less), - (BoundValue::Open(v1), BoundValue::Open(v2)) => Some(cmp_scalar(v1, v2)), - (BoundValue::Closed(v1), BoundValue::Closed(v2)) => Some(cmp_scalar(v1, v2)), - (BoundValue::Open(v1), BoundValue::Closed(v2)) => { - let res = cmp_scalar(v1, v2); - if res == Ordering::Equal { - Some(Ordering::Less) - } else { - Some(res) - } - } - (BoundValue::Closed(v1), BoundValue::Open(v2)) => { - let res = cmp_scalar(v1, v2); - if res == Ordering::Equal { - Some(Ordering::Greater) - } else { - Some(res) - } - } - } - } -} - -impl Ord for BoundValue { - fn cmp(&self, other: &Self) -> Ordering { - self.partial_cmp(other).unwrap_or(Ordering::Equal) - } -} - -impl PartialEq for BoundValue { - fn eq(&self, other: &Self) -> bool { - self.partial_cmp(other) == Some(Ordering::Equal) - } -} - -#[derive(Debug)] -struct RangeValues { - bounds: Option<(BoundValue, BoundValue)>, -} - -// if a column have more than one predicates, we can merge them together to simplify the process. -// -// +------+ -// | orig | -// +------+ -// +-----+ -// case 1 | new | -// +-----+ -// upper < orig_lower -// +-----+ -// case 2 | new | -// +-----+ -// lower > orig_upper -// +-----+ -// case 3 | new | -// +-----+ -// lower >= orig_lower && upper <= orig_upper -// +-----------+ -// case 4 | new | -// +-----------+ -// lower < orig_lower upper > orig_upper -// +-----+ -// case 5 | new | -// +-----+ -// upper >= orig_lower && upper <= orig_upper && lower <= orig_lower -// +-----+ -// case 6 | new | -// +-----+ -// lower >= orig_lower && lower <= orig_upper && upper >= orig_upper -impl RangeValues { - fn new() -> Self { - Self { - bounds: Some((BoundValue::NegativeInfinite, BoundValue::PositiveInfinite)), - } - } - - fn insert(&mut self, lower_bound: BoundValue, upper_bound: BoundValue) { - if let Some((orig_lower_bound, orig_upper_bound)) = &self.bounds { - // case 1 and case 2 - if upper_bound.cmp(orig_lower_bound) == Ordering::Less - || lower_bound.cmp(orig_upper_bound) == Ordering::Greater - { - self.bounds = None; - return; - } - match ( - lower_bound.cmp(orig_lower_bound), - upper_bound.cmp(orig_upper_bound), - ) { - // case 3 - (Ordering::Greater | Ordering::Equal, Ordering::Less | Ordering::Equal) => { - self.bounds = Some((lower_bound, upper_bound)) - } - // case 4 - (Ordering::Less, Ordering::Greater) => {} - (Ordering::Less, Ordering::Less | Ordering::Equal) => { - // case 5 - if matches!( - upper_bound.cmp(orig_lower_bound), - Ordering::Greater | Ordering::Equal - ) { - self.bounds = Some((orig_lower_bound.clone(), upper_bound)); - } - } - (Ordering::Greater | Ordering::Equal, Ordering::Greater) => { - // case 6 - if matches!( - lower_bound.cmp(orig_upper_bound), - Ordering::Less | Ordering::Equal - ) { - self.bounds = Some((lower_bound, orig_upper_bound.clone())); - } - } - } - } - } -} - -struct RangeClasses { - column_to_range_class: BTreeMap, -} - -impl RangeClasses { - fn new() -> Self { - Self { - column_to_range_class: BTreeMap::new(), - } - } - - fn add_range_class(&mut self, func_name: &str, col: &BoundColumnRef, val: &ConstantExpr) { - let col_name = format_col(&col.column); - - let (lower_bound, upper_bound) = match func_name { - "eq" => ( - BoundValue::Closed(val.value.clone()), - BoundValue::Closed(val.value.clone()), - ), - "lt" => ( - BoundValue::NegativeInfinite, - BoundValue::Open(val.value.clone()), - ), - "lte" => ( - BoundValue::NegativeInfinite, - BoundValue::Closed(val.value.clone()), - ), - "gt" => ( - BoundValue::Open(val.value.clone()), - BoundValue::PositiveInfinite, - ), - "gte" => ( - BoundValue::Closed(val.value.clone()), - BoundValue::PositiveInfinite, - ), - _ => unreachable!(), - }; - if !self.column_to_range_class.contains_key(&col_name) { - self.column_to_range_class - .insert(col_name.clone(), RangeValues::new()); - } - if let Some(range_values) = self.column_to_range_class.get_mut(&col_name) { - range_values.insert(lower_bound, upper_bound); - } - } - - // Range subsumption test. - #[allow(clippy::type_complexity)] - fn check( - &self, - view_range_classes: &RangeClasses, - ) -> (bool, Option>) { - // if the range predicate in aggregating index and the query have three cases. - // 1. the range of aggregating index and query are same, don't need extra filter ranges. - // 2. the range of aggregating index filter less values than the query, - // we can add extra range predicate to implement the filter. - // for example: aggregating index: a > 10 and query: a > 15 - // we can add extra range as a > 15 - // 3. the range of aggregating index filter more values than the query, - // this aggregating index don't match the query. - // for example: aggregating index: a > 10 and query: a > 5 - let mut extra_ranges = BTreeMap::new(); - for (col, query_range_values) in self.column_to_range_class.iter() { - if !view_range_classes.column_to_range_class.contains_key(col) { - if let Some((query_lower_bound, query_upper_bound)) = &query_range_values.bounds { - extra_ranges.insert( - col.clone(), - (query_lower_bound.clone(), query_upper_bound.clone()), - ); - } - } - } - - for (col, view_range_values) in view_range_classes.column_to_range_class.iter() { - if let Some(query_range_values) = self.column_to_range_class.get(col) { - match (&query_range_values.bounds, &view_range_values.bounds) { - ( - Some((query_lower_bound, query_upper_bound)), - Some((view_lower_bound, view_upper_bound)), - ) => { - let lower_res = view_lower_bound.cmp(query_lower_bound); - let upper_res = view_upper_bound.cmp(query_upper_bound); - - match (lower_res, upper_res) { - (Ordering::Equal, Ordering::Equal) => { - continue; - } - (Ordering::Equal, Ordering::Greater) => { - extra_ranges.insert( - col.clone(), - (BoundValue::NegativeInfinite, query_upper_bound.clone()), - ); - } - (Ordering::Less, Ordering::Equal) => { - extra_ranges.insert( - col.clone(), - (query_lower_bound.clone(), BoundValue::PositiveInfinite), - ); - } - (Ordering::Less, Ordering::Greater) => { - extra_ranges.insert( - col.clone(), - (query_lower_bound.clone(), query_upper_bound.clone()), - ); - } - (_, _) => { - return (false, None); - } - } - } - (Some((query_lower_bound, query_upper_bound)), None) => { - extra_ranges.insert( - col.clone(), - (query_lower_bound.clone(), query_upper_bound.clone()), - ); - } - (_, _) => { - return (false, None); - } - } - } else { - return (false, None); - } - } - if extra_ranges.is_empty() { - (true, None) - } else { - (true, Some(extra_ranges)) - } - } -} - -#[derive(Debug)] -struct ResidualClasses { - residual_preds: BTreeMap, -} - -impl ResidualClasses { - fn new() -> Self { - Self { - residual_preds: BTreeMap::new(), - } - } - - fn add_residual_pred(&mut self, pred_display: String, pred: &ScalarExpr) { - self.residual_preds.insert(pred_display, pred.clone()); - } - - // Residual subsumption test. - fn check(&self, view_residual_classes: &ResidualClasses) -> (bool, Option>) { - let mut extra_residual_preds = Vec::new(); - for (view_residual_key, _) in view_residual_classes.residual_preds.iter() { - if !self.residual_preds.contains_key(view_residual_key) { - return (false, None); - } - } - // TODO: continue split residual predicates and check - for (query_residual_key, query_residual_pred) in self.residual_preds.iter() { - if !view_residual_classes - .residual_preds - .contains_key(query_residual_key) - { - extra_residual_preds.push(query_residual_pred.clone()); - } - } - if extra_residual_preds.is_empty() { - (true, None) - } else { - (true, Some(extra_residual_preds)) - } - } -} - -// Aggregating index rewriting logic is based on "Optimizing Queries Using Materialized Views: -// A Practical, Scalable Solution" by Goldstein and Larson." -struct AggIndexRewriter { - query_info: QueryInfo, -} - -impl AggIndexRewriter { - fn new(query_info: QueryInfo) -> Self { - Self { query_info } - } - - fn try_rewrite_index( - &self, - s_expr: &SExpr, - index_id: u64, - sql: &str, - view_info: &ViewInfo, - ) -> Result> { - let mut new_predicates = Vec::new(); - let mut new_selection_set = HashSet::new(); - - if !self.check_predicates(view_info, &mut new_predicates, &mut new_selection_set) { - return Ok(None); - } - - if !self.check_output_expressions(view_info, &mut new_selection_set) { - return Ok(None); - } - - if !self.check_aggregation(view_info, &mut new_selection_set) { - return Ok(None); - } - - if !self.check_sort_items(view_info, &mut new_selection_set) { - return Ok(None); - } - - let mut new_selection: Vec<_> = new_selection_set.into_iter().collect(); - new_selection.sort_by_key(|i| i.index); - - let is_agg = self.query_info.aggregate.is_some(); - let num_agg_funcs = self - .query_info - .aggregate - .as_ref() - .map(|agg| agg.aggregate_functions.len()) - .unwrap_or_default(); - - let result = push_down_index_scan(s_expr, AggIndexInfo { - index_id, - selection: new_selection, - predicates: new_predicates, - schema: TableSchemaRefExt::create(view_info.index_fields.clone()), - is_agg, - num_agg_funcs, - })?; - - info!("Use aggregating index: {sql}"); - - Ok(Some(result)) - } - - fn check_predicates( - &self, - view_info: &ViewInfo, - new_predicates: &mut Vec, - new_selection_set: &mut HashSet, - ) -> bool { - // 3.1.2 Do all required rows exist in the view? - // 1. Compute equivalence classes for the query and the view. - // 2. Check that every view equivalence class is a subset of a - // query equivalence class. If not, reject the view - // 3. Compute range intervals for the query and the view. - // 4. Check that every view range contains the corresponding - // query range. If not, reject the view. - // 5. Check that every conjunct in the residual predicate of the - // view matches a conjunct in the residual predicate of the query. - // If not, reject the view. - if !self - .query_info - .equi_classes - .check(&view_info.query_info.equi_classes) - { - return false; - } - let (range_res, extra_ranges) = self - .query_info - .range_classes - .check(&view_info.query_info.range_classes); - if !range_res { - return false; - } - - let (residual_res, extra_residual_preds) = self - .query_info - .residual_classes - .check(&view_info.query_info.residual_classes); - if !residual_res { - return false; - } - - // 3.1.3 Can the required rows be selected? - // 1. Construct compensating column equality predicates - // while comparing view equivalence classes against query equivalence classes as described in the previous section. - // Try to map every column reference to an output column (using the view equivalence classes). - // If this is not possible, reject the view. - // 2. Construct compensating range predicates by comparing column ranges as described in the previous section. - // Try to map every column reference to an output column (using the query equivalence classes). - // If this is not possible, reject the view. - // 3. Find the residual predicates of the query that are missing in the view. - // Try to map every column reference to an output column (using the query equivalence classes). - // If this is not possible, reject the view. - - if let Some(extra_ranges) = extra_ranges { - for (col, (lower_bound, upper_bound)) in extra_ranges.iter() { - // materialized view output must contains the column - if let Some((new_scalar, _)) = view_info.index_output_cols.get(col) { - let lower = match lower_bound { - BoundValue::Closed(val) => Some((val.clone(), "gte")), - BoundValue::Open(val) => Some((val.clone(), "gt")), - BoundValue::NegativeInfinite => None, - _ => unreachable!(), - }; - let upper = match upper_bound { - BoundValue::Closed(val) => Some((val.clone(), "lte")), - BoundValue::Open(val) => Some((val.clone(), "lt")), - BoundValue::PositiveInfinite => None, - _ => unreachable!(), - }; - - if let (Some((lower_val, "gte")), Some((upper_val, "lte"))) = (&lower, &upper) { - // if lower and upper value equal, convert to equal function - if lower_val.eq(upper_val) { - let lower_val_scalar = ScalarExpr::ConstantExpr(ConstantExpr { - span: None, - value: lower_val.clone(), - }); - let pred = ScalarExpr::FunctionCall(FunctionCall { - span: None, - func_name: "eq".to_string(), - params: vec![], - arguments: vec![new_scalar.clone(), lower_val_scalar], - }); - new_predicates.push(pred); - continue; - } - } - - if let Some((lower_val, func_name)) = lower { - let lower_val_scalar = ScalarExpr::ConstantExpr(ConstantExpr { - span: None, - value: lower_val, - }); - let pred = ScalarExpr::FunctionCall(FunctionCall { - span: None, - func_name: func_name.to_string(), - params: vec![], - arguments: vec![new_scalar.clone(), lower_val_scalar], - }); - new_predicates.push(pred); - } - if let Some((upper_val, func_name)) = upper { - let upper_val_scalar = ScalarExpr::ConstantExpr(ConstantExpr { - span: None, - value: upper_val, - }); - let pred = ScalarExpr::FunctionCall(FunctionCall { - span: None, - func_name: func_name.to_string(), - params: vec![], - arguments: vec![new_scalar.clone(), upper_val_scalar], - }); - new_predicates.push(pred); - } - } else { - return false; - } - } - } - - if let Some(extra_residual_preds) = extra_residual_preds { - for extra_residual_pred in extra_residual_preds { - match self.query_info.check_output_cols( - &extra_residual_pred, - &view_info.index_output_cols, - new_selection_set, - ) { - Ok(Some(new_residual_pred)) => { - new_predicates.push(new_residual_pred); - } - Ok(None) => {} - Err(_) => { - return false; - } - } - } - } - - true - } - - fn check_output_expressions( - &self, - view_info: &ViewInfo, - new_selection_set: &mut HashSet, - ) -> bool { - // 3.1.4 Can output expressions be computed? - // Checking whether all output expressions of the query can be computed from the view - // is similar to checking whether the additional predicates can be computed correctly. - for output_item in self.query_info.output_cols.iter() { - if self - .query_info - .check_output_cols( - &output_item.scalar, - &view_info.index_output_cols, - new_selection_set, - ) - .is_err() - { - return false; - } - } - true - } - - fn check_aggregation( - &self, - view_info: &ViewInfo, - new_selection_set: &mut HashSet, - ) -> bool { - // 3.3 Aggregation queries and views - // 1. The SPJ part of the view produces all rows needed by - // the SPJ part of the query and with the right duplication factor. - // 2. All columns required by compensating predicates (if any) are available in the view output. - // 3. The view contains no aggregation or is less aggregated than the query, - // i.e, the groups formed by the query can be computed by further aggregation of groups output by the view. - // 4. All columns required to perform further grouping (if necessary) are available in the view output. - // 5. All columns required to compute output expressions are available in the view output. - - let query_group_items = self - .query_info - .aggregate - .clone() - .map(|agg| agg.group_items) - .unwrap_or_default(); - - let view_group_items = view_info - .query_info - .aggregate - .clone() - .clone() - .map(|agg| agg.group_items) - .unwrap_or_default(); - - match (query_group_items.is_empty(), view_group_items.is_empty()) { - // both query and view have group, check for same group items. - (false, false) => { - // TODO: query can support continue group - if query_group_items.len() != view_group_items.len() { - return false; - } - let mut query_group_names = HashSet::with_capacity(query_group_items.len()); - for item in &query_group_items { - let query_group_name = format_scalar(&item.scalar, &self.query_info.column_map); - query_group_names.insert(query_group_name); - } - let mut view_group_names = HashSet::with_capacity(view_group_items.len()); - for item in &view_group_items { - let view_group_name = - format_scalar(&item.scalar, &view_info.query_info.column_map); - view_group_names.insert(view_group_name); - } - for query_group_name in query_group_names { - if !view_group_names.contains(&query_group_name) { - return false; - } - } - - for item in query_group_items { - if self - .query_info - .check_output_cols( - &item.scalar, - &view_info.index_output_cols, - new_selection_set, - ) - .is_err() - { - return false; - } - } - } - // query have group, but view don't have group, - // check group items in output rows. - (false, true) => { - for item in query_group_items { - if self - .query_info - .check_output_cols( - &item.scalar, - &view_info.index_output_cols, - new_selection_set, - ) - .is_err() - { - return false; - } - } - } - // both query and view don't have group, don't need check. - (true, true) => {} - // query don't have group, but view have group, impossible to match. - (true, false) => { - return false; - } - } - - true - } - - fn check_sort_items( - &self, - view_info: &ViewInfo, - new_selection_set: &mut HashSet, - ) -> bool { - if let Some(sort_items) = &self.query_info.sort_items { - for item in sort_items { - if let Some(scalar) = self.query_info.column_map.get(&item.index) { - if self - .query_info - .check_output_cols(scalar, &view_info.index_output_cols, new_selection_set) - .is_err() - { - return false; - } - } else { - return false; - } - } - } - - true - } -} - -fn to_index_scalar(index: FieldIndex, data_type: &DataType) -> ScalarExpr { - let col = BoundColumnRef { - span: None, - column: ColumnBindingBuilder::new( - format!("index_col_{index}"), - Symbol::from_field_index(index), - Box::new(data_type.clone()), - Visibility::Visible, - ) - .build(), - }; - ScalarExpr::BoundColumnRef(col) -} - -#[inline(always)] -fn format_col(column: &ColumnBinding) -> String { - match &column.table_name { - Some(table_name) => { - format!("{}.{}", table_name, column.column_name) - } - None => column.column_name.clone(), - } -} - -#[inline(always)] -fn reverse_op(op: &str) -> String { - match op { - "gt" => "lt".to_string(), - "gte" => "lte".to_string(), - "lt" => "gt".to_string(), - "lte" => "gte".to_string(), - "eq" => "eq".to_string(), - _ => unreachable!(), - } -} - -// replace derived column with actual ScalarExpr. -fn actual_column_ref<'a>( - col: &'a ScalarExpr, - column_map: &'a HashMap, -) -> &'a ScalarExpr { - if let ScalarExpr::BoundColumnRef(col) = col { - if let Some(arg) = column_map.get(&col.column.index) { - return arg; - } - } - col -} - -fn format_scalar(scalar: &ScalarExpr, column_map: &HashMap) -> String { - match scalar { - ScalarExpr::BoundColumnRef(_) => match actual_column_ref(scalar, column_map) { - ScalarExpr::BoundColumnRef(col) => format_col(&col.column), - s => format_scalar(s, column_map), - }, - ScalarExpr::ConstantExpr(val) => format!("{}", val.value), - ScalarExpr::FunctionCall(func) => format!( - "{}({})", - &func.func_name, - func.arguments - .iter() - .map(|arg| { format_scalar(arg, column_map) }) - .collect::>() - .join(", ") - ), - ScalarExpr::CastExpr(cast) => { - let func_name = if cast.is_try { "try_cast" } else { "cast" }; - format!( - "{}({} as {})", - func_name, - format_scalar(&cast.argument, column_map), - cast.target_type - ) - } - ScalarExpr::AggregateFunction(agg) => { - let params = agg - .params - .iter() - .map(|i| i.to_string()) - .collect::>() - .join(", "); - let args = agg - .args - .iter() - .map(|arg| format_scalar(arg, column_map)) - .collect::>() - .join(", "); - let mut scalar = if !params.is_empty() { - format!("{}<{}>({})", &agg.func_name, params, args) - } else { - format!("{}({})", &agg.func_name, args) - }; - if !agg.sort_descs.is_empty() { - let sort_descs = agg - .sort_descs - .iter() - .map(|desc| format_sort_desc(desc, column_map)) - .join(", "); - scalar = format!("{} within group (order by {})", scalar, sort_descs); - } - scalar - } - ScalarExpr::UDAFCall(udaf) => { - let args = udaf - .arguments - .iter() - .map(|arg| format_scalar(arg, column_map)) - .collect::>() - .join(", "); - format!("{}({})", &udaf.name, args) - } - ScalarExpr::UDFCall(udf) => format!( - "{}({})", - &udf.handler, - udf.arguments - .iter() - .map(|arg| { format_scalar(arg, column_map) }) - .collect::>() - .join(", ") - ), - - _ => unreachable!(), // Window function and subquery will not appear in index. - } -} - -fn format_sort_desc( - AggregateFunctionScalarSortDesc { - expr, - nulls_first, - asc, - .. - }: &AggregateFunctionScalarSortDesc, - column_map: &HashMap, -) -> String { - let mut expr = format_scalar(expr, column_map); - - if *asc { - expr.push_str(" asc"); - } else { - expr.push_str(" desc"); - } - if *nulls_first { - expr.push_str(" nulls first"); - } else { - expr.push_str(" nulls last"); - } - expr +pub fn build_agg_index_plan_for_table( + table_ctx: Arc, + query_executor: Option>, + metadata: MetadataRef, + table_index: IndexType, + index_id: u64, + sql: String, + s_expr: SExpr, +) -> Result { + let s_expr = optimize_agg_index_s_expr(table_ctx, query_executor, metadata.clone(), &s_expr)?; + let metadata_ref = metadata.read(); + let table_name = metadata_ref.table(table_index).name(); + let base_columns = metadata_ref.columns_by_table_index(table_index); + + Ok(AggIndexPlan { + index_id, + sql, + metadata: metadata.clone(), + prepared: Arc::new(AggIndexViewInfo::new( + table_index, + table_name, + &base_columns, + &s_expr, + )?), + s_expr, + }) } -fn push_down_index_scan(s_expr: &SExpr, agg_info: AggIndexInfo) -> Result { - Ok(match s_expr.plan() { - RelOperator::Scan(scan) => { - let mut new_scan = scan.clone(); - new_scan.agg_index = Some(agg_info); - s_expr.replace_plan(Arc::new(new_scan.into())) - } - _ => { - let child = push_down_index_scan(s_expr.child(0)?, agg_info)?; - s_expr.replace_children(vec![Arc::new(child)]) - } - }) +fn optimize_agg_index_s_expr( + table_ctx: Arc, + query_executor: Option>, + metadata: MetadataRef, + s_expr: &SExpr, +) -> Result { + let settings = table_ctx.get_settings(); + let opt_ctx = OptimizerContext::new(table_ctx, metadata) + .with_settings(&settings)? + .set_sample_executor(query_executor) + .clone(); + let optimizer = RecursiveRuleOptimizer::new(opt_ctx, &[ + RuleID::NormalizeScalarFilter, + RuleID::FilterNulls, + RuleID::EliminateFilter, + RuleID::MergeFilter, + ]); + optimizer.optimize_sync(s_expr) } diff --git a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/rewrite.rs b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/rewrite.rs new file mode 100644 index 0000000000000..509bc6f190f6e --- /dev/null +++ b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/rewrite.rs @@ -0,0 +1,395 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashSet; +use std::sync::Arc; + +use databend_common_exception::Result; +use databend_common_expression::Scalar; +use databend_common_expression::TableSchemaRefExt; +use log::info; + +use super::prepare::AggIndexViewInfo; +use super::prepare::CompensatingRange; +use super::prepare::QueryInfo; +use super::prepare::ScalarExprMatcher; +use crate::ScalarExpr; +use crate::optimizer::ir::SExpr; +use crate::plans::AggIndexInfo; +use crate::plans::ConstantExpr; +use crate::plans::FunctionCall; +use crate::plans::RelOperator; +use crate::plans::ScalarItem; + +// Aggregating index rewriting logic is based on "Optimizing Queries Using Materialized Views: +// A Practical, Scalable Solution" by Goldstein and Larson." +pub(super) struct AggIndexMatcher<'a> { + pub(super) query_info: &'a QueryInfo, +} + +impl AggIndexMatcher<'_> { + pub(super) fn try_rewrite_index( + &self, + s_expr: &SExpr, + index_id: u64, + sql: &str, + view_info: &AggIndexViewInfo, + ) -> Result> { + let mut new_predicates = Vec::new(); + let mut new_selection_set = HashSet::new(); + + if !self.check_predicates(view_info, &mut new_predicates, &mut new_selection_set) { + return Ok(None); + } + + if !self.check_output_expressions(view_info, &mut new_selection_set) { + return Ok(None); + } + + if !self.check_aggregation(view_info, &mut new_selection_set) { + return Ok(None); + } + + if !self.check_sort_items(view_info, &mut new_selection_set) { + return Ok(None); + } + + let mut new_selection: Vec<_> = new_selection_set.into_iter().collect(); + new_selection.sort_by_key(|i| i.index); + + let is_agg = self.query_info.aggregate.is_some(); + let num_agg_funcs = self + .query_info + .aggregate + .as_ref() + .map(|agg| agg.aggregate_functions.len()) + .unwrap_or_default(); + + let result = push_down_index_scan(s_expr, AggIndexInfo { + index_id, + selection: new_selection, + predicates: new_predicates, + schema: TableSchemaRefExt::create(view_info.index_fields.clone()), + is_agg, + num_agg_funcs, + })?; + + info!("Use aggregating index: {sql}"); + + Ok(Some(result)) + } + + fn add_compensating_range_predicates( + &self, + view_info: &AggIndexViewInfo, + extra_ranges: &[CompensatingRange], + new_predicates: &mut Vec, + ) -> bool { + let index_output_matcher = ScalarExprMatcher::new( + &view_info.query_info.column_map, + &self.query_info.column_map, + ); + + for extra_range in extra_ranges { + let Some((new_scalar, _)) = index_output_matcher + .find_index_output_col(&view_info.index_output_cols, &extra_range.column) + else { + return false; + }; + + let lower = extra_range.lower_bound.comparison_bound("gte", "gt"); + let upper = extra_range.upper_bound.comparison_bound("lte", "lt"); + + if let (Some((lower_val, "gte")), Some((upper_val, "lte"))) = (&lower, &upper) { + if lower_val == upper_val { + new_predicates.push(comparison_predicate( + "eq", + new_scalar.clone(), + lower_val.clone(), + )); + continue; + } + } + + push_optional_comparison_predicate(new_predicates, &new_scalar, lower); + push_optional_comparison_predicate(new_predicates, &new_scalar, upper); + } + + true + } + + fn check_output_scalars<'a>( + &self, + scalars: impl IntoIterator, + view_info: &AggIndexViewInfo, + new_selection_set: &mut HashSet, + ) -> bool { + for scalar in scalars { + if self + .query_info + .check_output_cols( + scalar, + &view_info.index_output_cols, + &view_info.query_info.column_map, + new_selection_set, + ) + .is_err() + { + return false; + } + } + + true + } + + fn add_missing_residual_predicates( + &self, + view_info: &AggIndexViewInfo, + extra_residual_preds: &[ScalarExpr], + new_predicates: &mut Vec, + new_selection_set: &mut HashSet, + ) -> bool { + for extra_residual_pred in extra_residual_preds { + match self.query_info.check_output_cols( + extra_residual_pred, + &view_info.index_output_cols, + &view_info.query_info.column_map, + new_selection_set, + ) { + Ok(Some(new_residual_pred)) => new_predicates.push(new_residual_pred), + Ok(None) => {} + Err(_) => return false, + } + } + + true + } + + fn query_group_items(&self) -> &[ScalarItem] { + self.query_info + .aggregate + .as_ref() + .map(|agg| agg.group_items.as_slice()) + .unwrap_or_default() + } + + fn view_group_items<'a>(&self, view_info: &'a AggIndexViewInfo) -> &'a [ScalarItem] { + view_info + .query_info + .aggregate + .as_ref() + .map(|agg| agg.group_items.as_slice()) + .unwrap_or_default() + } + + fn group_items_match(&self, view_info: &AggIndexViewInfo) -> bool { + let query_group_items = self.query_group_items(); + let view_group_items = self.view_group_items(view_info); + if query_group_items.len() != view_group_items.len() { + return false; + } + + let query_group_matcher = ScalarExprMatcher::same(&self.query_info.column_map); + let mut query_group_names = Vec::with_capacity(query_group_items.len()); + for item in query_group_items { + query_group_matcher.push_unique_scalar(&mut query_group_names, item.scalar.clone()); + } + + let view_group_matcher = ScalarExprMatcher::same(&view_info.query_info.column_map); + let mut view_group_names = Vec::with_capacity(view_group_items.len()); + for item in view_group_items { + view_group_matcher.push_unique_scalar(&mut view_group_names, item.scalar.clone()); + } + + let group_matcher = ScalarExprMatcher::new( + &view_info.query_info.column_map, + &self.query_info.column_map, + ); + query_group_names.into_iter().all(|query_group_name| { + group_matcher.list_contains(&view_group_names, &query_group_name) + }) + } + + fn check_group_items_output_cols( + &self, + group_items: &[ScalarItem], + view_info: &AggIndexViewInfo, + new_selection_set: &mut HashSet, + ) -> bool { + self.check_output_scalars( + group_items.iter().map(|item| &item.scalar), + view_info, + new_selection_set, + ) + } + + fn check_predicates( + &self, + view_info: &AggIndexViewInfo, + new_predicates: &mut Vec, + new_selection_set: &mut HashSet, + ) -> bool { + if !self.query_info.equi_classes.check( + &view_info.query_info.equi_classes, + &self.query_info.column_map, + &view_info.query_info.column_map, + ) { + return false; + } + let (range_res, extra_ranges) = self.query_info.range_classes.check( + &view_info.query_info.range_classes, + &self.query_info.column_map, + &view_info.query_info.column_map, + ); + if !range_res { + return false; + } + + let (residual_res, extra_residual_preds) = self.query_info.residual_classes.check( + &view_info.query_info.residual_classes, + &self.query_info.column_map, + &view_info.query_info.column_map, + ); + if !residual_res { + return false; + } + + if let Some(extra_ranges) = extra_ranges + && !self.add_compensating_range_predicates(view_info, &extra_ranges, new_predicates) + { + return false; + } + + if let Some(extra_residual_preds) = extra_residual_preds + && !self.add_missing_residual_predicates( + view_info, + &extra_residual_preds, + new_predicates, + new_selection_set, + ) + { + return false; + } + + true + } + + fn check_output_expressions( + &self, + view_info: &AggIndexViewInfo, + new_selection_set: &mut HashSet, + ) -> bool { + self.check_output_scalars( + self.query_info.output_cols.iter().map(|item| &item.scalar), + view_info, + new_selection_set, + ) + } + + fn check_aggregation( + &self, + view_info: &AggIndexViewInfo, + new_selection_set: &mut HashSet, + ) -> bool { + let query_group_items = self.query_group_items(); + let view_group_items = self.view_group_items(view_info); + + match (query_group_items.is_empty(), view_group_items.is_empty()) { + (false, false) => { + if !self.group_items_match(view_info) { + return false; + } + if !self.check_group_items_output_cols( + query_group_items, + view_info, + new_selection_set, + ) { + return false; + } + } + (false, true) => { + if !self.check_group_items_output_cols( + query_group_items, + view_info, + new_selection_set, + ) { + return false; + } + } + (true, true) => {} + (true, false) => return false, + } + + true + } + + fn check_sort_items( + &self, + view_info: &AggIndexViewInfo, + new_selection_set: &mut HashSet, + ) -> bool { + let Some(sort_items) = &self.query_info.sort_items else { + return true; + }; + + for item in sort_items { + let Some(scalar) = self.query_info.column_map.get(&item.index) else { + return false; + }; + + if !self.check_output_scalars([scalar], view_info, new_selection_set) { + return false; + } + } + + true + } +} + +fn comparison_predicate(func_name: &str, column: ScalarExpr, value: Scalar) -> ScalarExpr { + ScalarExpr::FunctionCall(FunctionCall { + span: None, + func_name: func_name.to_string(), + params: vec![], + arguments: vec![ + column, + ScalarExpr::ConstantExpr(ConstantExpr { span: None, value }), + ], + }) +} + +fn push_optional_comparison_predicate( + predicates: &mut Vec, + column: &ScalarExpr, + bound: Option<(Scalar, &'static str)>, +) { + let Some((value, func_name)) = bound else { + return; + }; + predicates.push(comparison_predicate(func_name, column.clone(), value)); +} + +fn push_down_index_scan(s_expr: &SExpr, agg_info: AggIndexInfo) -> Result { + Ok(match s_expr.plan() { + RelOperator::Scan(scan) => { + let mut new_scan = scan.clone(); + new_scan.agg_index = Some(agg_info); + s_expr.replace_plan(Arc::new(new_scan.into())) + } + _ => { + let child = push_down_index_scan(s_expr.child(0)?, agg_info)?; + s_expr.replace_children(vec![Arc::new(child)]) + } + }) +} diff --git a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/mod.rs b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/mod.rs index 532277a6069e2..0f65235a4e3df 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/mod.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/mod.rs @@ -22,6 +22,7 @@ mod rule_push_down_limit_aggregate; mod rule_split_aggregate; mod rule_try_apply_agg_index; +pub use agg_index::*; pub use rule_eager_aggregation::RuleEagerAggregation; pub use rule_fold_count_aggregate::RuleFoldCountAggregate; pub use rule_grouping_sets_to_union::RuleGroupingSetsToUnion; diff --git a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_try_apply_agg_index.rs b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_try_apply_agg_index.rs index f89ef2e8c9f38..f89d1ed7c0d40 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_try_apply_agg_index.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_try_apply_agg_index.rs @@ -227,9 +227,9 @@ impl Rule for RuleTryApplyAggIndex { s_expr: &SExpr, state: &mut crate::optimizer::optimizers::rule::TransformResult, ) -> Result<()> { - let (table_index, table_name) = self.get_table(s_expr); + let table_index = self.get_table(s_expr); let metadata = self.metadata.read(); - let index_plans = metadata.get_agg_indices(&table_name); + let index_plans = metadata.get_agg_indices(table_index); let Some(index_plans) = index_plans else { // No enterprise license or no index. return Ok(()); @@ -258,16 +258,9 @@ impl Rule for RuleTryApplyAggIndex { } impl RuleTryApplyAggIndex { - fn get_table(&self, s_expr: &SExpr) -> (IndexType, String) { + fn get_table(&self, s_expr: &SExpr) -> IndexType { match s_expr.plan() { - RelOperator::Scan(scan) => { - let metadata = self.metadata.read(); - let table = metadata.table(scan.table_index); - ( - scan.table_index, - format!("{}.{}.{}", table.catalog(), table.database(), table.name()), - ) - } + RelOperator::Scan(scan) => scan.table_index, _ => self.get_table(s_expr.child(0).unwrap()), } } diff --git a/src/query/sql/src/planner/planner.rs b/src/query/sql/src/planner/planner.rs index c7a067901522c..a00d200b06811 100644 --- a/src/query/sql/src/planner/planner.rs +++ b/src/query/sql/src/planner/planner.rs @@ -51,8 +51,6 @@ use crate::NameResolutionContext; use crate::VariableNormalizer; use crate::optimizer::OptimizerContext; use crate::optimizer::optimize; -use crate::optimizer::optimizers::recursive::RecursiveRuleOptimizer; -use crate::optimizer::optimizers::rule::RuleID; use crate::planner::QueryExecutor; use crate::plans::Plan; @@ -300,22 +298,6 @@ impl Planner { .set_sample_executor(self.query_executor.clone()) .clone(); - { - let mut agg_indices = metadata.read().agg_indices().clone(); - let optimizer = RecursiveRuleOptimizer::new(opt_ctx.clone(), &[ - RuleID::NormalizeScalarFilter, - RuleID::FilterNulls, - RuleID::EliminateFilter, - RuleID::MergeFilter, - ]); - for indices in &mut agg_indices.values_mut() { - for (_, _, s_expr) in indices { - *s_expr = optimizer.optimize_sync(s_expr)?; - } - } - metadata.write().replace_agg_indices(agg_indices); - } - let optimized_plan = optimize(opt_ctx, plan).await?; if enable_planner_cache { diff --git a/src/query/sql/tests/it/framework/lite_context.rs b/src/query/sql/tests/it/framework/lite_context.rs index 4e380869c95c3..c6454e9db0477 100644 --- a/src/query/sql/tests/it/framework/lite_context.rs +++ b/src/query/sql/tests/it/framework/lite_context.rs @@ -125,6 +125,9 @@ static TEST_BUILD_INFO: BuildInfo = BuildInfo { commit_detail: String::new(), embedded_license: String::new(), }; +static NEXT_LITE_CONTEXT_ID: AtomicU64 = AtomicU64::new(1); +static NEXT_FAKE_TABLE_ID: AtomicU64 = AtomicU64::new(1); +static NEXT_FAKE_INDEX_ID: AtomicU64 = AtomicU64::new(1); thread_local! { static INIT_TESTING_GLOBALS: std::sync::Once = const { std::sync::Once::new() }; @@ -166,14 +169,16 @@ fn unsupported(name: &str) -> Result { ))) } -type TableKey = (String, String); +type TableKey = (String, String, String); type TableMap = HashMap>; type ColumnStatsMap = HashMap; +type IndexMap = HashMap>; #[derive(Clone)] struct DummyCatalog { info: Arc, tables: Arc>, + indexes: Arc>, } impl std::fmt::Debug for DummyCatalog { @@ -195,19 +200,57 @@ impl Default for DummyCatalog { ..Default::default() }), tables: Arc::new(RwLock::new(HashMap::new())), + indexes: Arc::new(RwLock::new(HashMap::new())), } } } impl DummyCatalog { - fn insert_table(&self, database: &str, table: Arc) { + fn insert_table(&self, tenant: &Tenant, database: &str, table: Arc) { + self.tables.write().insert( + ( + tenant.tenant_name().to_string(), + database.to_string(), + table.name().to_string(), + ), + table, + ); + } + + fn clear_tenant(&self, tenant: &Tenant) { self.tables .write() - .insert((database.to_string(), table.name().to_string()), table); + .retain(|(table_tenant, _, _), _| table_tenant != tenant.tenant_name()); + } + + fn get_registered_table( + &self, + tenant: &Tenant, + database: &str, + table_name: &str, + ) -> Result> { + self.tables + .read() + .get(&( + tenant.tenant_name().to_string(), + database.to_string(), + table_name.to_string(), + )) + .cloned() + .ok_or_else(|| ErrorCode::UnknownTable(format!("{}.{}", database, table_name))) } - fn clear_tables(&self) { - self.tables.write().clear(); + fn insert_index(&self, table_id: MetaId, index_id: u64, name: &str, query: &str) { + self.indexes.write().entry(table_id).or_default().push(( + index_id, + name.to_string(), + IndexMeta { + table_id, + original_query: query.to_string(), + query: query.to_string(), + ..Default::default() + }, + )); } } @@ -297,6 +340,10 @@ impl Table for FakeTable { fn support_prewhere(&self) -> bool { true } + + fn support_index(&self) -> bool { + true + } } #[async_trait::async_trait] @@ -357,6 +404,36 @@ impl Catalog for DummyCatalog { unsupported("catalog::update_index") } + async fn list_indexes(&self, req: ListIndexesReq) -> Result> { + if let Some(table_id) = req.table_id { + Ok(self + .indexes + .read() + .get(&table_id) + .cloned() + .unwrap_or_default()) + } else { + Ok(self + .indexes + .read() + .values() + .flat_map(|indexes| indexes.iter().cloned()) + .collect()) + } + } + + async fn list_indexes_by_table_id( + &self, + req: ListIndexesByIdReq, + ) -> Result> { + Ok(self + .indexes + .read() + .get(&req.table_id) + .cloned() + .unwrap_or_default()) + } + async fn rename_database(&self, _req: RenameDatabaseReq) -> Result { unsupported("catalog::rename_database") } @@ -400,15 +477,11 @@ impl Catalog for DummyCatalog { async fn get_table( &self, - _tenant: &Tenant, + tenant: &Tenant, db_name: &str, table_name: &str, ) -> Result> { - self.tables - .read() - .get(&(db_name.to_string(), table_name.to_string())) - .cloned() - .ok_or_else(|| ErrorCode::UnknownTable(format!("{}.{}", db_name, table_name))) + self.get_registered_table(tenant, db_name, table_name) } async fn mget_tables( @@ -652,8 +725,8 @@ pub struct LiteTableContext { merge_into_join: RwLock, variables: RwLock>, runtime_filter_ready: RwLock>>>, + can_scan_from_agg_index: RwLock, queued_duration: RwLock, - next_table_id: AtomicU64, } impl LiteTableContext { @@ -711,7 +784,7 @@ impl LiteTableContext { .collect::>>()?; let warehouse_distribution = *self.warehouse_distribution.read(); - let table_id = self.next_table_id.fetch_add(1, Ordering::Relaxed); + let table_id = NEXT_FAKE_TABLE_ID.fetch_add(1, Ordering::Relaxed); Ok(Arc::new(FakeTable { table_info: TableInfo { @@ -734,7 +807,10 @@ impl LiteTableContext { pub async fn create() -> Result> { init_testing_globals(); - let tenant = Tenant::new_literal("default"); + let tenant = Tenant::new_literal(&format!( + "default_{}", + NEXT_LITE_CONTEXT_ID.fetch_add(1, Ordering::Relaxed) + )); let settings = Settings::create(tenant.clone()); let shared_settings = Settings::create(tenant.clone()); Self::init_user_api_provider(&tenant).await?; @@ -774,7 +850,7 @@ impl LiteTableContext { .ok_or_else(|| ErrorCode::Internal("unexpected default catalog type"))? .clone() .into(); - default_catalog.clear_tables(); + default_catalog.clear_tenant(&tenant); let ctx = Arc::new(Self { catalog_manager, @@ -808,13 +884,17 @@ impl LiteTableContext { merge_into_join: RwLock::new(MergeIntoJoin::default()), variables: RwLock::new(HashMap::new()), runtime_filter_ready: RwLock::new(HashMap::new()), + can_scan_from_agg_index: RwLock::new(false), queued_duration: RwLock::new(Duration::default()), - next_table_id: AtomicU64::new(1), }); ctx.reset_user_api_state().await?; Ok(ctx) } + pub async fn create_isolated() -> Result> { + Self::create().await + } + pub fn set_table_warehouse_distribution(&self, enabled: bool) { *self.warehouse_distribution.write() = enabled; } @@ -864,7 +944,8 @@ impl LiteTableContext { ) -> Result<()> { let table = self.build_fake_table(database, table_name, fields, table_stats, column_stats)?; - self.default_catalog.insert_table(database, table); + self.default_catalog + .insert_table(&self.tenant, database, table); Ok(()) } @@ -1020,6 +1101,22 @@ impl LiteTableContext { } } + pub fn register_agg_index( + &self, + database: &str, + table_name: &str, + name: &str, + sql: &str, + ) -> Result { + let table = + self.default_catalog + .get_registered_table(&self.tenant, database, table_name)?; + let index_id = NEXT_FAKE_INDEX_ID.fetch_add(1, Ordering::Relaxed); + self.default_catalog + .insert_index(table.get_id(), index_id, name, sql); + Ok(index_id) + } + pub async fn bind_sql(self: &Arc, sql: &str) -> Result { let planner = Planner::new(self.clone()); let extras = planner.parse_sql(sql)?; @@ -1136,9 +1233,11 @@ impl TableContext for LiteTableContext { } fn set_cacheable(&self, _cacheable: bool) {} fn get_can_scan_from_agg_index(&self) -> bool { - false + *self.can_scan_from_agg_index.read() + } + fn set_can_scan_from_agg_index(&self, enable: bool) { + *self.can_scan_from_agg_index.write() = enable; } - fn set_can_scan_from_agg_index(&self, _enable: bool) {} fn get_enable_sort_spill(&self) -> bool { false } @@ -1603,7 +1702,7 @@ $$ } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] - async fn test_create_clears_catalog_tables() -> Result<()> { + async fn test_create_keeps_catalog_tables_isolated() -> Result<()> { let ctx1 = LiteTableContext::create().await?; ctx1.register_table_with_stats("default", "t1", test_fields(), None, HashMap::new())?; ctx1.default_catalog @@ -1617,6 +1716,9 @@ $$ .await .is_err() ); + ctx1.default_catalog + .get_table(&ctx1.tenant, "default", "t1") + .await?; Ok(()) } diff --git a/src/query/sql/tests/it/optimizer/agg_index.rs b/src/query/sql/tests/it/optimizer/agg_index.rs new file mode 100644 index 0000000000000..106f6c52da56e --- /dev/null +++ b/src/query/sql/tests/it/optimizer/agg_index.rs @@ -0,0 +1,649 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use databend_common_catalog::table_context::TableContext; +use databend_common_exception::Result; +use databend_common_sql::optimizer::build_agg_index_plan_for_table; +use databend_common_sql::optimizer::ir::SExpr; +use databend_common_sql::optimizer::ir::SExprVisitor; +use databend_common_sql::optimizer::ir::VisitAction; +use databend_common_sql::plans::AggIndexInfo; +use databend_common_sql::plans::BoundColumnRef; +use databend_common_sql::plans::Operator; +use databend_common_sql::plans::Plan; +use databend_common_sql::plans::RelOperator; +use databend_common_sql::plans::Visitor; +use databend_common_sql_test_support::TestCase; +use databend_common_sql_test_support::TestCaseRunner; +use parking_lot::RwLock; + +use crate::framework::LiteTableContext; + +struct AggIndexLiteRunner { + ctx: Arc, + agg_index_table: &'static str, + agg_index_sqls: &'static [&'static str], +} + +struct AggIndexLiteCase { + case: TestCase, + agg_index_table: &'static str, + agg_index_sqls: &'static [&'static str], + is_matched: bool, + index_selection: &'static [&'static str], + rewritten_predicates: &'static [&'static str], +} + +impl TestCaseRunner for AggIndexLiteRunner { + async fn bind_sql(&self, sql: &str) -> Result { + self.ctx.bind_sql(sql).await + } + + async fn optimize_plan(&self, plan: Plan) -> Result { + if let Plan::Query { metadata, .. } = &plan { + let mut agg_index_plans = Vec::with_capacity(self.agg_index_sqls.len()); + let child_metadata = Arc::new(RwLock::new(metadata.read().clone())); + let table_index = { + let metadata_ref = metadata.read(); + metadata_ref + .tables() + .iter() + .find(|table| table.name() == self.agg_index_table) + .map(|table| table.index()) + .expect("agg index table should exist in query metadata") + }; + for (index_id, sql) in self.agg_index_sqls.iter().enumerate() { + let index_plan = self.ctx.bind_sql(sql).await?; + let Plan::Query { s_expr, .. } = index_plan else { + unreachable!("agg index sql must bind to a query plan"); + }; + agg_index_plans.push(build_agg_index_plan_for_table( + self.ctx.clone(), + None, + child_metadata.clone(), + table_index, + index_id as u64, + sql.to_string(), + *s_expr, + )?); + } + metadata + .write() + .add_agg_indices(table_index, agg_index_plans); + } + + self.ctx.optimize_plan(plan).await + } +} + +fn agg_index_test_case(name: &'static str, sql: &'static str) -> TestCase { + TestCase { + name: name.to_string(), + sql: sql.to_string(), + table_stats: HashMap::new(), + column_stats: HashMap::new(), + auto_stats: false, + stem: name.to_string(), + subdir: None, + node_num: None, + tables: HashMap::from([( + "t".to_string(), + "create table t(a int, b int, c int)".to_string(), + )]), + } +} + +fn agg_index_test_cases() -> Vec { + vec![ + AggIndexLiteCase { + case: agg_index_test_case( + "agg_index_expression_exact_match", + "select a + 1 from t where b > 1", + ), + agg_index_table: "t", + agg_index_sqls: &["select a + 1, b from t where b > 0"], + is_matched: true, + index_selection: &["index_col_0 (#0)", "index_col_1 (#1)"], + rewritten_predicates: &["gt(index_col_0 (#0), 1)"], + }, + AggIndexLiteCase { + case: agg_index_test_case( + "agg_index_expression_structure_mismatch", + "select a + 1 from t where b > 1", + ), + agg_index_table: "t", + agg_index_sqls: &["select a + 2, b from t where b > 0"], + is_matched: false, + index_selection: &[], + rewritten_predicates: &[], + }, + AggIndexLiteCase { + case: agg_index_test_case( + "agg_index_expression_alias_output_match", + "select x from (select a + 1 as x from t) s", + ), + agg_index_table: "t", + agg_index_sqls: &["select a + 1 from t"], + is_matched: true, + index_selection: &["index_col_0 (#0)"], + rewritten_predicates: &[], + }, + AggIndexLiteCase { + case: agg_index_test_case( + "agg_index_aggregate_group_mismatch", + "select sum(a) + 1 from t group by b", + ), + agg_index_table: "t", + agg_index_sqls: &["select sum(a), c from t group by c"], + is_matched: false, + index_selection: &[], + rewritten_predicates: &[], + }, + AggIndexLiteCase { + case: agg_index_test_case( + "agg_index_aggregate_expression_output_match", + "select sum(a) + 1 from t group by b", + ), + agg_index_table: "t", + agg_index_sqls: &["select sum(a), b from t group by b"], + is_matched: true, + index_selection: &["index_col_0 (#0)", "index_col_1 (#1)"], + rewritten_predicates: &[], + }, + AggIndexLiteCase { + case: agg_index_test_case( + "agg_index_aggregate_distinct_mismatch", + "select count(distinct a) from t", + ), + agg_index_table: "t", + agg_index_sqls: &["select count(a) from t"], + is_matched: false, + index_selection: &[], + rewritten_predicates: &[], + }, + ] +} + +fn find_push_down_index_info_from_plan(plan: &Plan) -> Result> { + let Plan::Query { s_expr, .. } = plan else { + return Ok(None); + }; + find_push_down_index_info(s_expr) +} + +fn find_push_down_index_info(s_expr: &SExpr) -> Result> { + match s_expr.plan() { + RelOperator::Scan(scan) => Ok(scan.agg_index.as_ref()), + _ => find_push_down_index_info(s_expr.child(0)?), + } +} + +fn format_selection(info: &AggIndexInfo) -> Vec { + let mut selection: Vec<_> = info + .selection + .iter() + .map(|sel| databend_common_sql::format_scalar(&sel.scalar)) + .collect(); + selection.sort(); + selection +} + +fn format_filter(info: &AggIndexInfo) -> Vec { + let mut predicates: Vec<_> = info + .predicates + .iter() + .map(databend_common_sql::format_scalar) + .collect(); + predicates.sort(); + predicates +} + +fn find_scan(s_expr: &SExpr) -> Result> { + match s_expr.plan() { + RelOperator::Scan(scan) => Ok(Some(scan)), + _ => find_scan(s_expr.child(0)?), + } +} + +fn collect_scan_table_indexes(s_expr: &SExpr, table_indexes: &mut Vec) { + if let RelOperator::Scan(scan) = s_expr.plan() { + table_indexes.push(scan.table_index); + } + + for child in s_expr.children() { + collect_scan_table_indexes(child, table_indexes); + } +} + +fn describe_table_columns( + metadata: &databend_common_sql::Metadata, + table_index: usize, +) -> Vec { + metadata + .columns_by_table_index(table_index) + .into_iter() + .map(|column| match column { + databend_common_sql::ColumnEntry::BaseTableColumn(col) => { + format!( + "{} idx={} table_index={}", + col.column_name, col.column_index, col.table_index + ) + } + databend_common_sql::ColumnEntry::InternalColumn(col) => { + format!( + "{} idx={} table_index={}", + col.internal_column.column_name(), + col.column_index, + col.table_index + ) + } + databend_common_sql::ColumnEntry::VirtualColumn(col) => { + format!( + "{} idx={} table_index={}", + col.column_name, col.column_index, col.table_index + ) + } + databend_common_sql::ColumnEntry::DerivedColumn(_) => unreachable!(), + }) + .collect() +} + +fn describe_bound_columns(s_expr: &SExpr) -> Result> { + struct BoundColumnCollector { + columns: Vec, + } + + impl SExprVisitor for BoundColumnCollector { + fn visit(&mut self, expr: &SExpr) -> Result { + for scalar in expr.plan().scalar_expr_iter() { + let mut scalar_collector = ScalarBoundColumnCollector { + columns: &mut self.columns, + }; + scalar_collector.visit(scalar)?; + } + Ok(VisitAction::Continue) + } + } + + struct ScalarBoundColumnCollector<'a> { + columns: &'a mut Vec, + } + + impl<'a, 'b> Visitor<'a> for ScalarBoundColumnCollector<'b> { + fn visit_bound_column_ref(&mut self, col: &'a BoundColumnRef) -> Result<()> { + if col.column.table_index.is_none() { + return Ok(()); + } + self.columns.push(format!( + "db={:?} table={:?} col={} idx={} table_index={:?}", + col.column.database_name, + col.column.table_name, + col.column.column_name, + col.column.index, + col.column.table_index + )); + Ok(()) + } + } + + let mut collector = BoundColumnCollector { + columns: Vec::new(), + }; + let _ = s_expr.accept(&mut collector)?; + collector.columns.sort(); + collector.columns.dedup(); + Ok(collector.columns) +} + +async fn setup_tables(ctx: &Arc, case: &TestCase) -> Result<()> { + for sql in case.tables.values() { + for statement in sql.split(';').filter(|s| !s.trim().is_empty()) { + ctx.register_table_sql(statement).await?; + } + } + Ok(()) +} + +async fn optimize_with_debug_agg_index( + ctx: &Arc, + query_sql: &str, + index_sql: &str, +) -> Result { + let plan = ctx.bind_sql(query_sql).await?; + let metadata = match &plan { + Plan::Query { metadata, .. } => metadata.clone(), + _ => unreachable!("query sql must bind to query plan"), + }; + + let index_plan = ctx.bind_sql(index_sql).await?; + let Plan::Query { s_expr, .. } = index_plan else { + unreachable!("index sql must bind to query plan"); + }; + let table_index = { + let metadata_guard = metadata.read(); + metadata_guard + .tables() + .iter() + .find(|table| table.name() == "t") + .map(|table| table.index()) + .expect("query metadata should contain table t") + }; + let agg_index_plan = build_agg_index_plan_for_table( + ctx.clone(), + None, + Arc::new(RwLock::new(metadata.read().clone())), + table_index, + 0, + index_sql.to_string(), + *s_expr, + )?; + metadata + .write() + .add_agg_indices(table_index, vec![agg_index_plan]); + + ctx.optimize_plan(plan).await +} + +async fn create_auto_bound_agg_index_ctx(index_sql: &str) -> Result> { + let ctx = LiteTableContext::create_isolated().await?; + ctx.configure_for_optimizer_case(false)?; + ctx.get_settings() + .set_setting("enable_aggregating_index_scan".to_string(), "1".to_string())?; + ctx.set_can_scan_from_agg_index(true); + ctx.register_table_sql("create table t(a int, b int, c int)") + .await?; + ctx.register_agg_index("default", "t", "idx1", index_sql)?; + Ok(ctx) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_lite_agg_index_optimizer_cases() -> Result<()> { + for test in agg_index_test_cases() { + let ctx = LiteTableContext::create_isolated().await?; + ctx.configure_for_optimizer_case(test.case.auto_stats)?; + setup_tables(&ctx, &test.case).await?; + + let runner = AggIndexLiteRunner { + ctx: ctx.clone(), + agg_index_table: test.agg_index_table, + agg_index_sqls: test.agg_index_sqls, + }; + let plan = runner.bind_sql(&test.case.sql).await?; + let optimized_plan = runner.optimize_plan(plan).await?; + let agg_index = find_push_down_index_info_from_plan(&optimized_plan)?; + + assert_eq!( + test.is_matched, + agg_index.is_some(), + "case: {}, sql: {}, indexes: {:?}", + test.case.name, + test.case.sql, + test.agg_index_sqls + ); + + if let Some(agg_index) = agg_index { + let mut expected_selection: Vec<_> = test + .index_selection + .iter() + .map(|s| (*s).to_string()) + .collect(); + expected_selection.sort(); + assert_eq!( + expected_selection, + format_selection(agg_index), + "case: {} selection mismatch", + test.case.name + ); + + let mut expected_predicates: Vec<_> = test + .rewritten_predicates + .iter() + .map(|s| (*s).to_string()) + .collect(); + expected_predicates.sort(); + assert_eq!( + expected_predicates, + format_filter(agg_index), + "case: {} predicate mismatch", + test.case.name + ); + } + } + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_lite_agg_index_auto_bound_matches_manual_injection() -> Result<()> { + let index_sql = "select b, sum(a) from t where b > 1 group by b"; + for query_sql in [ + "select sum(a), b from t where b > 1 group by b", + "select b from t where b > 1 group by b", + "select sum(a) + 1 from t where b > 1 group by b", + ] { + let auto_ctx = create_auto_bound_agg_index_ctx(index_sql).await?; + let auto_plan = auto_ctx + .optimize_plan(auto_ctx.bind_sql(query_sql).await?) + .await?; + let auto_info = find_push_down_index_info_from_plan(&auto_plan)? + .expect("auto-bound agg index should match"); + + let manual_ctx = LiteTableContext::create_isolated().await?; + manual_ctx.configure_for_optimizer_case(false)?; + manual_ctx + .register_table_sql("create table t(a int, b int, c int)") + .await?; + let manual_plan = optimize_with_debug_agg_index(&manual_ctx, query_sql, index_sql).await?; + let manual_info = find_push_down_index_info_from_plan(&manual_plan)? + .expect("manually injected agg index should match"); + + assert_eq!(format_selection(auto_info), format_selection(manual_info)); + assert_eq!(format_filter(auto_info), format_filter(manual_info)); + } + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_lite_agg_index_auto_bound_index_plan_is_normalized() -> Result<()> { + let index_sql = "select b, sum(a) from t where b > 1 group by b"; + let query_sql = "select sum(a), b from t where b > 1 group by b"; + let ctx = create_auto_bound_agg_index_ctx(index_sql).await?; + let plan = ctx.bind_sql(query_sql).await?; + + let (query_scan, metadata_ref) = match &plan { + Plan::Query { + s_expr, metadata, .. + } => ( + find_scan(s_expr)?.expect("query scan should exist"), + metadata.clone(), + ), + _ => unreachable!("debug test must bind to query plan"), + }; + let metadata = metadata_ref.read(); + + let agg_index = metadata + .get_agg_indices(query_scan.table_index) + .expect("agg index should be bound from catalog") + .first() + .expect("agg index should not be empty"); + let index_metadata = agg_index.metadata.read(); + let index_scan = find_scan(&agg_index.s_expr)?.expect("index scan should exist"); + let query_columns = describe_table_columns(&metadata, query_scan.table_index); + let index_columns = describe_table_columns(&index_metadata, index_scan.table_index); + let index_bound_columns = describe_bound_columns(&agg_index.s_expr)?; + + assert_eq!( + index_columns, query_columns, + "catalog-bound agg index table metadata should match query table metadata" + ); + assert!( + !Arc::ptr_eq(&agg_index.metadata, &metadata_ref), + "catalog-bound agg index should use metadata independent from the main query", + ); + assert!( + index_scan.statistics.table_stats.is_none(), + "normalized catalog-bound agg index scan should clear table statistics" + ); + assert!( + index_scan.statistics.column_stats.is_empty(), + "normalized catalog-bound agg index scan should clear column statistics" + ); + assert!( + index_scan.statistics.histograms.is_empty(), + "normalized catalog-bound agg index scan should clear histograms" + ); + assert_eq!( + index_bound_columns, + vec![ + format!( + "db=Some(\"default\") table=Some(\"t\") col=a idx=0 table_index=Some({})", + index_scan.table_index + ), + format!( + "db=Some(\"default\") table=Some(\"t\") col=b idx=1 table_index=Some({})", + index_scan.table_index + ), + ], + "catalog-bound agg index raw plan should keep canonical bound column names inside index metadata", + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_lite_agg_index_auto_bound_share_child_metadata() -> Result<()> { + let ctx = LiteTableContext::create_isolated().await?; + ctx.configure_for_optimizer_case(false)?; + ctx.get_settings() + .set_setting("enable_aggregating_index_scan".to_string(), "1".to_string())?; + ctx.set_can_scan_from_agg_index(true); + ctx.register_table_sql("create table t(a int, b int, c int)") + .await?; + ctx.register_agg_index( + "default", + "t", + "idx1", + "select b, sum(a) from t where b > 1 group by b", + )?; + ctx.register_agg_index( + "default", + "t", + "idx2", + "select b, max(c) from t where b > 2 group by b", + )?; + + let plan = ctx + .bind_sql("select sum(a), b from t where b > 1 group by b") + .await?; + let metadata = match &plan { + Plan::Query { metadata, .. } => metadata.read(), + _ => unreachable!("debug test must bind to query plan"), + }; + let agg_indices = metadata + .get_agg_indices( + find_scan(match &plan { + Plan::Query { s_expr, .. } => s_expr, + _ => unreachable!("debug test must bind to query plan"), + })? + .expect("query scan should exist") + .table_index, + ) + .expect("agg index should be bound from catalog"); + + assert_eq!(agg_indices.len(), 2); + assert!( + Arc::ptr_eq(&agg_indices[0].metadata, &agg_indices[1].metadata), + "catalog-bound agg indices should share the same child metadata", + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_lite_agg_index_auto_bound_for_cross_database_table() -> Result<()> { + let ctx = LiteTableContext::create_isolated().await?; + ctx.configure_for_optimizer_case(false)?; + ctx.get_settings() + .set_setting("enable_aggregating_index_scan".to_string(), "1".to_string())?; + ctx.set_can_scan_from_agg_index(true); + ctx.register_table_sql("create table other_db.t(a int, b int, c int)") + .await?; + ctx.register_agg_index( + "other_db", + "t", + "idx1", + "select b, sum(a) from other_db.t where b > 1 group by b", + )?; + + let plan = ctx + .bind_sql("select sum(a), b from other_db.t where b > 1 group by b") + .await?; + let (query_scan, metadata_ref) = match &plan { + Plan::Query { + s_expr, metadata, .. + } => ( + find_scan(s_expr)?.expect("query scan should exist"), + metadata.clone(), + ), + _ => unreachable!("debug test must bind to query plan"), + }; + let metadata = metadata_ref.read(); + + assert_eq!( + metadata.table(query_scan.table_index).database(), + "other_db" + ); + assert!( + metadata.get_agg_indices(query_scan.table_index).is_some(), + "cross-database table should keep its bound agg index", + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_lite_agg_index_auto_bound_for_repeated_table_scans() -> Result<()> { + let ctx = + create_auto_bound_agg_index_ctx("select b, sum(a) from t where b > 1 group by b").await?; + let plan = ctx + .bind_sql("select l.a, r.c from t as l join t as r on l.b = r.b") + .await?; + let (s_expr, metadata_ref) = match &plan { + Plan::Query { + s_expr, metadata, .. + } => (s_expr, metadata.clone()), + _ => unreachable!("debug test must bind to query plan"), + }; + let metadata = metadata_ref.read(); + let mut table_indexes = Vec::new(); + collect_scan_table_indexes(s_expr, &mut table_indexes); + + assert_eq!( + table_indexes.len(), + 2, + "self join should bind two table scans" + ); + for table_index in table_indexes { + assert!( + metadata.get_agg_indices(table_index).is_some(), + "each repeated table scan should keep its bound agg index", + ); + } + + Ok(()) +} diff --git a/src/query/sql/tests/it/optimizer/mod.rs b/src/query/sql/tests/it/optimizer/mod.rs index 9cb7859cc30c8..f8d23c0c9301c 100644 --- a/src/query/sql/tests/it/optimizer/mod.rs +++ b/src/query/sql/tests/it/optimizer/mod.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod agg_index; mod eager_aggregation; mod normalize_scalar; mod push_down_filter_project_set;