From 1676c0a163e26d1edfe0563e803a7d8953a4e9ae Mon Sep 17 00:00:00 2001 From: theirix Date: Wed, 17 Sep 2025 19:58:34 +0100 Subject: [PATCH 01/17] Add ctor, rand_distr dependencies --- Cargo.lock | 2 ++ Cargo.toml | 1 + datafusion-examples/Cargo.toml | 2 ++ 3 files changed, 5 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 9dd40437fba9..60eb1886b27a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2164,6 +2164,7 @@ dependencies = [ "async-trait", "base64 0.22.1", "bytes", + "ctor", "dashmap", "datafusion", "datafusion-ffi", @@ -2177,6 +2178,7 @@ dependencies = [ "object_store", "prost", "rand 0.9.2", + "rand_distr", "serde_json", "tempfile", "test-utils", diff --git a/Cargo.toml b/Cargo.toml index 92392a199107..91e9a375b27e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -169,6 +169,7 @@ pbjson-types = "0.7" insta = { version = "1.43.2", features = ["glob", "filters"] } prost = "0.13.1" rand = "0.9" +rand_distr = "0.5" recursive = "0.1.1" regex = "1.11" rstest = "0.25.0" diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 68bb5376a1ac..a24b9146fd09 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -63,6 +63,7 @@ arrow-flight = { workspace = true } arrow-schema = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } +ctor = { workspace = true } dashmap = { workspace = true } # note only use main datafusion crate for examples base64 = "0.22.1" @@ -77,6 +78,7 @@ mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } rand = { workspace = true } +rand_distr = { workspace = true } serde_json = { workspace = true } tempfile = { workspace = true } test-utils = { path = "../test-utils" } From f6b73bd472fcd3f9c6960e51e3f1c64a0b0e6b4f Mon Sep 17 00:00:00 2001 From: theirix Date: Wed, 17 Sep 2025 21:39:10 +0100 Subject: [PATCH 02/17] Implement SQL table sample as an extension --- datafusion-examples/examples/table_sample.rs | 1355 ++++++++++++++++++ 1 file changed, 1355 insertions(+) create mode 100644 datafusion-examples/examples/table_sample.rs diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs new file mode 100644 index 000000000000..42743186df10 --- /dev/null +++ b/datafusion-examples/examples/table_sample.rs @@ -0,0 +1,1355 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![allow(unused_imports)] + +use datafusion::common::{ + arrow_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, + DFSchema, DFSchemaRef, ResolvedTableReference, Statistics, TableReference, +}; +use datafusion::error::Result; +use datafusion::logical_expr::sqlparser::ast::{ + Query, SetExpr, Statement, TableFactor, TableSample, TableSampleMethod, + TableSampleQuantity, TableSampleUnit, +}; +use datafusion::logical_expr::{ + AggregateUDF, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Projection, + ScalarUDF, TableSource, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + WindowUDF, +}; +use std::any::Any; +use std::cmp::Ordering; +use std::collections::{BTreeMap, HashMap}; + +use arrow::util::pretty::{pretty_format_batches, pretty_format_batches_with_schema}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRewriter, +}; + +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit}; +use async_trait::async_trait; +use datafusion::catalog::cte_worktable::CteWorkTable; +use datafusion::common::file_options::file_type::FileType; +use datafusion::config::ConfigOptions; +use datafusion::datasource::file_format::format_as_file_type; +use datafusion::datasource::{provider_as_source, DefaultTableSource, TableProvider}; +use datafusion::error::DataFusionError; +use datafusion::execution::{ + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, + TaskContext, +}; +use datafusion::logical_expr::planner::{ + ContextProvider, ExprPlanner, PlannerResult, RawBinaryExpr, TypePlanner, +}; +use datafusion::logical_expr::sqlparser::dialect::PostgreSqlDialect; +use datafusion::logical_expr::sqlparser::parser::Parser; +use datafusion::logical_expr::var_provider::is_system_variables; +use datafusion::optimizer::simplify_expressions::ExprSimplifier; +use datafusion::optimizer::AnalyzerRule; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, +}; +use datafusion::physical_plan::{ + displayable, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + RecordBatchStream, +}; +use datafusion::physical_planner::{ + DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner, +}; +use datafusion::prelude::*; +use datafusion::sql::planner::{ParserOptions, PlannerContext, SqlToRel}; +use datafusion::sql::sqlparser::ast::{TableSampleKind, TableSampleModifier}; +use datafusion::sql::unparser::ast::{ + DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, +}; +use datafusion::sql::unparser::dialect::CustomDialectBuilder; +use datafusion::sql::unparser::expr_to_sql; +use datafusion::sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparser; +use datafusion::sql::unparser::extension_unparser::{ + UnparseToStatementResult, UnparseWithinStatementResult, +}; +use datafusion::sql::unparser::{plan_to_sql, Unparser}; +use datafusion::variable::VarType; +use log::{debug, info}; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::UInt32Array; +use arrow::compute; +use arrow::record_batch::RecordBatch; + +use datafusion::execution::context::QueryPlanner; +use datafusion::sql::sqlparser; +use datafusion::sql::sqlparser::ast; +use futures::stream::{Stream, StreamExt}; +use futures::TryStreamExt; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use rand_distr::{Distribution, Poisson}; + +/// This example demonstrates the table sample support. + +#[derive(Debug, Clone)] +struct TableSamplePlanNode { + inner_plan: LogicalPlan, + + lower_bound: f64, + upper_bound: f64, + with_replacement: bool, + seed: u64, +} + +impl Hash for TableSamplePlanNode { + fn hash(&self, state: &mut H) { + self.inner_plan.hash(state); + self.lower_bound.to_bits().hash(state); + self.upper_bound.to_bits().hash(state); + self.with_replacement.hash(state); + self.seed.hash(state); + } +} + +impl PartialEq for TableSamplePlanNode { + fn eq(&self, other: &Self) -> bool { + self.inner_plan == other.inner_plan + && (self.lower_bound - other.lower_bound).abs() < f64::EPSILON + && (self.upper_bound - other.upper_bound).abs() < f64::EPSILON + && self.with_replacement == other.with_replacement + && self.seed == other.seed + } +} + +impl Eq for TableSamplePlanNode {} + +impl PartialOrd for TableSamplePlanNode { + fn partial_cmp(&self, other: &Self) -> Option { + self.inner_plan + .partial_cmp(&other.inner_plan) + .and_then(|ord| { + if ord != Ordering::Equal { + Some(ord) + } else { + self.lower_bound + .partial_cmp(&other.lower_bound) + .and_then(|ord| { + if ord != Ordering::Equal { + Some(ord) + } else { + self.upper_bound.partial_cmp(&other.upper_bound).and_then( + |ord| { + if ord != Ordering::Equal { + Some(ord) + } else { + self.with_replacement + .partial_cmp(&other.with_replacement) + .and_then(|ord| { + if ord != Ordering::Equal { + Some(ord) + } else { + self.seed.partial_cmp(&other.seed) + } + }) + } + }, + ) + } + }) + } + }) + } +} + +impl UserDefinedLogicalNodeCore for TableSamplePlanNode { + fn name(&self) -> &str { + "TableSample" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.inner_plan] + } + + fn schema(&self) -> &DFSchemaRef { + self.inner_plan.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> fmt::Result { + f.write_fmt(format_args!( + "Sample: {:?} {:?} {:?}", + self.lower_bound, self.upper_bound, self.seed + )) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + let input = inputs + .first() + .ok_or(DataFusionError::Plan("Should have input".into()))?; + Ok(Self { + inner_plan: input.clone(), + lower_bound: self.lower_bound, + upper_bound: self.upper_bound, + with_replacement: self.with_replacement, + seed: self.seed, + }) + } +} + +/// Execution planner with `SampleExec` for `TableSamplePlanNode` +struct TableSampleExtensionPlanner {} + +impl TableSampleExtensionPlanner { + fn build_execution_plan( + &self, + specific_node: &TableSamplePlanNode, + physical_input: Arc, + ) -> Result> { + Ok(Arc::new(SampleExec { + input: physical_input.clone(), + lower_bound: 0.0, + upper_bound: specific_node.upper_bound, + with_replacement: specific_node.with_replacement, + seed: specific_node.seed, + metrics: Default::default(), + cache: SampleExec::compute_properties(&physical_input), + })) + } +} + +#[async_trait] +impl ExtensionPlanner for TableSampleExtensionPlanner { + /// Create a physical plan for an extension node + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> Result>> { + if let Some(specific_node) = node.as_any().downcast_ref::() { + info!("Extension planner plan_extension: {:?}", &logical_inputs); + assert_eq!(logical_inputs.len(), 1, "Inconsistent number of inputs"); + assert_eq!(physical_inputs.len(), 1, "Inconsistent number of inputs"); + + let exec_plan = + self.build_execution_plan(specific_node, physical_inputs[0].clone())?; + Ok(Some(exec_plan)) + } else { + Ok(None) + } + } +} + +/// Query planner supporting a `TableSampleExtensionPlanner` +#[derive(Debug)] +struct TableSampleQueryPlanner {} + +#[async_trait] +impl QueryPlanner for TableSampleQueryPlanner { + /// Given a `LogicalPlan` created from above, create an + /// `ExecutionPlan` suitable for execution + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + // Additional extension for table sample node + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( + TableSampleExtensionPlanner {}, + )]); + // Delegate most work of physical planning to the default physical planner + physical_planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +/// Physical plan implementation +trait Sampler: Send + Sync { + fn sample(&mut self, batch: &RecordBatch) -> Result; +} + +struct BernoulliSampler { + lower_bound: f64, + upper_bound: f64, + rng: StdRng, +} + +impl BernoulliSampler { + fn new(lower_bound: f64, upper_bound: f64, seed: u64) -> Self { + Self { + lower_bound, + upper_bound, + rng: StdRng::seed_from_u64(seed), + } + } +} + +impl Sampler for BernoulliSampler { + fn sample(&mut self, batch: &RecordBatch) -> Result { + if self.upper_bound <= self.lower_bound { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let mut indices = Vec::new(); + + for i in 0..batch.num_rows() { + let rnd: f64 = self.rng.random(); + + if rnd >= self.lower_bound && rnd < self.upper_bound { + indices.push(i as u32); + } + } + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + let indices = UInt32Array::from(indices); + compute::take_record_batch(batch, &indices).map_err(|e| e.into()) + } +} + +struct PoissonSampler { + ratio: f64, + poisson: Poisson, + rng: StdRng, +} + +impl PoissonSampler { + fn try_new(ratio: f64, seed: u64) -> Result { + let poisson = Poisson::new(ratio).map_err(|e| plan_datafusion_err!("{}", e))?; + Ok(Self { + ratio, + poisson, + rng: StdRng::seed_from_u64(seed), + }) + } +} + +impl Sampler for PoissonSampler { + fn sample(&mut self, batch: &RecordBatch) -> Result { + if self.ratio <= 0.0 { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let mut indices = Vec::new(); + + for i in 0..batch.num_rows() { + let k = self.poisson.sample(&mut self.rng) as i32; + for _ in 0..k { + indices.push(i as u32); + } + } + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let indices = UInt32Array::from(indices); + compute::take_record_batch(batch, &indices).map_err(|e| e.into()) + } +} + +/// SampleExec samples rows from its input based on a sampling method. +/// This is used to implement SQL `SAMPLE` clause. +#[derive(Debug, Clone)] +pub struct SampleExec { + /// The input plan + input: Arc, + /// The lower bound of the sampling ratio + lower_bound: f64, + /// The upper bound of the sampling ratio + upper_bound: f64, + /// Whether to sample with replacement + with_replacement: bool, + /// Random seed for reproducible sampling + seed: u64, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Properties equivalence properties, partitioning, etc. + cache: PlanProperties, +} + +impl SampleExec { + /// Create a new SampleExec with a custom sampling method + pub fn try_new( + input: Arc, + lower_bound: f64, + upper_bound: f64, + with_replacement: bool, + seed: u64, + ) -> Result { + if lower_bound < 0.0 || upper_bound > 1.0 || lower_bound > upper_bound { + return internal_err!( + "Sampling bounds must be between 0.0 and 1.0, and lower_bound <= upper_bound, got [{}, {}]", + lower_bound, upper_bound + ); + } + + let cache = Self::compute_properties(&input); + + Ok(Self { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + metrics: ExecutionPlanMetricsSet::new(), + cache, + }) + } + + fn create_sampler(&self, partition: usize) -> Result> { + if self.with_replacement { + Ok(Box::new(PoissonSampler::try_new( + self.upper_bound - self.lower_bound, + self.seed + partition as u64, + )?)) + } else { + Ok(Box::new(BernoulliSampler::new( + self.lower_bound, + self.upper_bound, + self.seed + partition as u64, + ))) + } + } + + /// Whether to sample with replacement + pub fn with_replacement(&self) -> bool { + self.with_replacement + } + + /// The lower bound of the sampling ratio + pub fn lower_bound(&self) -> f64 { + self.lower_bound + } + + /// The upper bound of the sampling ratio + pub fn upper_bound(&self) -> f64 { + self.upper_bound + } + + /// The random seed + pub fn seed(&self) -> u64 { + self.seed + } + + /// The input plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(input: &Arc) -> PlanProperties { + input + .properties() + .clone() + .with_eq_properties(EquivalenceProperties::new(input.schema())) + } +} + +impl DisplayAs for SampleExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "SampleExec: lower_bound={}, upper_bound={}, with_replacement={}, seed={}", + self.lower_bound, self.upper_bound, self.with_replacement, self.seed + ) + } + DisplayFormatType::TreeRender => { + write!( + f, + "SampleExec: lower_bound={}, upper_bound={}, with_replacement={}, seed={}", + self.lower_bound, self.upper_bound, self.with_replacement, self.seed + ) + } + } + } +} + +impl ExecutionPlan for SampleExec { + fn name(&self) -> &'static str { + "SampleExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn maintains_input_order(&self) -> Vec { + vec![false] // Sampling does not maintain input order + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(SampleExec::try_new( + Arc::clone(&children[0]), + self.lower_bound, + self.upper_bound, + self.with_replacement, + self.seed, + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input_stream = self.input.execute(partition, context)?; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + Ok(Box::pin(SampleExecStream { + input: input_stream, + sampler: self.create_sampler(partition)?, + baseline_metrics, + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stats = self.input.partition_statistics(partition)?; + + // Apply sampling ratio to statistics + let mut stats = input_stats; + let ratio = self.upper_bound - self.lower_bound; + + stats.num_rows = stats + .num_rows + .map(|nr| (nr as f64 * ratio) as usize) + .to_inexact(); + stats.total_byte_size = stats + .total_byte_size + .map(|tb| (tb as f64 * ratio) as usize) + .to_inexact(); + + Ok(stats) + } +} + +/// Stream for the SampleExec operator +struct SampleExecStream { + /// The input stream + input: SendableRecordBatchStream, + /// The sampling method + sampler: Box, + /// Runtime metrics recording + baseline_metrics: BaselineMetrics, +} + +impl Stream for SampleExecStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.input.poll_next_unpin(cx); + let baseline_metrics = &mut self.baseline_metrics; + + match poll { + Poll::Ready(Some(Ok(batch))) => { + let start = baseline_metrics.elapsed_compute().clone(); + let result = self.sampler.sample(&batch); + let _timer = start.timer(); + Poll::Ready(Some(result)) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl RecordBatchStream for SampleExecStream { + fn schema(&self) -> SchemaRef { + self.input.schema() + } +} + +/// Query planner to produce a LogicalPlan from AST with table sampling +struct TableSamplePlanner<'a, S: ContextProvider> { + context_provider: &'a S, +} + +fn evaluate_number< + T: FromStr + Add + Sub + Mul + Div, +>( + expr: &ast::Expr, +) -> Option { + match expr { + ast::Expr::BinaryOp { left, op, right } => { + let left = evaluate_number::(left); + let right = evaluate_number::(right); + match (left, right) { + (Some(left), Some(right)) => match op { + ast::BinaryOperator::Plus => Some(left + right), + ast::BinaryOperator::Minus => Some(left - right), + ast::BinaryOperator::Multiply => Some(left * right), + ast::BinaryOperator::Divide => Some(left / right), + _ => None, + }, + _ => None, + } + } + ast::Expr::Value(value) => match &value.value { + ast::Value::Number(value, _) => { + let value = value.to_string(); + let Ok(value) = value.parse::() else { + return None; + }; + Some(value) + } + _ => None, + }, + _ => None, + } +} + +impl<'a, S: ContextProvider> TableSamplePlanner<'a, S> { + pub fn new(context_provider: &'a S) -> Self { + Self { context_provider } + } + + pub fn new_node( + input: LogicalPlan, + fraction: f64, + with_replacement: Option, + seed: Option, + ) -> Result { + let node = TableSamplePlanNode { + inner_plan: input, + lower_bound: 0.0, + upper_bound: fraction, + with_replacement: with_replacement.unwrap_or(false), + seed: seed.unwrap_or_else(rand::random), + }; + + Ok(LogicalPlan::Extension(Extension { + node: Arc::new(node), + })) + } + + fn sample_to_logical_plan( + &self, + input: LogicalPlan, + sample: TableSampleKind, + ) -> Result { + let sample = match sample { + TableSampleKind::BeforeTableAlias(sample) => sample, + TableSampleKind::AfterTableAlias(sample) => sample, + }; + if let Some(name) = &sample.name { + if *name != TableSampleMethod::Bernoulli && *name != TableSampleMethod::Row { + // Postgres-style sample. Not supported because DataFusion does not have a concept of pages like PostgreSQL. + return not_impl_err!("{} is not supported yet", name); + } + } + if sample.offset.is_some() { + // Clickhouse-style sample. Not supported because it requires knowing the total data size. + return not_impl_err!("Offset sample is not supported yet"); + } + + let seed = sample + .seed + .map(|seed| { + let Ok(seed) = seed.value.to_string().parse::() else { + return plan_err!("seed must be a number: {}", seed.value); + }; + Ok(seed) + }) + .transpose()?; + + if let Some(bucket) = sample.bucket { + if bucket.on.is_some() { + // Hive-style sample, only used when the Hive table is defined with CLUSTERED BY + return not_impl_err!("Bucket sample with ON is not supported yet"); + } + + let Ok(bucket_num) = bucket.bucket.to_string().parse::() else { + return plan_err!("bucket must be a number"); + }; + + let Ok(total_num) = bucket.total.to_string().parse::() else { + return plan_err!("total must be a number"); + }; + let value = bucket_num as f64 / total_num as f64; + return Self::new_node(input, value, None, seed); + } + if let Some(quantity) = sample.quantity { + return match quantity.unit { + Some(TableSampleUnit::Rows) => { + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!( + "quantity must be a number: {:?}", + quantity.value + ); + } + let value = value.unwrap(); + if value < 0 { + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); + } + LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build() + } + Some(TableSampleUnit::Percent) => { + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!( + "quantity must be a number: {:?}", + quantity.value + ); + } + let value = value.unwrap() / 100.0; + Self::new_node(input, value, None, seed) + } + None => { + // Clickhouse-style sample + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!( + "quantity must be a valid number: {:?}", + quantity.value + ); + } + let value = value.unwrap(); + if value < 0.0 { + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); + } + if value >= 1.0 { + // If value is larger than 1, it is a row limit + LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build() + } else { + // If value is between 0.0 and 1.0, it is a fraction + Self::new_node(input, value, None, seed) + } + } + }; + } + plan_err!("Cannot plan sample SQL") + } + + fn create_logical_plan(&self, statement: Statement) -> Result { + let sql_to_rel = SqlToRel::new(self.context_provider); + + let stmt = statement.clone(); + match &stmt { + Statement::Query(query) => { + let inner_plan = sql_to_rel.sql_statement_to_plan(statement)?; + if let SetExpr::Select(select) = &*query.body { + if select.from.len() == 1 { + let table_with_joins = select.from.first().unwrap(); + if let TableFactor::Table { sample: Some(table_sample_kind), ..} = + &table_with_joins.relation + { + debug!("Constructing table sample plan from {:?}", &select); + return self.sample_to_logical_plan( + inner_plan, + table_sample_kind.clone(), + ); + } + } + } + // Pass-through by default + Ok(inner_plan) + } + _ => sql_to_rel.sql_statement_to_plan(statement), + } + } +} + +// Context provider for tests + +struct MockContextProvider<'a> { + state: &'a SessionState, + tables: HashMap>, +} + +impl ContextProvider for MockContextProvider<'_> { + fn get_table_source(&self, name: TableReference) -> Result> { + self.tables + .get(&name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("table '{name}' not found")) + } + + fn get_expr_planners(&self) -> &[Arc] { + self.state.expr_planners() + } + + fn get_type_planner(&self) -> Option> { + None + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + + fn get_variable_type(&self, variable_names: &[String]) -> Option { + self.state + .execution_props() + .var_providers + .as_ref() + .and_then(|provider| provider.get(&VarType::System)?.get_type(variable_names)) + } + + fn options(&self) -> &ConfigOptions { + self.state.config_options() + } + + fn udf_names(&self) -> Vec { + self.state.scalar_functions().keys().cloned().collect() + } + + fn udaf_names(&self) -> Vec { + self.state.aggregate_functions().keys().cloned().collect() + } + + fn udwf_names(&self) -> Vec { + self.state.window_functions().keys().cloned().collect() + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let _ = env_logger::try_init(); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(TableSampleQueryPlanner {})) + .build(); + + let ctx = SessionContext::new_with_state(state.clone()); + + let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let table_source = provider_as_source(ctx.table_provider("alltypes_plain").await?); + let context_provider = MockContextProvider { + state: &state, + tables: HashMap::>::from([( + "alltypes_plain".into(), + table_source.clone(), + )]), + }; + + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 42 PERCENT REPEATABLE(5) WHERE int_col = 1"; + + let dialect = PostgreSqlDialect {}; + let statements = Parser::parse_sql(&dialect, sql)?; + let statement = statements.first().expect("one statement"); + + // Classical way + // let sql_to_rel = SqlToRel::new(&context_provider); + // let logical_plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; + + // Use sampling planner to create a logical plan + let table_sample_planner = TableSamplePlanner::new(&context_provider); + let logical_plan = table_sample_planner.create_logical_plan(statement.clone())?; + + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + + // Inspect physical plan + let displayable_plan = displayable(physical_plan.as_ref()) + .indent(false) + .to_string(); + info!("Physical plan:\n{displayable_plan}\n"); + let first_line = displayable_plan.lines().next().unwrap(); + assert_eq!( + first_line, + "SampleExec: lower_bound=0, upper_bound=0.42, with_replacement=false, seed=5" + ); + + // Execute via standard sql call - doesn't work + // let df = ctx.sql(sql).await?; + // let batches = df.collect().await?; + + // Execute directly via physical plan + let task_context = Arc::new(TaskContext::from(&ctx)); + let stream = physical_plan.execute(0, task_context)?; + let batches: Vec<_> = stream.try_collect().await?; + + info!("Batches: {:?}", &batches); + + let result_string = pretty_format_batches(&batches) + // pretty_format_batches_with_schema(table_source.schema(), &batches) + .map_err(|e| arrow_datafusion_err!(e)) + .map(|d| d.to_string())?; + let result_strings = result_string.lines().collect::>(); + info!("Batch result: {:?}", &result_strings); + + assert_eq!(batches.len(), 1); + assert_eq!(batches.first().unwrap().num_rows(), 2); + + info!("Done"); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int32Array; + use arrow::datatypes::{Field, Schema}; + use datafusion::assert_batches_eq; + use datafusion::physical_plan::test::TestMemoryExec; + use futures::TryStreamExt; + use std::sync::Arc; + + #[cfg(test)] + #[ctor::ctor] + fn init() { + // Enable RUST_LOG logging configuration for test + let _ = env_logger::try_init(); + } + + async fn parse_to_logical_plan(sql: &str) -> Result<(SessionContext, LogicalPlan)> { + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(TableSampleQueryPlanner {})) + .build(); + + let ctx = SessionContext::new_with_state(state.clone()); + + let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let table_source = + provider_as_source(ctx.table_provider("alltypes_plain").await?); + let context_provider = MockContextProvider { + state: &state, + tables: HashMap::>::from([( + "alltypes_plain".into(), + table_source.clone(), + )]), + }; + + let dialect = PostgreSqlDialect {}; + let statements = Parser::parse_sql(&dialect, sql)?; + let statement = statements.get(0).expect("one statement"); + + // Use sampling + let table_sample_planner = TableSamplePlanner::new(&context_provider); + let logical_plan = table_sample_planner.create_logical_plan(statement.clone())?; + + Ok((ctx, logical_plan)) + } + + fn physical_plan_first_line(physical_plan: Arc) -> String { + let displayable_plan = displayable(physical_plan.as_ref()) + .indent(false) + .to_string(); + info!("Physical plan:\n{}\n", displayable_plan); + let first_line = displayable_plan.lines().next().expect("empty plan"); + first_line.into() + } + + fn as_table_sample_node(logical_plan: LogicalPlan) -> Result { + match logical_plan { + LogicalPlan::Extension(Extension { node }) => { + if let Some(plan) = node.as_any().downcast_ref::() { + Ok(plan.clone()) + } else { + plan_err!("Wrong extension node") + } + } + _ => plan_err!("Not an extension node"), + } + } + + #[tokio::test] + async fn test_logical_plan_sample() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 0.42 where int_col = 1"; + let (_, logical_plan) = parse_to_logical_plan(sql).await?; + let node = as_table_sample_node(logical_plan)?; + assert_eq!(node.lower_bound, 0.0); + assert_eq!(node.upper_bound, 0.42); + assert_eq!(node.with_replacement, false); + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_sample_repeatable() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 0.42 REPEATABLE(123) where int_col = 1"; + let (_, logical_plan) = parse_to_logical_plan(sql).await?; + let node = as_table_sample_node(logical_plan)?; + assert_eq!(node.lower_bound, 0.0); + assert_eq!(node.upper_bound, 0.42); + assert_eq!(node.with_replacement, false); + assert_eq!(node.seed, 123); + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_sample_percent() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 42 PERCENT where int_col = 1"; + let (_, logical_plan) = parse_to_logical_plan(sql).await?; + let node = as_table_sample_node(logical_plan)?; + assert_eq!(node.lower_bound, 0.0); + assert_eq!(node.upper_bound, 0.42); + assert_eq!(node.with_replacement, false); + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_sample_rows() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 42 ROWS where int_col = 1"; + let (_, logical_plan) = parse_to_logical_plan(sql).await?; + if let LogicalPlan::Limit(limit) = logical_plan { + assert_eq!(limit.fetch, Some(Box::new(lit(42_i64)))); + assert_eq!(limit.skip, None); + } else { + assert!(false, "Expected LogicalPlan::Limit"); + } + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_method_system_unsupported() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE SYSTEM 42 where int_col = 1"; + let err = parse_to_logical_plan(sql).await.err().expect("should fail"); + assert!(err + .to_string() + .contains("SYSTEM is not supported yet")); + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_method_block_unsupported() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE BLOCK 42 where int_col = 1"; + let err = parse_to_logical_plan(sql).await.err().expect("should fail"); + assert!(err + .to_string() + .contains("BLOCK is not supported yet")); + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_method_offset_unsupported() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain SAMPLE 42 OFFSET 5 where int_col = 1"; + let err = parse_to_logical_plan(sql).await.err().expect("should fail"); + assert!(err + .to_string() + .contains("Offset sample is not supported yet")); + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_sample_clickhouse() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain SAMPLE 0.42 where int_col = 1"; + let (_, logical_plan) = parse_to_logical_plan(sql).await?; + let node = as_table_sample_node(logical_plan)?; + assert_eq!(node.lower_bound, 0.0); + assert_eq!(node.upper_bound, 0.42); + assert_eq!(node.with_replacement, false); + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_limit_clickhouse() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain SAMPLE 42 where int_col = 1"; + let (_, logical_plan) = parse_to_logical_plan(sql).await?; + if let LogicalPlan::Limit(limit) = logical_plan { + assert_eq!(limit.fetch, Some(Box::new(lit(42_i64)))); + assert_eq!(limit.skip, None); + } else { + assert!(false, "Expected LogicalPlan::Limit"); + } + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_bucket() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE (BUCKET 3 OUT OF 16) where int_col = 1"; + let (_, logical_plan) = parse_to_logical_plan(sql).await?; + let node = as_table_sample_node(logical_plan)?; + assert_eq!(node.lower_bound, 0.0); + assert!((node.upper_bound - 3.0 / 16.0).abs() < f64::EPSILON); + assert_eq!(node.with_replacement, false); + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_bucket_on_unsupported() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE (BUCKET 3 OUT OF 16 ON int_col) where int_col = 1"; + let err = parse_to_logical_plan(sql).await.err().expect("should fail"); + assert!(err + .to_string() + .contains("Bucket sample with ON is not supported yet")); + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_negative_value() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 1-5 where int_col = 1"; + let err = parse_to_logical_plan(sql).await.err().expect("should fail"); + assert!( + err.to_string() + .contains("quantity must be a non-negative number"), + "{err:?}" + ); + Ok(()) + } + + #[tokio::test] + async fn test_logical_plan_not_a_number() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 42 + int_col where int_col = 1"; + let err = parse_to_logical_plan(sql).await.err().expect("should fail"); + assert!( + err.to_string().contains("quantity must be a valid number"), + "{err:?}" + ); + Ok(()) + } + + #[tokio::test] + async fn test_physical_plan_sample() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 0.42 where int_col = 1"; + let (ctx, logical_plan) = parse_to_logical_plan(sql).await?; + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + let physical_plan_repr = physical_plan_first_line(physical_plan); + assert!(physical_plan_repr.starts_with( + "SampleExec: lower_bound=0, upper_bound=0.42, with_replacement=false, seed=" + ), "{physical_plan_repr}"); + Ok(()) + } + + #[tokio::test] + async fn test_physical_plan_sample_repeateable() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 0.42 REPEATABLE(123) where int_col = 1"; + let (ctx, logical_plan) = parse_to_logical_plan(sql).await?; + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + let physical_plan_repr = physical_plan_first_line(physical_plan); + assert_eq!( + physical_plan_repr, + "SampleExec: lower_bound=0, upper_bound=0.42, with_replacement=false, seed=123", + "{physical_plan_repr}" + ); + Ok(()) + } + + //noinspection RsTypeCheck + #[tokio::test] + async fn test_execute() -> Result<()> { + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 0.42 REPEATABLE(5) where int_col = 1"; + + let (ctx, logical_plan) = parse_to_logical_plan(sql).await?; + + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + + // Execute via physical plan + let task_context = Arc::new(TaskContext::from(&ctx)); + let stream = physical_plan.execute(0, task_context)?; + let batches: Vec<_> = stream.try_collect().await?; + + // Observed with a specific repeatable seed + assert_batches_eq!( + #[rustfmt::skip] + &[ + "+---------+------------+", + "| int_col | double_col |", + "+---------+------------+", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "+---------+------------+" + ], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn test_sample_exec_bernoulli() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + )?; + + let input = Arc::new(TestMemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?); + + let sample_exec = SampleExec::try_new(input, 0.6, 1.0, false, 42)?; + + let context = Arc::new(TaskContext::default()); + let stream = sample_exec.execute(0, context)?; + + let batches = stream.try_collect::>().await?; + assert_batches_eq!( + &["+----+", "| id |", "+----+", "| 3 |", "+----+",], + &batches + ); + + Ok(()) + } + + //noinspection RsTypeCheck + #[tokio::test] + async fn test_sample_exec_poisson() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + )?; + + let input = Arc::new(TestMemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?); + + let sample_exec = SampleExec::try_new(input, 0.0, 0.5, true, 42)?; + + let context = Arc::new(TaskContext::default()); + let stream = sample_exec.execute(0, context)?; + + let batches = stream.try_collect::>().await?; + assert_batches_eq!( + #[rustfmt::skip] + &[ + "+----+", + "| id |", + "+----+", + "| 3 |", + "+----+", + ], + &batches + ); + + Ok(()) + } + + //noinspection RsTypeCheck + #[test] + fn test_sampler_trait() { + let mut bernoulli_sampler = BernoulliSampler::new(0.0, 0.5, 42); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + ) + .unwrap(); + + let result = bernoulli_sampler.sample(&batch).unwrap(); + assert_batches_eq!( + #[rustfmt::skip] + &[ + "+----+", + "| id |", + "+----+", + "| 4 |", + "| 5 |", + "+----+", + ], + &[result] + ); + + let mut poisson_sampler = PoissonSampler::try_new(0.5, 42).unwrap(); + let result = poisson_sampler.sample(&batch).unwrap(); + assert_batches_eq!( + #[rustfmt::skip] + &[ + "+----+", + "| id |", + "+----+", + "| 3 |", + "+----+", + ], + &[result] + ); + } +} From edceba3e4c4b03bda9cdf2c885fbc626311bd005 Mon Sep 17 00:00:00 2001 From: theirix Date: Wed, 17 Sep 2025 22:16:50 +0100 Subject: [PATCH 03/17] Format --- datafusion-examples/examples/table_sample.rs | 22 +++++++++----------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index 42743186df10..295c6f0ee5d9 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -617,7 +617,7 @@ struct TableSamplePlanner<'a, S: ContextProvider> { } fn evaluate_number< - T: FromStr + Add + Sub + Mul + Div, + T: FromStr + Add + Sub + Mul + Div, >( expr: &ast::Expr, ) -> Option { @@ -793,8 +793,10 @@ impl<'a, S: ContextProvider> TableSamplePlanner<'a, S> { if let SetExpr::Select(select) = &*query.body { if select.from.len() == 1 { let table_with_joins = select.from.first().unwrap(); - if let TableFactor::Table { sample: Some(table_sample_kind), ..} = - &table_with_joins.relation + if let TableFactor::Table { + sample: Some(table_sample_kind), + .. + } = &table_with_joins.relation { debug!("Constructing table sample plan from {:?}", &select); return self.sample_to_logical_plan( @@ -889,7 +891,7 @@ async fn main() -> Result<()> { &format!("{testdata}/alltypes_plain.parquet"), ParquetReadOptions::default(), ) - .await?; + .await?; let table_source = provider_as_source(ctx.table_provider("alltypes_plain").await?); let context_provider = MockContextProvider { @@ -984,7 +986,7 @@ mod tests { &format!("{testdata}/alltypes_plain.parquet"), ParquetReadOptions::default(), ) - .await?; + .await?; let table_source = provider_as_source(ctx.table_provider("alltypes_plain").await?); @@ -1085,9 +1087,7 @@ mod tests { let sql = "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE SYSTEM 42 where int_col = 1"; let err = parse_to_logical_plan(sql).await.err().expect("should fail"); - assert!(err - .to_string() - .contains("SYSTEM is not supported yet")); + assert!(err.to_string().contains("SYSTEM is not supported yet")); Ok(()) } @@ -1096,9 +1096,7 @@ mod tests { let sql = "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE BLOCK 42 where int_col = 1"; let err = parse_to_logical_plan(sql).await.err().expect("should fail"); - assert!(err - .to_string() - .contains("BLOCK is not supported yet")); + assert!(err.to_string().contains("BLOCK is not supported yet")); Ok(()) } @@ -1322,7 +1320,7 @@ mod tests { Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], ) - .unwrap(); + .unwrap(); let result = bernoulli_sampler.sample(&batch).unwrap(); assert_batches_eq!( From 2a6f0d681a9f9b4718dfe58f4815cf33d25f2282 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 22 Sep 2025 10:47:43 -0400 Subject: [PATCH 04/17] Use newtype to get hash/eq for f64 --- datafusion-examples/examples/table_sample.rs | 103 ++++++++----------- 1 file changed, 41 insertions(+), 62 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index 295c6f0ee5d9..17c93ea1953a 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -108,78 +108,57 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rand_distr::{Distribution, Poisson}; -/// This example demonstrates the table sample support. -#[derive(Debug, Clone)] -struct TableSamplePlanNode { - inner_plan: LogicalPlan, - lower_bound: f64, - upper_bound: f64, - with_replacement: bool, - seed: u64, +/// This example demonstrates the table sample support. + +/// Hashable and comparible f64 for sampling bounds +#[derive(Debug, Clone, Copy, PartialOrd)] +struct Bound(f64); +impl PartialEq for Bound { + fn eq(&self, other: &Self) -> bool { + (self.0 - other.0).abs() < f64::EPSILON + } } -impl Hash for TableSamplePlanNode { +impl Eq for Bound {} + +impl Hash for Bound { fn hash(&self, state: &mut H) { - self.inner_plan.hash(state); - self.lower_bound.to_bits().hash(state); - self.upper_bound.to_bits().hash(state); - self.with_replacement.hash(state); - self.seed.hash(state); + // Hash the bits of the f64 + self.0.to_bits().hash(state); } } -impl PartialEq for TableSamplePlanNode { - fn eq(&self, other: &Self) -> bool { - self.inner_plan == other.inner_plan - && (self.lower_bound - other.lower_bound).abs() < f64::EPSILON - && (self.upper_bound - other.upper_bound).abs() < f64::EPSILON - && self.with_replacement == other.with_replacement - && self.seed == other.seed +impl From for Bound { + fn from(value: f64) -> Self { + Self(value) + } +} +impl From for f64 { + fn from(value: Bound) -> Self { + value.0 } } -impl Eq for TableSamplePlanNode {} - -impl PartialOrd for TableSamplePlanNode { - fn partial_cmp(&self, other: &Self) -> Option { - self.inner_plan - .partial_cmp(&other.inner_plan) - .and_then(|ord| { - if ord != Ordering::Equal { - Some(ord) - } else { - self.lower_bound - .partial_cmp(&other.lower_bound) - .and_then(|ord| { - if ord != Ordering::Equal { - Some(ord) - } else { - self.upper_bound.partial_cmp(&other.upper_bound).and_then( - |ord| { - if ord != Ordering::Equal { - Some(ord) - } else { - self.with_replacement - .partial_cmp(&other.with_replacement) - .and_then(|ord| { - if ord != Ordering::Equal { - Some(ord) - } else { - self.seed.partial_cmp(&other.seed) - } - }) - } - }, - ) - } - }) - } - }) +impl AsRef for Bound { + fn as_ref(&self) -> &f64 { + &self.0 } } + +#[derive(Debug, Clone, Hash, Eq, PartialEq, PartialOrd)] +struct TableSamplePlanNode { + inner_plan: LogicalPlan, + + lower_bound: Bound, + upper_bound: Bound, + with_replacement: bool, + seed: u64, +} + + impl UserDefinedLogicalNodeCore for TableSamplePlanNode { fn name(&self) -> &str { "TableSample" @@ -234,8 +213,8 @@ impl TableSampleExtensionPlanner { Ok(Arc::new(SampleExec { input: physical_input.clone(), lower_bound: 0.0, - upper_bound: specific_node.upper_bound, - with_replacement: specific_node.with_replacement, + upper_bound: specific_node.upper_bound.into(), + with_replacement: specific_node.with_replacement.into(), seed: specific_node.seed, metrics: Default::default(), cache: SampleExec::compute_properties(&physical_input), @@ -663,8 +642,8 @@ impl<'a, S: ContextProvider> TableSamplePlanner<'a, S> { ) -> Result { let node = TableSamplePlanNode { inner_plan: input, - lower_bound: 0.0, - upper_bound: fraction, + lower_bound: Bound::from(0.0), + upper_bound: Bound::from(fraction), with_replacement: with_replacement.unwrap_or(false), seed: seed.unwrap_or_else(rand::random), }; From 66a68ec451021289ec0eada5fe2f2884c7371c1a Mon Sep 17 00:00:00 2001 From: theirix Date: Tue, 23 Sep 2025 21:26:24 +0300 Subject: [PATCH 05/17] Apply docs suggestions from code review Co-authored-by: Andrew Lamb --- datafusion-examples/examples/table_sample.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index 17c93ea1953a..3706232f2f96 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -590,7 +590,7 @@ impl RecordBatchStream for SampleExecStream { } } -/// Query planner to produce a LogicalPlan from AST with table sampling +/// Custom SQL planner that implements support for TABLESAMPLE struct TableSamplePlanner<'a, S: ContextProvider> { context_provider: &'a S, } From ef4752d9540e934950ad6e791051fcb3ea8e7e79 Mon Sep 17 00:00:00 2001 From: theirix Date: Tue, 23 Sep 2025 20:43:54 +0100 Subject: [PATCH 06/17] Optimise imports Signed-off-by: theirix --- datafusion-examples/examples/table_sample.rs | 55 ++++++-------------- 1 file changed, 15 insertions(+), 40 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index 3706232f2f96..18fa49e42adb 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -15,55 +15,42 @@ // specific language governing permissions and limitations // under the License. -#![allow(unused_imports)] - use datafusion::common::{ arrow_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, - DFSchema, DFSchemaRef, ResolvedTableReference, Statistics, TableReference, + DFSchemaRef, Statistics, TableReference, }; use datafusion::error::Result; use datafusion::logical_expr::sqlparser::ast::{ - Query, SetExpr, Statement, TableFactor, TableSample, TableSampleMethod, - TableSampleQuantity, TableSampleUnit, + SetExpr, Statement, TableFactor, TableSampleMethod, + TableSampleUnit, }; use datafusion::logical_expr::{ - AggregateUDF, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Projection, + AggregateUDF, Extension, LogicalPlan, LogicalPlanBuilder, ScalarUDF, TableSource, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, WindowUDF, }; use std::any::Any; -use std::cmp::Ordering; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{HashMap}; -use arrow::util::pretty::{pretty_format_batches, pretty_format_batches_with_schema}; -use datafusion::common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRewriter, -}; +use arrow::util::pretty::{pretty_format_batches}; -use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit}; +use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; -use datafusion::catalog::cte_worktable::CteWorkTable; -use datafusion::common::file_options::file_type::FileType; use datafusion::config::ConfigOptions; -use datafusion::datasource::file_format::format_as_file_type; -use datafusion::datasource::{provider_as_source, DefaultTableSource, TableProvider}; +use datafusion::datasource::{provider_as_source}; use datafusion::error::DataFusionError; use datafusion::execution::{ - FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, + SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; use datafusion::logical_expr::planner::{ - ContextProvider, ExprPlanner, PlannerResult, RawBinaryExpr, TypePlanner, + ContextProvider, ExprPlanner, TypePlanner, }; use datafusion::logical_expr::sqlparser::dialect::PostgreSqlDialect; use datafusion::logical_expr::sqlparser::parser::Parser; -use datafusion::logical_expr::var_provider::is_system_variables; -use datafusion::optimizer::simplify_expressions::ExprSimplifier; -use datafusion::optimizer::AnalyzerRule; use datafusion::physical_expr::EquivalenceProperties; -use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; use datafusion::physical_plan::{ displayable, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, @@ -73,18 +60,8 @@ use datafusion::physical_planner::{ DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner, }; use datafusion::prelude::*; -use datafusion::sql::planner::{ParserOptions, PlannerContext, SqlToRel}; -use datafusion::sql::sqlparser::ast::{TableSampleKind, TableSampleModifier}; -use datafusion::sql::unparser::ast::{ - DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, -}; -use datafusion::sql::unparser::dialect::CustomDialectBuilder; -use datafusion::sql::unparser::expr_to_sql; -use datafusion::sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparser; -use datafusion::sql::unparser::extension_unparser::{ - UnparseToStatementResult, UnparseWithinStatementResult, -}; -use datafusion::sql::unparser::{plan_to_sql, Unparser}; +use datafusion::sql::planner::{SqlToRel}; +use datafusion::sql::sqlparser::ast::{TableSampleKind}; use datafusion::variable::VarType; use log::{debug, info}; use std::fmt; @@ -100,10 +77,9 @@ use arrow::compute; use arrow::record_batch::RecordBatch; use datafusion::execution::context::QueryPlanner; -use datafusion::sql::sqlparser; use datafusion::sql::sqlparser::ast; use futures::stream::{Stream, StreamExt}; -use futures::TryStreamExt; +use futures::{ready, TryStreamExt}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rand_distr::{Distribution, Poisson}; @@ -115,6 +91,7 @@ use rand_distr::{Distribution, Poisson}; /// Hashable and comparible f64 for sampling bounds #[derive(Debug, Clone, Copy, PartialOrd)] struct Bound(f64); + impl PartialEq for Bound { fn eq(&self, other: &Self) -> bool { (self.0 - other.0).abs() < f64::EPSILON @@ -147,7 +124,6 @@ impl AsRef for Bound { } } - #[derive(Debug, Clone, Hash, Eq, PartialEq, PartialOrd)] struct TableSamplePlanNode { inner_plan: LogicalPlan, @@ -158,7 +134,6 @@ struct TableSamplePlanNode { seed: u64, } - impl UserDefinedLogicalNodeCore for TableSamplePlanNode { fn name(&self) -> &str { "TableSample" From c98add8c93faf5a52a374bcba0066b59193d93ce Mon Sep 17 00:00:00 2001 From: theirix Date: Tue, 23 Sep 2025 20:44:25 +0100 Subject: [PATCH 07/17] Move main Signed-off-by: theirix --- datafusion-examples/examples/table_sample.rs | 167 ++++++++++--------- 1 file changed, 85 insertions(+), 82 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index 18fa49e42adb..eeb1f225c732 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -84,9 +84,93 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rand_distr::{Distribution, Poisson}; +/// This example demonstrates how to extend DataFusion's SQL parser to recognize +/// other syntax. +/// This example shows how to extend the DataFusion SQL planner to support the +/// `TABLESAMPLE` clause in SQL queries and then use a custom user defined node +/// to implement the sampling logic in the physical plan. -/// This example demonstrates the table sample support. +#[tokio::main] +async fn main() -> Result<()> { + let _ = env_logger::try_init(); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(TableSampleQueryPlanner {})) + .build(); + + let ctx = SessionContext::new_with_state(state.clone()); + + let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let table_source = provider_as_source(ctx.table_provider("alltypes_plain").await?); + let context_provider = MockContextProvider { + state: &state, + tables: HashMap::>::from([( + "alltypes_plain".into(), + table_source.clone(), + )]), + }; + + let sql = + "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 42 PERCENT REPEATABLE(5) WHERE int_col = 1"; + + let dialect = PostgreSqlDialect {}; + let statements = Parser::parse_sql(&dialect, sql)?; + let statement = statements.first().expect("one statement"); + + // Classical way + // let sql_to_rel = SqlToRel::new(&context_provider); + // let logical_plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; + + // Use sampling planner to create a logical plan + let table_sample_planner = TableSamplePlanner::new(&context_provider); + let logical_plan = table_sample_planner.create_logical_plan(statement.clone())?; + + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + + // Inspect physical plan + let displayable_plan = displayable(physical_plan.as_ref()) + .indent(false) + .to_string(); + info!("Physical plan:\n{displayable_plan}\n"); + let first_line = displayable_plan.lines().next().unwrap(); + assert_eq!( + first_line, + "SampleExec: lower_bound=0, upper_bound=0.42, with_replacement=false, seed=5" + ); + + // Execute via standard sql call - doesn't work + // let df = ctx.sql(sql).await?; + // let batches = df.collect().await?; + + // Execute directly via physical plan + let task_context = Arc::new(TaskContext::from(&ctx)); + let stream = physical_plan.execute(0, task_context)?; + let batches: Vec<_> = stream.try_collect().await?; + + info!("Batches: {:?}", &batches); + + let result_string = pretty_format_batches(&batches) + // pretty_format_batches_with_schema(table_source.schema(), &batches) + .map_err(|e| arrow_datafusion_err!(e)) + .map(|d| d.to_string())?; + let result_strings = result_string.lines().collect::>(); + info!("Batch result: {:?}", &result_strings); + + assert_eq!(batches.len(), 1); + assert_eq!(batches.first().unwrap().num_rows(), 2); + + info!("Done"); + Ok(()) +} /// Hashable and comparible f64 for sampling bounds #[derive(Debug, Clone, Copy, PartialOrd)] @@ -828,87 +912,6 @@ impl ContextProvider for MockContextProvider<'_> { } } -#[tokio::main] -async fn main() -> Result<()> { - let _ = env_logger::try_init(); - - let state = SessionStateBuilder::new() - .with_default_features() - .with_query_planner(Arc::new(TableSampleQueryPlanner {})) - .build(); - - let ctx = SessionContext::new_with_state(state.clone()); - - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await?; - - let table_source = provider_as_source(ctx.table_provider("alltypes_plain").await?); - let context_provider = MockContextProvider { - state: &state, - tables: HashMap::>::from([( - "alltypes_plain".into(), - table_source.clone(), - )]), - }; - - let sql = - "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 42 PERCENT REPEATABLE(5) WHERE int_col = 1"; - - let dialect = PostgreSqlDialect {}; - let statements = Parser::parse_sql(&dialect, sql)?; - let statement = statements.first().expect("one statement"); - - // Classical way - // let sql_to_rel = SqlToRel::new(&context_provider); - // let logical_plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; - - // Use sampling planner to create a logical plan - let table_sample_planner = TableSamplePlanner::new(&context_provider); - let logical_plan = table_sample_planner.create_logical_plan(statement.clone())?; - - let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; - - // Inspect physical plan - let displayable_plan = displayable(physical_plan.as_ref()) - .indent(false) - .to_string(); - info!("Physical plan:\n{displayable_plan}\n"); - let first_line = displayable_plan.lines().next().unwrap(); - assert_eq!( - first_line, - "SampleExec: lower_bound=0, upper_bound=0.42, with_replacement=false, seed=5" - ); - - // Execute via standard sql call - doesn't work - // let df = ctx.sql(sql).await?; - // let batches = df.collect().await?; - - // Execute directly via physical plan - let task_context = Arc::new(TaskContext::from(&ctx)); - let stream = physical_plan.execute(0, task_context)?; - let batches: Vec<_> = stream.try_collect().await?; - - info!("Batches: {:?}", &batches); - - let result_string = pretty_format_batches(&batches) - // pretty_format_batches_with_schema(table_source.schema(), &batches) - .map_err(|e| arrow_datafusion_err!(e)) - .map(|d| d.to_string())?; - let result_strings = result_string.lines().collect::>(); - info!("Batch result: {:?}", &result_strings); - - assert_eq!(batches.len(), 1); - assert_eq!(batches.first().unwrap().num_rows(), 2); - - info!("Done"); - Ok(()) -} - #[cfg(test)] mod tests { use super::*; From 6e6e4a2f10370877b98a63b5d3ecab4c909eec2d Mon Sep 17 00:00:00 2001 From: theirix Date: Tue, 23 Sep 2025 20:45:16 +0100 Subject: [PATCH 08/17] Fixup newtype casts Signed-off-by: theirix --- datafusion-examples/examples/table_sample.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index eeb1f225c732..05945bc2e5cd 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -994,8 +994,8 @@ mod tests { "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 0.42 where int_col = 1"; let (_, logical_plan) = parse_to_logical_plan(sql).await?; let node = as_table_sample_node(logical_plan)?; - assert_eq!(node.lower_bound, 0.0); - assert_eq!(node.upper_bound, 0.42); + assert_eq!(f64::from(node.lower_bound), 0.0); + assert_eq!(f64::from(node.upper_bound), 0.42); assert_eq!(node.with_replacement, false); Ok(()) } @@ -1006,8 +1006,8 @@ mod tests { "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 0.42 REPEATABLE(123) where int_col = 1"; let (_, logical_plan) = parse_to_logical_plan(sql).await?; let node = as_table_sample_node(logical_plan)?; - assert_eq!(node.lower_bound, 0.0); - assert_eq!(node.upper_bound, 0.42); + assert_eq!(f64::from(node.lower_bound), 0.0); + assert_eq!(f64::from(node.upper_bound), 0.42); assert_eq!(node.with_replacement, false); assert_eq!(node.seed, 123); Ok(()) @@ -1019,8 +1019,8 @@ mod tests { "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 42 PERCENT where int_col = 1"; let (_, logical_plan) = parse_to_logical_plan(sql).await?; let node = as_table_sample_node(logical_plan)?; - assert_eq!(node.lower_bound, 0.0); - assert_eq!(node.upper_bound, 0.42); + assert_eq!(f64::from(node.lower_bound), 0.0); + assert_eq!(f64::from(node.upper_bound), 0.42); assert_eq!(node.with_replacement, false); Ok(()) } @@ -1074,8 +1074,8 @@ mod tests { "SELECT int_col, double_col FROM alltypes_plain SAMPLE 0.42 where int_col = 1"; let (_, logical_plan) = parse_to_logical_plan(sql).await?; let node = as_table_sample_node(logical_plan)?; - assert_eq!(node.lower_bound, 0.0); - assert_eq!(node.upper_bound, 0.42); + assert_eq!(f64::from(node.lower_bound), 0.0); + assert_eq!(f64::from(node.upper_bound), 0.42); assert_eq!(node.with_replacement, false); Ok(()) } @@ -1100,8 +1100,8 @@ mod tests { "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE (BUCKET 3 OUT OF 16) where int_col = 1"; let (_, logical_plan) = parse_to_logical_plan(sql).await?; let node = as_table_sample_node(logical_plan)?; - assert_eq!(node.lower_bound, 0.0); - assert!((node.upper_bound - 3.0 / 16.0).abs() < f64::EPSILON); + assert_eq!(f64::from(node.lower_bound), 0.0); + assert!((f64::from(node.upper_bound) - 3.0 / 16.0).abs() < f64::EPSILON); assert_eq!(node.with_replacement, false); Ok(()) } From 6482b52f3bb65dfb7291c889d763eacc80f411f0 Mon Sep 17 00:00:00 2001 From: theirix Date: Tue, 23 Sep 2025 20:45:53 +0100 Subject: [PATCH 09/17] Use poll helper and record metrics Signed-off-by: theirix --- datafusion-examples/examples/table_sample.rs | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index 05945bc2e5cd..b8992c3db9bd 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -626,19 +626,16 @@ impl Stream for SampleExecStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let poll = self.input.poll_next_unpin(cx); - let baseline_metrics = &mut self.baseline_metrics; - - match poll { - Poll::Ready(Some(Ok(batch))) => { - let start = baseline_metrics.elapsed_compute().clone(); + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let start = self.baseline_metrics.elapsed_compute().clone(); let result = self.sampler.sample(&batch); + let result = result.record_output(&self.baseline_metrics); let _timer = start.timer(); Poll::Ready(Some(result)) } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), } } } From eefeaf698acba8ad164a442ce367466ed9cdb084 Mon Sep 17 00:00:00 2001 From: theirix Date: Tue, 23 Sep 2025 22:37:41 +0100 Subject: [PATCH 10/17] Make SessionContextProvider public --- datafusion/core/src/execution/session_state.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index f6582909041a..3acda9ce0c9e 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -1664,11 +1664,22 @@ impl From for SessionStateBuilder { /// This is used so the SQL planner can access the state of the session without /// having a direct dependency on the [`SessionState`] struct (and core crate) #[cfg(feature = "sql")] -struct SessionContextProvider<'a> { +pub struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, } +#[cfg(feature = "sql")] +impl<'a> SessionContextProvider<'a> { + /// Construct the [`SessionContextProvider`] struct + pub fn new(state: &'a SessionState, tables: HashMap>) -> Self { + Self { + state, + tables + } + } +} + #[cfg(feature = "sql")] impl ContextProvider for SessionContextProvider<'_> { fn get_expr_planners(&self) -> &[Arc] { From 4a8d8d93169bdb264cb51d7e948ed847de43d9f6 Mon Sep 17 00:00:00 2001 From: theirix Date: Tue, 23 Sep 2025 22:38:10 +0100 Subject: [PATCH 11/17] Reuse SessionContextProvider in table sample code --- datafusion-examples/examples/table_sample.rs | 123 +++++-------------- 1 file changed, 34 insertions(+), 89 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index b8992c3db9bd..b1c0bcdb2b0e 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -17,35 +17,29 @@ use datafusion::common::{ arrow_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, - DFSchemaRef, Statistics, TableReference, + DFSchemaRef, ResolvedTableReference, Statistics, TableReference, }; use datafusion::error::Result; use datafusion::logical_expr::sqlparser::ast::{ - SetExpr, Statement, TableFactor, TableSampleMethod, - TableSampleUnit, + SetExpr, Statement, TableFactor, TableSampleMethod, TableSampleUnit, }; use datafusion::logical_expr::{ - AggregateUDF, Extension, LogicalPlan, LogicalPlanBuilder, - ScalarUDF, TableSource, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, - WindowUDF, + Extension, LogicalPlan, LogicalPlanBuilder, TableSource, UserDefinedLogicalNode, + UserDefinedLogicalNodeCore, }; use std::any::Any; -use std::collections::{HashMap}; +use std::collections::HashMap; -use arrow::util::pretty::{pretty_format_batches}; +use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{DataType, SchemaRef}; +use arrow_schema::SchemaRef; use async_trait::async_trait; -use datafusion::config::ConfigOptions; -use datafusion::datasource::{provider_as_source}; +use datafusion::datasource::provider_as_source; use datafusion::error::DataFusionError; use datafusion::execution::{ - SendableRecordBatchStream, SessionState, SessionStateBuilder, - TaskContext, -}; -use datafusion::logical_expr::planner::{ - ContextProvider, ExprPlanner, TypePlanner, + SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; +use datafusion::logical_expr::planner::ContextProvider; use datafusion::logical_expr::sqlparser::dialect::PostgreSqlDialect; use datafusion::logical_expr::sqlparser::parser::Parser; use datafusion::physical_expr::EquivalenceProperties; @@ -60,9 +54,8 @@ use datafusion::physical_planner::{ DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner, }; use datafusion::prelude::*; -use datafusion::sql::planner::{SqlToRel}; -use datafusion::sql::sqlparser::ast::{TableSampleKind}; -use datafusion::variable::VarType; +use datafusion::sql::planner::SqlToRel; +use datafusion::sql::sqlparser::ast::TableSampleKind; use log::{debug, info}; use std::fmt; use std::fmt::{Debug, Formatter}; @@ -77,6 +70,7 @@ use arrow::compute; use arrow::record_batch::RecordBatch; use datafusion::execution::context::QueryPlanner; +use datafusion::execution::session_state::SessionContextProvider; use datafusion::sql::sqlparser::ast; use futures::stream::{Stream, StreamExt}; use futures::{ready, TryStreamExt}; @@ -110,14 +104,19 @@ async fn main() -> Result<()> { ) .await?; + // Construct a context provider from this parquet table source let table_source = provider_as_source(ctx.table_provider("alltypes_plain").await?); - let context_provider = MockContextProvider { - state: &state, - tables: HashMap::>::from([( - "alltypes_plain".into(), + let resolved_table_ref = TableReference::bare("alltypes_plain").resolve( + &state.config_options().catalog.default_catalog, + &state.config_options().catalog.default_schema, + ); + let context_provider = SessionContextProvider::new( + &state, + HashMap::>::from([( + resolved_table_ref, table_source.clone(), )]), - }; + ); let sql = "SELECT int_col, double_col FROM alltypes_plain TABLESAMPLE 42 PERCENT REPEATABLE(5) WHERE int_col = 1"; @@ -849,72 +848,14 @@ impl<'a, S: ContextProvider> TableSamplePlanner<'a, S> { } } -// Context provider for tests - -struct MockContextProvider<'a> { - state: &'a SessionState, - tables: HashMap>, -} - -impl ContextProvider for MockContextProvider<'_> { - fn get_table_source(&self, name: TableReference) -> Result> { - self.tables - .get(&name) - .cloned() - .ok_or_else(|| plan_datafusion_err!("table '{name}' not found")) - } - - fn get_expr_planners(&self) -> &[Arc] { - self.state.expr_planners() - } - - fn get_type_planner(&self) -> Option> { - None - } - - fn get_function_meta(&self, name: &str) -> Option> { - self.state.scalar_functions().get(name).cloned() - } - - fn get_aggregate_meta(&self, name: &str) -> Option> { - self.state.aggregate_functions().get(name).cloned() - } - - fn get_window_meta(&self, name: &str) -> Option> { - self.state.window_functions().get(name).cloned() - } - - fn get_variable_type(&self, variable_names: &[String]) -> Option { - self.state - .execution_props() - .var_providers - .as_ref() - .and_then(|provider| provider.get(&VarType::System)?.get_type(variable_names)) - } - - fn options(&self) -> &ConfigOptions { - self.state.config_options() - } - - fn udf_names(&self) -> Vec { - self.state.scalar_functions().keys().cloned().collect() - } - - fn udaf_names(&self) -> Vec { - self.state.aggregate_functions().keys().cloned().collect() - } - - fn udwf_names(&self) -> Vec { - self.state.window_functions().keys().cloned().collect() - } -} - #[cfg(test)] mod tests { use super::*; use arrow::array::Int32Array; use arrow::datatypes::{Field, Schema}; use datafusion::assert_batches_eq; + use datafusion::common::ResolvedTableReference; + use datafusion::execution::session_state::SessionContextProvider; use datafusion::physical_plan::test::TestMemoryExec; use futures::TryStreamExt; use std::sync::Arc; @@ -944,13 +885,17 @@ mod tests { let table_source = provider_as_source(ctx.table_provider("alltypes_plain").await?); - let context_provider = MockContextProvider { - state: &state, - tables: HashMap::>::from([( - "alltypes_plain".into(), + let resolved_table_ref = TableReference::bare("alltypes_plain").resolve( + &state.config_options().catalog.default_catalog, + &state.config_options().catalog.default_schema, + ); + let context_provider = SessionContextProvider::new( + &state, + HashMap::>::from([( + resolved_table_ref, table_source.clone(), )]), - }; + ); let dialect = PostgreSqlDialect {}; let statements = Parser::parse_sql(&dialect, sql)?; From 8b13b0e743abe7aff85271bbd619e6d48acca878 Mon Sep 17 00:00:00 2001 From: theirix Date: Tue, 23 Sep 2025 22:52:14 +0100 Subject: [PATCH 12/17] Reformat SessionContextProvider --- datafusion/core/src/execution/session_state.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 3acda9ce0c9e..ca1377c63065 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -1672,11 +1672,11 @@ pub struct SessionContextProvider<'a> { #[cfg(feature = "sql")] impl<'a> SessionContextProvider<'a> { /// Construct the [`SessionContextProvider`] struct - pub fn new(state: &'a SessionState, tables: HashMap>) -> Self { - Self { - state, - tables - } + pub fn new( + state: &'a SessionState, + tables: HashMap>, + ) -> Self { + Self { state, tables } } } From e63de51944ae4dbb6d3c3b47501766ce7e594799 Mon Sep 17 00:00:00 2001 From: theirix Date: Tue, 23 Sep 2025 23:08:31 +0100 Subject: [PATCH 13/17] Fix clippy lints --- datafusion-examples/examples/table_sample.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index b1c0bcdb2b0e..55408751e52f 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -80,7 +80,7 @@ use rand_distr::{Distribution, Poisson}; /// This example demonstrates how to extend DataFusion's SQL parser to recognize /// other syntax. - +/// /// This example shows how to extend the DataFusion SQL planner to support the /// `TABLESAMPLE` clause in SQL queries and then use a custom user defined node /// to implement the sampling logic in the physical plan. @@ -272,7 +272,7 @@ impl TableSampleExtensionPlanner { input: physical_input.clone(), lower_bound: 0.0, upper_bound: specific_node.upper_bound.into(), - with_replacement: specific_node.with_replacement.into(), + with_replacement: specific_node.with_replacement, seed: specific_node.seed, metrics: Default::default(), cache: SampleExec::compute_properties(&physical_input), From 591c6a0f93a12fb7a55d9f82d4a2b36d0f159599 Mon Sep 17 00:00:00 2001 From: theirix Date: Sat, 27 Sep 2025 09:34:57 +0100 Subject: [PATCH 14/17] Rework comments --- datafusion-examples/examples/table_sample.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index 55408751e52f..a0b4f9215431 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -125,11 +125,8 @@ async fn main() -> Result<()> { let statements = Parser::parse_sql(&dialect, sql)?; let statement = statements.first().expect("one statement"); - // Classical way - // let sql_to_rel = SqlToRel::new(&context_provider); - // let logical_plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; - - // Use sampling planner to create a logical plan + // Use a custom sampling planner to create a logical plan + // instead of [SqlToRel::sql_statement_to_plan] let table_sample_planner = TableSamplePlanner::new(&context_provider); let logical_plan = table_sample_planner.create_logical_plan(statement.clone())?; @@ -146,10 +143,6 @@ async fn main() -> Result<()> { "SampleExec: lower_bound=0, upper_bound=0.42, with_replacement=false, seed=5" ); - // Execute via standard sql call - doesn't work - // let df = ctx.sql(sql).await?; - // let batches = df.collect().await?; - // Execute directly via physical plan let task_context = Arc::new(TaskContext::from(&ctx)); let stream = physical_plan.execute(0, task_context)?; @@ -158,7 +151,6 @@ async fn main() -> Result<()> { info!("Batches: {:?}", &batches); let result_string = pretty_format_batches(&batches) - // pretty_format_batches_with_schema(table_source.schema(), &batches) .map_err(|e| arrow_datafusion_err!(e)) .map(|d| d.to_string())?; let result_strings = result_string.lines().collect::>(); From 003ccb639a77f77ac40e63125defa8deba0e0852 Mon Sep 17 00:00:00 2001 From: theirix Date: Sat, 27 Sep 2025 09:34:36 +0100 Subject: [PATCH 15/17] Streamline imports --- datafusion-examples/examples/table_sample.rs | 43 ++++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index a0b4f9215431..10d46ff1c999 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -15,6 +15,22 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::collections::HashMap; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::UInt32Array; +use arrow::compute; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use arrow_schema::SchemaRef; + use datafusion::common::{ arrow_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchemaRef, ResolvedTableReference, Statistics, TableReference, @@ -27,15 +43,11 @@ use datafusion::logical_expr::{ Extension, LogicalPlan, LogicalPlanBuilder, TableSource, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, }; -use std::any::Any; -use std::collections::HashMap; -use arrow::util::pretty::pretty_format_batches; - -use arrow_schema::SchemaRef; -use async_trait::async_trait; use datafusion::datasource::provider_as_source; use datafusion::error::DataFusionError; +use datafusion::execution::context::QueryPlanner; +use datafusion::execution::session_state::SessionContextProvider; use datafusion::execution::{ SendableRecordBatchStream, SessionState, SessionStateBuilder, TaskContext, }; @@ -55,25 +67,12 @@ use datafusion::physical_planner::{ }; use datafusion::prelude::*; use datafusion::sql::planner::SqlToRel; -use datafusion::sql::sqlparser::ast::TableSampleKind; -use log::{debug, info}; -use std::fmt; -use std::fmt::{Debug, Formatter}; -use std::hash::{Hash, Hasher}; -use std::pin::Pin; -use std::str::FromStr; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use arrow::array::UInt32Array; -use arrow::compute; -use arrow::record_batch::RecordBatch; - -use datafusion::execution::context::QueryPlanner; -use datafusion::execution::session_state::SessionContextProvider; use datafusion::sql::sqlparser::ast; + +use async_trait::async_trait; use futures::stream::{Stream, StreamExt}; use futures::{ready, TryStreamExt}; +use log::{debug, info}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rand_distr::{Distribution, Poisson}; From 71a1452b4faedf5985fb3889a01151025bd01b1d Mon Sep 17 00:00:00 2001 From: theirix Date: Sat, 27 Sep 2025 09:35:17 +0100 Subject: [PATCH 16/17] Better separation of ast vs logical expressions --- datafusion-examples/examples/table_sample.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index 10d46ff1c999..9781bdb6207c 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -702,11 +702,11 @@ impl<'a, S: ContextProvider> TableSamplePlanner<'a, S> { fn sample_to_logical_plan( &self, input: LogicalPlan, - sample: TableSampleKind, + sample: ast::TableSampleKind, ) -> Result { let sample = match sample { - TableSampleKind::BeforeTableAlias(sample) => sample, - TableSampleKind::AfterTableAlias(sample) => sample, + ast::TableSampleKind::BeforeTableAlias(sample) => sample, + ast::TableSampleKind::AfterTableAlias(sample) => sample, }; if let Some(name) = &sample.name { if *name != TableSampleMethod::Bernoulli && *name != TableSampleMethod::Row { From 500bab93737d7ff4bb89c8247b40ef52b2c35df3 Mon Sep 17 00:00:00 2001 From: theirix Date: Sat, 27 Sep 2025 09:37:28 +0100 Subject: [PATCH 17/17] Fix typo --- datafusion-examples/examples/table_sample.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-examples/examples/table_sample.rs b/datafusion-examples/examples/table_sample.rs index 9781bdb6207c..81c7c9345b9d 100644 --- a/datafusion-examples/examples/table_sample.rs +++ b/datafusion-examples/examples/table_sample.rs @@ -162,7 +162,7 @@ async fn main() -> Result<()> { Ok(()) } -/// Hashable and comparible f64 for sampling bounds +/// Hashable and comparable f64 for sampling bounds #[derive(Debug, Clone, Copy, PartialOrd)] struct Bound(f64);