Skip to content

Commit 510b5bc

Browse files
committed
Add tests for additional cases
1 parent 9ae434e commit 510b5bc

File tree

1 file changed

+156
-61
lines changed

1 file changed

+156
-61
lines changed

datafusion/optimizer/src/simplify_expressions/udf_preimage.rs

Lines changed: 156 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
// under the License.
1717

1818
use datafusion_common::{Result, internal_err, tree_node::Transformed};
19-
use datafusion_expr::{
20-
Expr, Operator, and, binary_expr, lit, or, simplify::SimplifyContext,
21-
};
19+
use datafusion_expr::{Expr, Operator, and, lit, or, simplify::SimplifyContext};
2220
use datafusion_expr_common::interval_arithmetic::Interval;
2321

2422
/// Rewrites a binary expression using its "preimage"
@@ -46,32 +44,32 @@ pub(super) fn rewrite_with_preimage(
4644
) -> Result<Transformed<Expr>> {
4745
let (lower, upper) = preimage_interval.into_bounds();
4846
let (lower, upper) = (lit(lower), lit(upper));
47+
let expr = *expr;
4948

5049
let rewritten_expr = match op {
5150
// <expr> < x ==> <expr> < lower
5251
// <expr> >= x ==> <expr> >= lower
53-
Operator::Lt | Operator::GtEq => binary_expr(*expr, op, lower),
52+
Operator::Lt => expr.lt(lower),
53+
Operator::GtEq => expr.gt_eq(lower),
5454
// <expr> > x ==> <expr> >= upper
55-
Operator::Gt => binary_expr(*expr, Operator::GtEq, upper),
55+
Operator::Gt => expr.gt_eq(upper),
5656
// <expr> <= x ==> <expr> < upper
57-
Operator::LtEq => binary_expr(*expr, Operator::Lt, upper),
57+
Operator::LtEq => expr.lt(upper),
5858
// <expr> = x ==> (<expr> >= lower) and (<expr> < upper)
5959
//
6060
// <expr> is not distinct from x ==> (<expr> is NULL and x is NULL) or ((<expr> >= lower) and (<expr> < upper))
6161
// but since x is always not NULL => (<expr> >= lower) and (<expr> < upper)
62-
Operator::Eq | Operator::IsNotDistinctFrom => and(
63-
binary_expr(*expr.clone(), Operator::GtEq, lower),
64-
binary_expr(*expr, Operator::Lt, upper),
65-
),
62+
Operator::Eq | Operator::IsNotDistinctFrom => {
63+
and(expr.clone().gt_eq(lower), expr.lt(upper))
64+
}
6665
// <expr> != x ==> (<expr> < lower) or (<expr> >= upper)
67-
Operator::NotEq => or(
68-
binary_expr(*expr.clone(), Operator::Lt, lower),
69-
binary_expr(*expr, Operator::GtEq, upper),
70-
),
66+
Operator::NotEq => or(expr.clone().lt(lower), expr.gt_eq(upper)),
7167
// <expr> is distinct from x ==> (<expr> < lower) or (<expr> >= upper) or (<expr> is NULL and x is not NULL) or (<expr> is not NULL and x is NULL)
7268
// but given that x is always not NULL => (<expr> < lower) or (<expr> >= upper) or (<expr> is NULL)
73-
Operator::IsDistinctFrom => binary_expr(*expr.clone(), Operator::Lt, lower)
74-
.or(binary_expr(*expr.clone(), Operator::GtEq, upper))
69+
Operator::IsDistinctFrom => expr
70+
.clone()
71+
.lt(lower)
72+
.or(expr.clone().gt_eq(upper))
7573
.or(expr.is_null()),
7674
_ => return internal_err!("Expect comparison operators"),
7775
};
@@ -86,17 +84,56 @@ mod test {
8684
use arrow::datatypes::{DataType, Field};
8785
use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue};
8886
use datafusion_expr::{
89-
ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
90-
Signature, Volatility, and, binary_expr, col, expr::ScalarFunction, lit,
91-
simplify::SimplifyContext,
87+
BinaryExpr, ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF,
88+
ScalarUDFImpl, Signature, Volatility, and, col, lit, simplify::SimplifyContext,
9289
};
9390

9491
use super::Interval;
9592
use crate::simplify_expressions::ExprSimplifier;
9693

94+
fn is_distinct_from(left: Expr, right: Expr) -> Expr {
95+
Expr::BinaryExpr(BinaryExpr {
96+
left: Box::new(left),
97+
op: Operator::IsDistinctFrom,
98+
right: Box::new(right),
99+
})
100+
}
101+
102+
fn is_not_distinct_from(left: Expr, right: Expr) -> Expr {
103+
Expr::BinaryExpr(BinaryExpr {
104+
left: Box::new(left),
105+
op: Operator::IsNotDistinctFrom,
106+
right: Box::new(right),
107+
})
108+
}
109+
97110
#[derive(Debug, PartialEq, Eq, Hash)]
98111
struct PreimageUdf {
112+
/// Defaults to an exact signature with one Int32 argument and Immutable volatility
99113
signature: Signature,
114+
/// If true, returns a preimage; otherwise, returns None
115+
enabled: bool,
116+
}
117+
118+
impl PreimageUdf {
119+
fn new() -> Self {
120+
Self {
121+
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
122+
enabled: true,
123+
}
124+
}
125+
126+
/// Set the enabled flag
127+
fn with_enabled(mut self, enabled: bool) -> Self {
128+
self.enabled = enabled;
129+
self
130+
}
131+
132+
/// Set the volatility
133+
fn with_volatility(mut self, volatility: Volatility) -> Self {
134+
self.signature.volatility = volatility;
135+
self
136+
}
100137
}
101138

102139
impl ScalarUDFImpl for PreimageUdf {
@@ -126,6 +163,9 @@ mod test {
126163
lit_expr: &Expr,
127164
_info: &SimplifyContext,
128165
) -> Result<Option<Interval>> {
166+
if !self.enabled {
167+
return Ok(None);
168+
}
129169
if args.len() != 1 {
130170
return Ok(None);
131171
}
@@ -146,19 +186,24 @@ mod test {
146186
}
147187

148188
fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
149-
let simplifier = ExprSimplifier::new(
150-
SimplifyContext::default().with_schema(Arc::clone(schema)),
151-
);
152-
153-
simplifier.simplify(expr).unwrap()
189+
let simplify_context = SimplifyContext::default().with_schema(Arc::clone(schema));
190+
ExprSimplifier::new(simplify_context)
191+
.simplify(expr)
192+
.unwrap()
154193
}
155194

156195
fn preimage_udf_expr() -> Expr {
157-
let udf = ScalarUDF::new_from_impl(PreimageUdf {
158-
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
159-
});
196+
ScalarUDF::new_from_impl(PreimageUdf::new()).call(vec![col("x")])
197+
}
160198

161-
Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), vec![col("x")]))
199+
fn non_immutable_udf_expr() -> Expr {
200+
ScalarUDF::new_from_impl(PreimageUdf::new().with_volatility(Volatility::Volatile))
201+
.call(vec![col("x")])
202+
}
203+
204+
fn no_preimage_udf_expr() -> Expr {
205+
ScalarUDF::new_from_impl(PreimageUdf::new().with_enabled(false))
206+
.call(vec![col("x")])
162207
}
163208

