Skip to content

Commit 8147565

Browse files
authored
Support Substrait functions and_not, xor, and between in consumer built-in expression builder (#16984)
* Added support for functions with consumer's build in expression builder * Format fix * Support for logb * Use built in functions for binary expressions * Added a comment * Quick fix * Simplified producer logic * quick fix
1 parent 2bbd6a1 commit 8147565

10 files changed

+837
-43
lines changed

datafusion/substrait/src/extensions.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ impl Extensions {
4545
// Rename those to match the Substrait extensions for interoperability
4646
let function_name = match function_name.as_str() {
4747
"substr" => "substring".to_string(),
48+
"log" => "logb".to_string(),
4849
"isnan" => "is_nan".to_string(),
4950
_ => function_name,
5051
};

datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs

Lines changed: 99 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@ use datafusion::common::{
2121
not_impl_err, plan_err, substrait_err, DFSchema, DataFusionError, ScalarValue,
2222
};
2323
use datafusion::execution::FunctionRegistry;
24-
use datafusion::logical_expr::{expr, BinaryExpr, Expr, Like, Operator};
24+
use datafusion::logical_expr::{expr, Between, BinaryExpr, Expr, Like, Operator};
2525
use std::vec::Drain;
2626
use substrait::proto::expression::ScalarFunction;
27-
use substrait::proto::function_argument::ArgType;
2827

2928
pub async fn from_scalar_function(
3029
consumer: &impl SubstraitConsumer,
@@ -70,7 +69,7 @@ pub async fn from_scalar_function(
7069
// In those cases we build a balanced tree of BinaryExprs
7170
arg_list_to_binary_op_tree(op, args)
7271
} else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) {
73-
builder.build(consumer, f, input_schema).await
72+
builder.build(consumer, f, args).await
7473
} else {
7574
not_impl_err!("Unsupported function name: {fn_name:?}")
7675
}
@@ -180,7 +179,8 @@ impl BuiltinExprBuilder {
180179
match name {
181180
"not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true"
182181
| "is_false" | "is_not_true" | "is_not_false" | "is_unknown"
183-
| "is_not_unknown" | "negative" | "negate" => Some(Self {
182+
| "is_not_unknown" | "negative" | "negate" | "and_not" | "xor"
183+
| "between" | "logb" => Some(Self {
184184
expr_name: name.to_string(),
185185
}),
186186
_ => None,
@@ -191,37 +191,30 @@ impl BuiltinExprBuilder {
191191
self,
192192
consumer: &impl SubstraitConsumer,
193193
f: &ScalarFunction,
194-
input_schema: &DFSchema,
194+
args: Vec<Expr>,
195195
) -> Result<Expr> {
196196
match self.expr_name.as_str() {
197-
"like" => Self::build_like_expr(consumer, false, f, input_schema).await,
198-
"ilike" => Self::build_like_expr(consumer, true, f, input_schema).await,
197+
"like" => Self::build_like_expr(false, f, args).await,
198+
"ilike" => Self::build_like_expr(true, f, args).await,
199199
"not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true"
200200
| "is_false" | "is_not_true" | "is_not_false" | "is_unknown"
201-
| "is_not_unknown" => {
202-
Self::build_unary_expr(consumer, &self.expr_name, f, input_schema).await
201+
| "is_not_unknown" => Self::build_unary_expr(&self.expr_name, args).await,
202+
"and_not" | "xor" => Self::build_binary_expr(&self.expr_name, args).await,
203+
"between" => Self::build_between_expr(&self.expr_name, args).await,
204+
"logb" => {
205+
Self::build_custom_handling_expr(consumer, &self.expr_name, args).await
203206
}
204207
_ => {
205208
not_impl_err!("Unsupported builtin expression: {}", self.expr_name)
206209
}
207210
}
208211
}
209212

210-
async fn build_unary_expr(
211-
consumer: &impl SubstraitConsumer,
212-
fn_name: &str,
213-
f: &ScalarFunction,
214-
input_schema: &DFSchema,
215-
) -> Result<Expr> {
216-
if f.arguments.len() != 1 {
217-
return substrait_err!("Expect one argument for {fn_name} expr");
218-
}
219-
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
220-
return substrait_err!("Invalid arguments type for {fn_name} expr");
213+
async fn build_unary_expr(fn_name: &str, args: Vec<Expr>) -> Result<Expr> {
214+
let [arg] = match args.try_into() {
215+
Ok(args_arr) => args_arr,
216+
Err(_) => return substrait_err!("Expected one argument for {fn_name} expr"),
221217
};
222-
let arg = consumer
223-
.consume_expression(expr_substrait, input_schema)
224-
.await?;
225218
let arg = Box::new(arg);
226219

227220
let expr = match fn_name {
@@ -242,40 +235,29 @@ impl BuiltinExprBuilder {
242235
}
243236

244237
async fn build_like_expr(
245-
consumer: &impl SubstraitConsumer,
246238
case_insensitive: bool,
247239
f: &ScalarFunction,
248-
input_schema: &DFSchema,
240+
args: Vec<Expr>,
249241
) -> Result<Expr> {
250242
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
251-
if f.arguments.len() != 2 && f.arguments.len() != 3 {
243+
if args.len() != 2 && args.len() != 3 {
252244
return substrait_err!("Expect two or three arguments for `{fn_name}` expr");
253245
}
254246

255-
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
256-
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
247+
let mut args_iter = args.into_iter();
248+
let Some(expr) = args_iter.next() else {
249+
return substrait_err!("Missing first argument for {fn_name} expression");
257250
};
258-
let expr = consumer
259-
.consume_expression(expr_substrait, input_schema)
260-
.await?;
261-
let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else {
262-
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
251+
let Some(pattern) = args_iter.next() else {
252+
return substrait_err!("Missing second argument for {fn_name} expression");
263253
};
264-
let pattern = consumer
265-
.consume_expression(pattern_substrait, input_schema)
266-
.await?;
267254

268255
// Default case: escape character is Literal(Utf8(None))
269256
let escape_char = if f.arguments.len() == 3 {
270-
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type
271-
else {
272-
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
257+
let Some(escape_char_expr) = args_iter.next() else {
258+
return substrait_err!("Missing third argument for {fn_name} expression");
273259
};
274260

275-
let escape_char_expr = consumer
276-
.consume_expression(escape_char_substrait, input_schema)
277-
.await?;
278-
279261
match escape_char_expr {
280262
Expr::Literal(ScalarValue::Utf8(escape_char_string), _) => {
281263
// Convert Option<String> to Option<char>
@@ -299,6 +281,80 @@ impl BuiltinExprBuilder {
299281
case_insensitive,
300282
}))
301283
}
284+
285+
async fn build_binary_expr(fn_name: &str, args: Vec<Expr>) -> Result<Expr> {
286+
let [a, b] = match args.try_into() {
287+
Ok(args_arr) => args_arr,
288+
Err(_) => {
289+
return substrait_err!("Expected two arguments for `{fn_name}` expr")
290+
}
291+
};
292+
match fn_name {
293+
"and_not" => Ok(Self::build_and_not_expr(a, b)),
294+
"xor" => Ok(Self::build_xor_expr(a, b)),
295+
_ => not_impl_err!("Unsupported builtin expression: {}", fn_name),
296+
}
297+
}
298+
299+
fn build_and_not_expr(a: Expr, b: Expr) -> Expr {
300+
a.and(Expr::Not(Box::new(b)))
301+
}
302+
303+
fn build_xor_expr(a: Expr, b: Expr) -> Expr {
304+
let or_expr = a.clone().or(b.clone());
305+
let and_expr = a.and(b);
306+
Self::build_and_not_expr(or_expr, and_expr)
307+
}
308+
309+
async fn build_between_expr(fn_name: &str, args: Vec<Expr>) -> Result<Expr> {
310+
let [expression, low, high] = match args.try_into() {
311+
Ok(args_arr) => args_arr,
312+
Err(_) => {
313+
return substrait_err!("Expected three arguments for `{fn_name}` expr")
314+
}
315+
};
316+
317+
Ok(Expr::Between(Between {
318+
expr: Box::new(expression),
319+
negated: false,
320+
low: Box::new(low),
321+
high: Box::new(high),
322+
}))
323+
}
324+
325+
//This handles any functions that require custom handling
326+
async fn build_custom_handling_expr(
327+
consumer: &impl SubstraitConsumer,
328+
fn_name: &str,
329+
args: Vec<Expr>,
330+
) -> Result<Expr> {
331+
match fn_name {
332+
"logb" => Self::build_logb_expr(consumer, args).await,
333+
_ => not_impl_err!("Unsupported custom handled expression: {}", fn_name),
334+
}
335+
}
336+
337+
async fn build_logb_expr(
338+
consumer: &impl SubstraitConsumer,
339+
args: Vec<Expr>,
340+
) -> Result<Expr> {
341+
if args.len() != 2 {
342+
return substrait_err!("Expect two arguments for logb function");
343+
}
344+
345+
let mut args = args;
346+
args.swap(0, 1);
347+
348+
//The equivalent of logb in DataFusion is the log function (which has its arguments in reverse order)
349+
if let Ok(func) = consumer.get_function_registry().udf("log") {
350+
Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
351+
func.to_owned(),
352+
args,
353+
)))
354+
} else {
355+
not_impl_err!("Unsupported function name: logb")
356+
}
357+
}
302358
}
303359

304360
#[cfg(test)]

datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ pub fn from_scalar_function(
3434
});
3535
}
3636

37+
let arguments = custom_argument_handler(fun.name(), arguments);
38+
3739
let function_anchor = producer.register_function(fun.name().to_string());
3840
#[allow(deprecated)]
3941
Ok(Expression {
@@ -47,6 +49,25 @@ pub fn from_scalar_function(
4749
})
4850
}
4951

52+
// Handle functions that require custom handling for their arguments (e.g. log)
53+
pub fn custom_argument_handler(
54+
name: &str,
55+
args: Vec<FunctionArgument>,
56+
) -> Vec<FunctionArgument> {
57+
match name {
58+
"log" => {
59+
if args.len() == 2 {
60+
let mut args = args;
61+
args.swap(0, 1);
62+
args
63+
} else {
64+
args
65+
}
66+
}
67+
_ => args,
68+
}
69+
}
70+
5071
pub fn from_unary_expr(
5172
producer: &mut impl SubstraitProducer,
5273
expr: &Expr,
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! There are some Substrait functions that are semantically equivalent to nested built-in expressions, such as xor:bool_bool and and_not:bool_bool
19+
//! This module tests that the semantics of these functions are correct roundtripped
20+
21+
#[cfg(test)]
22+
mod tests {
23+
use crate::utils::test::add_plan_schemas_to_ctx;
24+
use datafusion::arrow::util::pretty;
25+
use datafusion::common::Result;
26+
use datafusion::prelude::DataFrame;
27+
use datafusion::prelude::SessionContext;
28+
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
29+
use datafusion_substrait::logical_plan::producer::to_substrait_plan;
30+
use std::fs::File;
31+
use std::io::BufReader;
32+
use substrait::proto::Plan;
33+
34+
// Helper function to test scalar function semantics and roundtrip conversion
35+
async fn test_scalar_fn_semantics(
36+
file_path: &str,
37+
expected_results: Vec<&str>,
38+
) -> Result<()> {
39+
let path = format!("tests/testdata/test_plans/{file_path}");
40+
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
41+
File::open(path).expect("file not found"),
42+
))
43+
.expect("failed to parse json");
44+
45+
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
46+
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
47+
48+
// Test correct semantics of function
49+
let df = DataFrame::new(ctx.state().clone(), plan.clone());
50+
let results = df.collect().await?;
51+
let pretty_results = pretty::pretty_format_batches(&results)?.to_string();
52+
assert_eq!(
53+
pretty_results.trim().lines().collect::<Vec<_>>(),
54+
expected_results
55+
);
56+
57+
// Test roundtrip semantics
58+
let proto = to_substrait_plan(&plan, &ctx.state())?;
59+
let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
60+
let df2 = DataFrame::new(ctx.state().clone(), plan2.clone());
61+
let results2 = df2.collect().await?;
62+
let pretty_results2 = pretty::pretty_format_batches(&results2)?.to_string();
63+
assert_eq!(
64+
pretty_results2.trim().lines().collect::<Vec<_>>(),
65+
expected_results
66+
);
67+
68+
Ok(())
69+
}
70+
71+
#[tokio::test]
72+
async fn test_xor_semantics() -> Result<()> {
73+
let expected = vec![
74+
"+-------+-------+--------+",
75+
"| a | b | result |",
76+
"+-------+-------+--------+",
77+
"| true | true | false |",
78+
"| true | false | true |",
79+
"| false | true | true |",
80+
"| false | false | false |",
81+
"+-------+-------+--------+",
82+
];
83+
84+
test_scalar_fn_semantics(
85+
"scalar_fn_to_built_in_binary_expr_xor.substrait.json",
86+
expected,
87+
)
88+
.await
89+
}
90+
91+
#[tokio::test]
92+
async fn test_and_not_semantics() -> Result<()> {
93+
let expected = vec![
94+
"+-------+-------+--------+",
95+
"| a | b | result |",
96+
"+-------+-------+--------+",
97+
"| true | true | false |",
98+
"| true | false | true |",
99+
"| false | true | false |",
100+
"| false | false | false |",
101+
"+-------+-------+--------+",
102+
];
103+
104+
test_scalar_fn_semantics(
105+
"scalar_fn_to_built_in_binary_expr_and_not.substrait.json",
106+
expected,
107+
)
108+
.await
109+
}
110+
111+
#[tokio::test]
112+
async fn test_logb_semantics() -> Result<()> {
113+
let expected = vec![
114+
"+-------+------+--------+",
115+
"| x | base | result |",
116+
"+-------+------+--------+",
117+
"| 1.0 | 10.0 | 0.0 |",
118+
"| 100.0 | 10.0 | 2.0 |",
119+
"+-------+------+--------+",
120+
];
121+
122+
test_scalar_fn_semantics("scalar_fn_logb_expr.substrait.json", expected).await
123+
}
124+
}

0 commit comments

Comments
 (0)