Skip to content

Commit 70daf88

Browse files
authored
feat: plan-time SQL expression simplifying (#19311)
## Which issue does this PR close? - Closes #19312. ## Rationale for this change Introduce a module for parsing and simplification of an SQL expression to a literal of a given type. This module provides functionality to parse and simplify static SQL expressions used in SQL constructs like `FROM TABLE SAMPLE (10 + 50 * 2)`. If they are required in a planning (not an execution) phase, they need to be reduced to literals of a given type. ## What changes are included in this PR? 1. New module with documentation and unit tests 2. A table sample example is switched to use this module for SQL parsing 3. A small fix-up to run `cargo run --package datafusion-examples --example relation_planner` instead of `cargo run --package datafusion-examples --example relation_planner -- all` ## Are these changes tested? - Unit tests - Table sample example runs ## Are there any user-facing changes? A new API is provided --------- Signed-off-by: theirix <[email protected]>
1 parent 2ac032b commit 70daf88

File tree

3 files changed

+158
-39
lines changed

3 files changed

+158
-39
lines changed

datafusion-examples/examples/relation_planner/table_sample.rs

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,12 @@ use std::{
8383
any::Any,
8484
fmt::{self, Debug, Formatter},
8585
hash::{Hash, Hasher},
86-
ops::{Add, Div, Mul, Sub},
8786
pin::Pin,
88-
str::FromStr,
8987
sync::Arc,
9088
task::{Context, Poll},
9189
};
9290

91+
use arrow::datatypes::{Float64Type, Int64Type};
9392
use arrow::{
9493
array::{ArrayRef, Int32Array, RecordBatch, StringArray, UInt32Array},
9594
compute,
@@ -102,6 +101,7 @@ use futures::{
102101
use rand::{Rng, SeedableRng, rngs::StdRng};
103102
use tonic::async_trait;
104103

104+
use datafusion::optimizer::simplify_expressions::simplify_literal::parse_literal;
105105
use datafusion::{
106106
execution::{
107107
RecordBatchStream, SendableRecordBatchStream, SessionState, SessionStateBuilder,
@@ -410,11 +410,12 @@ impl RelationPlanner for TableSamplePlanner {
410410
"TABLESAMPLE requires a quantity (percentage, fraction, or row count)"
411411
);
412412
};
413+
let quantity_value_expr = context.sql_to_expr(quantity.value, input.schema())?;
413414

414415
match quantity.unit {
415416
// TABLESAMPLE (N ROWS) - exact row limit
416417
Some(TableSampleUnit::Rows) => {
417-
let rows = parse_quantity::<i64>(&quantity.value)?;
418+
let rows: i64 = parse_literal::<Int64Type>(&quantity_value_expr)?;
418419
if rows < 0 {
419420
return plan_err!("row count must be non-negative, got {}", rows);
420421
}
@@ -426,15 +427,15 @@ impl RelationPlanner for TableSamplePlanner {
426427

427428
// TABLESAMPLE (N PERCENT) - percentage sampling
428429
Some(TableSampleUnit::Percent) => {
429-
let percent = parse_quantity::<f64>(&quantity.value)?;
430+
let percent: f64 = parse_literal::<Float64Type>(&quantity_value_expr)?;
430431
let fraction = percent / 100.0;
431432
let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan();
432433
Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias)))
433434
}
434435

435436
// TABLESAMPLE (N) - fraction if <1.0, row limit if >=1.0
436437
None => {
437-
let value = parse_quantity::<f64>(&quantity.value)?;
438+
let value = parse_literal::<Float64Type>(&quantity_value_expr)?;
438439
if value < 0.0 {
439440
return plan_err!("sample value must be non-negative, got {}", value);
440441
}
@@ -453,40 +454,6 @@ impl RelationPlanner for TableSamplePlanner {
453454
}
454455
}
455456

456-
/// Parse a SQL expression as a numeric value (supports basic arithmetic).
457-
fn parse_quantity<T>(expr: &ast::Expr) -> Result<T>
458-
where
459-
T: FromStr + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
460-
{
461-
eval_numeric_expr(expr)
462-
.ok_or_else(|| plan_datafusion_err!("invalid numeric expression: {:?}", expr))
463-
}
464-
465-
/// Recursively evaluate numeric SQL expressions.
466-
fn eval_numeric_expr<T>(expr: &ast::Expr) -> Option<T>
467-
where
468-
T: FromStr + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
469-
{
470-
match expr {
471-
ast::Expr::Value(v) => match &v.value {
472-
ast::Value::Number(n, _) => n.to_string().parse().ok(),
473-
_ => None,
474-
},
475-
ast::Expr::BinaryOp { left, op, right } => {
476-
let l = eval_numeric_expr::<T>(left)?;
477-
let r = eval_numeric_expr::<T>(right)?;
478-
match op {
479-
ast::BinaryOperator::Plus => Some(l + r),
480-
ast::BinaryOperator::Minus => Some(l - r),
481-
ast::BinaryOperator::Multiply => Some(l * r),
482-
ast::BinaryOperator::Divide => Some(l / r),
483-
_ => None,
484-
}
485-
}
486-
_ => None,
487-
}
488-
}
489-
490457
/// Custom logical plan node representing a TABLESAMPLE operation.
491458
///
492459
/// Stores sampling parameters (bounds, seed) and wraps the input plan.

datafusion/optimizer/src/simplify_expressions/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub mod expr_simplifier;
2222
mod inlist_simplifier;
2323
mod regex;
2424
pub mod simplify_exprs;
25+
pub mod simplify_literal;
2526
mod simplify_predicates;
2627
mod unwrap_cast;
2728
mod utils;
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Parses and simplifies an expression to a literal of a given type.
19+
//!
20+
//! This module provides functionality to parse and simplify static expressions
21+
//! used in SQL constructs like `FROM TABLE SAMPLE (10 + 50 * 2)`. If they are required
22+
//! in a planning (not an execution) phase, they need to be reduced to literals of a given type.
23+
24+
use crate::simplify_expressions::ExprSimplifier;
25+
use arrow::datatypes::ArrowPrimitiveType;
26+
use datafusion_common::{
27+
DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, plan_datafusion_err,
28+
plan_err,
29+
};
30+
use datafusion_expr::Expr;
31+
use datafusion_expr::execution_props::ExecutionProps;
32+
use datafusion_expr::simplify::SimplifyContext;
33+
use std::sync::Arc;
34+
35+
/// Parse and simplifies an expression to a numeric literal,
36+
/// corresponding to an arrow primitive type `T` (for example, Float64Type).
37+
///
38+
/// This function simplifies and coerces the expression, then extracts the underlying
39+
/// native type using `TryFrom<ScalarValue>`.
40+
///
41+
/// # Example
42+
/// ```ignore
43+
/// let value: f64 = parse_literal::<Float64Type>(expr)?;
44+
/// ```
45+
pub fn parse_literal<T>(expr: &Expr) -> Result<T::Native>
46+
where
47+
T: ArrowPrimitiveType,
48+
T::Native: TryFrom<ScalarValue, Error = DataFusionError>,
49+
{
50+
// Empty schema is sufficient because it parses only literal expressions
51+
let schema = DFSchemaRef::new(DFSchema::empty());
52+
53+
log::debug!("Parsing expr {:?} to type {}", expr, T::DATA_TYPE);
54+
55+
let execution_props = ExecutionProps::new();
56+
let simplifier = ExprSimplifier::new(
57+
SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)),
58+
);
59+
60+
// Simplify and coerce expression in case of constant arithmetic operations (e.g., 10 + 5)
61+
let simplified_expr: Expr = simplifier
62+
.simplify(expr.clone())
63+
.map_err(|err| plan_datafusion_err!("Cannot simplify {expr:?}: {err}"))?;
64+
let coerced_expr: Expr = simplifier.coerce(simplified_expr, schema.as_ref())?;
65+
log::debug!("Coerced expression: {:?}", &coerced_expr);
66+
67+
match coerced_expr {
68+
Expr::Literal(scalar_value, _) => {
69+
// It is a literal - proceed to the underlying value
70+
// Cast to the target type if needed
71+
let casted_scalar = scalar_value.cast_to(&T::DATA_TYPE)?;
72+
73+
// Extract the native type
74+
T::Native::try_from(casted_scalar).map_err(|err| {
75+
plan_datafusion_err!(
76+
"Cannot extract {} from scalar value: {err}",
77+
std::any::type_name::<T>()
78+
)
79+
})
80+
}
81+
actual => {
82+
plan_err!(
83+
"Cannot extract literal from coerced {actual:?} expression given {expr:?} expression"
84+
)
85+
}
86+
}
87+
}
88+
89+
#[cfg(test)]
90+
mod tests {
91+
use super::*;
92+
use arrow::datatypes::{Float64Type, Int64Type};
93+
use datafusion_expr::{BinaryExpr, lit};
94+
use datafusion_expr_common::operator::Operator;
95+
96+
#[test]
97+
fn test_parse_sql_float_literal() {
98+
let test_cases = vec![
99+
(Expr::Literal(ScalarValue::Float64(Some(0.0)), None), 0.0),
100+
(Expr::Literal(ScalarValue::Float64(Some(1.0)), None), 1.0),
101+
(
102+
Expr::BinaryExpr(BinaryExpr::new(
103+
Box::new(lit(50.0)),
104+
Operator::Minus,
105+
Box::new(lit(10.0)),
106+
)),
107+
40.0,
108+
),
109+
(
110+
Expr::Literal(ScalarValue::Utf8(Some("1e2".into())), None),
111+
100.0,
112+
),
113+
(
114+
Expr::Literal(ScalarValue::Utf8(Some("2.5e-1".into())), None),
115+
0.25,
116+
),
117+
];
118+
119+
for (expr, expected) in test_cases {
120+
let result: Result<f64> = parse_literal::<Float64Type>(&expr);
121+
122+
match result {
123+
Ok(value) => {
124+
assert!(
125+
(value - expected).abs() < 1e-10,
126+
"For expression '{expr}': expected {expected}, got {value}",
127+
);
128+
}
129+
Err(e) => panic!("Failed to parse expression '{expr}': {e}"),
130+
}
131+
}
132+
}
133+
134+
#[test]
135+
fn test_parse_sql_integer_literal() {
136+
let expr = Expr::BinaryExpr(BinaryExpr::new(
137+
Box::new(lit(2)),
138+
Operator::Plus,
139+
Box::new(lit(4)),
140+
));
141+
142+
let result: Result<i64> = parse_literal::<Int64Type>(&expr);
143+
144+
match result {
145+
Ok(value) => {
146+
assert_eq!(6, value);
147+
}
148+
Err(e) => panic!("Failed to parse expression: {e}"),
149+
}
150+
}
151+
}

0 commit comments

Comments
 (0)