164209
fn test_schema() -> DFSchemaRef {
@@ -171,100 +216,150 @@ mod test {
171216
)
172217
}
173218

219+
fn test_schema_xy() -> DFSchemaRef {
220+
Arc::new(
221+
DFSchema::from_unqualified_fields(
222+
vec![
223+
Field::new("x", DataType::Int32, false),
224+
Field::new("y", DataType::Int32, false),
225+
]
226+
.into(),
227+
Default::default(),
228+
)
229+
.unwrap(),
230+
)
231+
}
232+
174233
#[test]
175234
fn test_preimage_eq_rewrite() {
235+
// Equality rewrite when preimage and column expression are available.
176236
let schema = test_schema();
177-
let expr = binary_expr(preimage_udf_expr(), Operator::Eq, lit(500));
178-
let expected = and(
179-
binary_expr(col("x"), Operator::GtEq, lit(100)),
180-
binary_expr(col("x"), Operator::Lt, lit(200)),
181-
);
237+
let expr = preimage_udf_expr().eq(lit(500));
238+
let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200)));
182239

183240
assert_eq!(optimize_test(expr, &schema), expected);
184241
}
185242

186243
#[test]
187244
fn test_preimage_noteq_rewrite() {
245+
// Inequality rewrite expands to disjoint ranges.
188246
let schema = test_schema();
189-
let expr = binary_expr(preimage_udf_expr(), Operator::NotEq, lit(500));
190-
let expected = binary_expr(col("x"), Operator::Lt, lit(100)).or(binary_expr(
191-
col("x"),
192-
Operator::GtEq,
193-
lit(200),
194-
));
247+
let expr = preimage_udf_expr().not_eq(lit(500));
248+
let expected = col("x").lt(lit(100)).or(col("x").gt_eq(lit(200)));
195249

196250
assert_eq!(optimize_test(expr, &schema), expected);
197251
}
198252

199253
#[test]
200254
fn test_preimage_eq_rewrite_swapped() {
255+
// Equality rewrite works when the literal appears on the left.
201256
let schema = test_schema();
202-
let expr = binary_expr(lit(500), Operator::Eq, preimage_udf_expr());
203-
let expected = and(
204-
binary_expr(col("x"), Operator::GtEq, lit(100)),
205-
binary_expr(col("x"), Operator::Lt, lit(200)),
206-
);
257+
let expr = lit(500).eq(preimage_udf_expr());
258+
let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200)));
207259

208260
assert_eq!(optimize_test(expr, &schema), expected);
209261
}
210262

211263
#[test]
212264
fn test_preimage_lt_rewrite() {
265+
// Less-than comparison rewrites to the lower bound.
213266
let schema = test_schema();
214-
let expr = binary_expr(preimage_udf_expr(), Operator::Lt, lit(500));
215-
let expected = binary_expr(col("x"), Operator::Lt, lit(100));
267+
let expr = preimage_udf_expr().lt(lit(500));
268+
let expected = col("x").lt(lit(100));
216269

217270
assert_eq!(optimize_test(expr, &schema), expected);
218271
}
219272

