diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index eca40c553280..08272a830379 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -52,6 +52,7 @@ arrow = { workspace = true } bigdecimal = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true, default-features = true } indexmap = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index b50fbf68129c..d58e3c7448aa 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -31,7 +31,9 @@ use datafusion_common::error::DataFusionErrorBuilder; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_common::{RecursionUnnestOption, UnnestOptions}; -use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; +use datafusion_expr::expr::{ + Alias, PlannedReplaceSelectItem, ScalarFunction, WildcardOptions, +}; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_sorts, }; @@ -40,14 +42,17 @@ use datafusion_expr::utils::{ expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_window_exprs, }; use datafusion_expr::{ - Aggregate, Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, - LogicalPlanBuilderOptions, Partitioning, + lit, Aggregate, BinaryExpr, Expr, Filter, GroupingSet, LogicalPlan, + LogicalPlanBuilder, LogicalPlanBuilderOptions, Operator, Partitioning, ScalarUDF, }; +use datafusion_functions::math::random::RandomFunc; use indexmap::IndexMap; use sqlparser::ast::{ visit_expressions_mut, Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, - OrderBy, SelectItemQualifiedWildcardKind, WildcardAdditionalOptions, WindowType, + OrderBy, SelectItemQualifiedWildcardKind, TableFactor, TableSampleKind, + TableSampleModifier, TableSampleQuantity, TableSampleUnit, WildcardAdditionalOptions, + WindowType, }; use sqlparser::ast::{NamedWindowDefinition, Select, SelectItem, TableWithJoins}; @@ -77,11 +82,29 @@ impl SqlToRel<'_, S> { } // Process `from` clause - let plan = self.plan_from_tables(select.from, planner_context)?; + let plan = self.plan_from_tables(select.from.clone(), planner_context)?; let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // Process `where` clause - let base_plan = self.plan_selection(select.selection, plan, planner_context)?; + let mut base_plan = + self.plan_selection(select.selection, plan, planner_context)?; + + // Now `base_plan` is a LogicalPlan::Filter + if let Some(from_first) = select.from.first() { + if let TableFactor::Table { + sample: Some(sample), + .. + } = &from_first.relation + { + // Rewrite SAMPLE / TABLESAMPLE clause to additional filters + // TODO: handle samples from joined tables + base_plan = self.sample_to_where_random_clause( + base_plan, + sample, + planner_context, + )?; + } + } // Handle named windows before processing the projection expression check_conflicting_windows(&select.named_window)?; @@ -556,6 +579,122 @@ impl SqlToRel<'_, S> { } } + /// Extract an expression for table sample quantity + fn sample_quanitity_value( + &self, + quantity: &TableSampleQuantity, + planner_context: &mut PlannerContext, + ) -> Result { + match &quantity.value { + // Support only numeric literals now + SQLExpr::Value(value_with_span) => Ok(self.parse_value( + value_with_span.value.clone(), + planner_context.prepare_param_data_types(), + )?), + _ => not_impl_err!( + "Table quantity value {:?} is not supported", + &quantity.value + ), + } + } + + /// Compose expression for TABLE SAMPLE filter + fn sample_to_threshold_expr( + &self, + table_sample_kind: &TableSampleKind, + planner_context: &mut PlannerContext, + ) -> Result> { + // Support both before and after syntax + let table_sample = match table_sample_kind { + // Standard syntax + TableSampleKind::BeforeTableAlias(kind) => kind, + // Hive syntax + TableSampleKind::AfterTableAlias(kind) => kind, + }; + + // These features are not part of a common SQL specification, + // not implemented yet + if table_sample.seed.is_some() { + return not_impl_err!("Table sample seed is not supported"); + } + if table_sample.bucket.is_some() { + return not_impl_err!("Table sample bucket is not supported"); + } + if table_sample.offset.is_some() { + return not_impl_err!("Table sample offset is not supported"); + } + + if let Some(table_sample_quantity) = &table_sample.quantity { + match table_sample_quantity.unit { + Some(TableSampleUnit::Rows) => { + // Fixed size row sampling is not supported + not_impl_err!("Table sample with rows unit is not supported") + } + Some(TableSampleUnit::Percent) | None => { + // There are two flavors of sampling (`TableSampleMethod`): + // - Block-level sampling (SYSTEM or BLOCK keywords) + // - Row-level sampling (BERNOULLI or ROW keywords) + // `random()` filter pushdown allows only block-level sampling, + // not row-level. However, we do not forbid using BERNOULLI/ROW; + + // Extract quantity from SQLExpr + let quantity_value: Expr = self + .sample_quanitity_value(table_sample_quantity, planner_context)?; + + let ratio: Expr = match table_sample.modifier { + TableSampleModifier::TableSample => + // SELECT * FROM tbl TABLESAMPLE SYSTEM (10), + // Value is percentage + { + Expr::BinaryExpr(BinaryExpr::new( + Box::new(quantity_value), + Operator::Divide, + Box::new(lit(100.0)), + )) + } + TableSampleModifier::Sample => + // SELECT * FROM tbl SAMPLE 0.1 + // Value is floating ratio, pass as is + { + quantity_value + } + }; + + let random_threshold = Box::new(ratio); + Ok(random_threshold) + } + } + } else { + plan_err!("Table sample quantity must be specified") + } + } + + /// Compose a logical plan with a static Filter based on TABLE SAMPLE expression + fn sample_to_where_random_clause( + &self, + plan: LogicalPlan, + sample_kind: &TableSampleKind, + planner_context: &mut PlannerContext, + ) -> Result { + // `random()` call + let random_udf = ScalarUDF::new_from_impl(RandomFunc::new()); + let random_expr_call = Box::new(Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(random_udf), + vec![], + ))); + let random_threshold = + self.sample_to_threshold_expr(sample_kind, planner_context)?; + // New filter predicate: `random() < 0.1` + let predicate = Expr::BinaryExpr(BinaryExpr::new( + random_expr_call, + Operator::Lt, + random_threshold, + )); + let random_filter = Filter::try_new(predicate, Arc::new(plan)); + + Ok(LogicalPlan::Filter(random_filter?)) + } + pub(crate) fn plan_from_tables( &self, mut from: Vec, diff --git a/datafusion/sqllogictest/test_files/tablesample.slt b/datafusion/sqllogictest/test_files/tablesample.slt new file mode 100644 index 000000000000..fe9d3acd8ecc --- /dev/null +++ b/datafusion/sqllogictest/test_files/tablesample.slt @@ -0,0 +1,239 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +create table t(a int) as values (1), (2), (3), (1); + +statement ok +create table t2(a int, b int) as values (1, 10), (2, 20), (3, 30), (4, 40); + +query II +select *, count(*) over() as ta from t; +---- +1 4 +2 4 +3 4 +1 4 + +statement ok +set datafusion.explain.logical_plan_only = true; + +# tablesample value +query TT +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE 42 WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------Filter: t.a < Int32(10) AND random() < Float64(0.42) +05)--------TableScan: t projection=[a] + + +# tablesample value float +query TT +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE 42.3 WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------Filter: t.a < Int32(10) AND random() < Float64(0.423) +05)--------TableScan: t projection=[a] + + +# tablesample system(value) +query TT +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE SYSTEM (42) WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------Filter: t.a < Int32(10) AND random() < Float64(0.42) +05)--------TableScan: t projection=[a] + +# tablesample system percent +query TT +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE SYSTEM (42 PERCENT) WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------Filter: t.a < Int32(10) AND random() < Float64(0.42) +05)--------TableScan: t projection=[a] + +# tablesample block(value) +query TT +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE BLOCK (42) WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------Filter: t.a < Int32(10) AND random() < Float64(0.42) +05)--------TableScan: t projection=[a] + +# tablesample after alias +query TT +EXPLAIN SELECT COUNT(*) from t as talias TABLESAMPLE SYSTEM (42) WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: talias +04)------Projection: +05)--------Filter: t.a < Int32(10) AND random() < Float64(0.42) +06)----------TableScan: t projection=[a] + +# sample random +query TT +EXPLAIN SELECT COUNT(*) from t SAMPLE 0.42 WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------Filter: t.a < Int32(10) AND random() < Float64(0.42) +05)--------TableScan: t projection=[a] + +# tablesample system percent with BERNOULLI method +query TT +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE BERNOULLI (42 PERCENT) WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------Filter: t.a < Int32(10) AND random() < Float64(0.42) +05)--------TableScan: t projection=[a] + +# tablesample system percent with ROW method (percentage), Snowflake syntax +query TT +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE ROW (42 PERCENT) WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------Filter: t.a < Int32(10) AND random() < Float64(0.42) +05)--------TableScan: t projection=[a] + +# tablesample system percent with ROW method (rows), Snowflake syntax +query error DataFusion error: This feature is not implemented: Table sample with rows unit is not supported +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE ROW (20 ROWS) WHERE a < 10; + +# unsupported: fixed row sampling +query error DataFusion error: This feature is not implemented: Table sample with rows unit is not supported +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE (5 ROWS); + +# unsupported: buckets +query error DataFusion error: This feature is not implemented: Table sample bucket is not supported +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE (BUCKET 3 OUT OF 16 ON id) + +# unsupported: seed +query error DataFusion error: This feature is not implemented: Table sample seed is not supported +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE SYSTEM (3) REPEATABLE (82) + + +# smoke test for joining tables +query III +SELECT t.a, t2.a, t2.b FROM t JOIN t2 on t.a = t2.a; +---- +1 1 10 +1 1 10 +2 2 20 +3 3 30 + +# multiple tables with join +# sampling is applied only to the first table +query TT +EXPLAIN SELECT COUNT(*) from t SAMPLE 0.42 JOIN t2 TABLESAMPLE 10 PERCENT on t.a = t2.a; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------Inner Join: t.a = t2.a +05)--------Filter: random() < Float64(0.42) +06)----------TableScan: t projection=[a] +07)--------TableScan: t2 projection=[a] + +# multiple tables with subquery +query TT +EXPLAIN SELECT COUNT(*) from t SAMPLE 0.42 WHERE a IN (SELECT b from t2 TABLESAMPLE 10 PERCENT) and a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------LeftSemi Join: t.a = __correlated_sq_1.b +05)--------Filter: t.a < Int32(10) AND random() < Float64(0.42) +06)----------TableScan: t projection=[a] +07)--------SubqueryAlias: __correlated_sq_1 +08)----------Filter: random() < Float64(0.1) +09)------------TableScan: t2 projection=[b] + +statement ok +set datafusion.sql_parser.dialect = 'Hive'; + +# tablesample before alias, Hive syntax +query TT +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE SYSTEM (42) as talias WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: talias +04)------Projection: +05)--------Filter: t.a < Int32(10) AND random() < Float64(0.42) +06)----------TableScan: t projection=[a] + +statement ok +set datafusion.sql_parser.dialect = 'Generic'; + +statement ok +set datafusion.explain.logical_plan_only = false; + +# verify that `random()` filter is not pushed down to executor as volatile +statement ok +set datafusion.execution.parquet.pushdown_filters=true; + +query TT +EXPLAIN SELECT COUNT(*) from t TABLESAMPLE SYSTEM (42 PERCENT) WHERE a < 10; +---- +logical_plan +01)Projection: count(Int64(1)) AS count(*) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----Projection: +04)------Filter: t.a < Int32(10) AND random() < Float64(0.42) +05)--------TableScan: t projection=[a] +physical_plan +01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------ProjectionExec: expr=[] +07)------------CoalesceBatchesExec: target_batch_size=8192 +08)--------------FilterExec: a@0 < 10 AND random() < 0.42 +09)----------------DataSourceExec: partitions=1, partition_sizes=[1] + +statement count 0 +drop table t;