220273
#[test]
221274
fn test_preimage_lteq_rewrite() {
275+
// Less-than-or-equal comparison rewrites to the upper bound.
222276
let schema = test_schema();
223-
let expr = binary_expr(preimage_udf_expr(), Operator::LtEq, lit(500));
224-
let expected = binary_expr(col("x"), Operator::Lt, lit(200));
277+
let expr = preimage_udf_expr().lt_eq(lit(500));
278+
let expected = col("x").lt(lit(200));
225279

226280
assert_eq!(optimize_test(expr, &schema), expected);
227281
}
228282

229283
#[test]
230284
fn test_preimage_gt_rewrite() {
285+
// Greater-than comparison rewrites to the upper bound (inclusive).
231286
let schema = test_schema();
232-
let expr = binary_expr(preimage_udf_expr(), Operator::Gt, lit(500));
233-
let expected = binary_expr(col("x"), Operator::GtEq, lit(200));
287+
let expr = preimage_udf_expr().gt(lit(500));
288+
let expected = col("x").gt_eq(lit(200));
234289

235290
assert_eq!(optimize_test(expr, &schema), expected);
236291
}
237292

238293
#[test]
239294
fn test_preimage_gteq_rewrite() {
295+
// Greater-than-or-equal comparison rewrites to the lower bound.
240296
let schema = test_schema();
241-
let expr = binary_expr(preimage_udf_expr(), Operator::GtEq, lit(500));
242-
let expected = binary_expr(col("x"), Operator::GtEq, lit(100));
297+
let expr = preimage_udf_expr().gt_eq(lit(500));
298+
let expected = col("x").gt_eq(lit(100));
243299

244300
assert_eq!(optimize_test(expr, &schema), expected);
245301
}
246302

247303
#[test]
248304
fn test_preimage_is_not_distinct_from_rewrite() {
305+
// IS NOT DISTINCT FROM is treated like equality for non-null literal RHS.
249306
let schema = test_schema();
250-
let expr =
251-
binary_expr(preimage_udf_expr(), Operator::IsNotDistinctFrom, lit(500));
252-
let expected = and(
253-
binary_expr(col("x"), Operator::GtEq, lit(100)),
254-
binary_expr(col("x"), Operator::Lt, lit(200)),
255-
);
307+
let expr = is_not_distinct_from(preimage_udf_expr(), lit(500));
308+
let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200)));
256309

257310
assert_eq!(optimize_test(expr, &schema), expected);
258311
}
259312

260313
#[test]
261314
fn test_preimage_is_distinct_from_rewrite() {
315+
// IS DISTINCT FROM adds an explicit NULL branch for the column.
262316
let schema = test_schema();
263-
let expr = binary_expr(preimage_udf_expr(), Operator::IsDistinctFrom, lit(500));
264-
let expected = binary_expr(col("x"), Operator::Lt, lit(100))
265-
.or(binary_expr(col("x"), Operator::GtEq, lit(200)))
317+
let expr = is_distinct_from(preimage_udf_expr(), lit(500));
318+
let expected = col("x")
319+
.lt(lit(100))
320+
.or(col("x").gt_eq(lit(200)))
266321
.or(col("x").is_null());
267322

268323
assert_eq!(optimize_test(expr, &schema), expected);
269324
}
325+
326+
#[test]
327+
fn test_preimage_non_literal_rhs_no_rewrite() {
328+
// Non-literal RHS should not be rewritten.
329+
let schema = test_schema_xy();
330+
let expr = preimage_udf_expr().eq(col("y"));
331+
let expected = expr.clone();
332+
333+
assert_eq!(optimize_test(expr, &schema), expected);
334+
}
335+
336+
#[test]
337+
fn test_preimage_null_literal_no_rewrite() {
338+
// NULL literal RHS should not be rewritten.
339+
let schema = test_schema();
340+
let expr = preimage_udf_expr().eq(lit(ScalarValue::Int32(None)));
341+
let expected = expr.clone();
342+
343+
assert_eq!(optimize_test(expr, &schema), expected);
344+
}
345+
346+
#[test]
347+
fn test_preimage_non_immutable_no_rewrite() {
348+
// Non-immutable UDFs should not participate in preimage rewrites.
349+
let schema = test_schema();
350+
let expr = non_immutable_udf_expr().eq(lit(500));
351+
let expected = expr.clone();
352+
353+
assert_eq!(optimize_test(expr, &schema), expected);
354+
}
355+
356+
#[test]
357+
fn test_preimage_no_preimage_no_rewrite() {
358+
// If the UDF provides no preimage, the expression should remain unchanged.
359+
let schema = test_schema();
360+
let expr = no_preimage_udf_expr().eq(lit(500));
361+
let expected = expr.clone();
362+
363+
assert_eq!(optimize_test(expr, &schema), expected);
364+
}
270365
}

0 commit comments

Comments
 (0)