Skip to content

Commit bf6f631

Browse files
authored
Add missing Substrait to DataFusion function name mappings (#16950)
* Add missing function mappings * Added roundtrip test * Quick fix * Tests for all mappings, refactored logic * Quick fix to log test * Removed logb and sign from name mappings changes
1 parent 5f26e70 commit bf6f631

File tree

3 files changed

+88
-1
lines changed

3 files changed

+88
-1
lines changed

datafusion/substrait/src/extensions.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ impl Extensions {
4545
// Rename those to match the Substrait extensions for interoperability
4646
let function_name = match function_name.as_str() {
4747
"substr" => "substring".to_string(),
48+
"isnan" => "is_nan".to_string(),
4849
_ => function_name,
4950
};
5051

datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,21 @@ pub async fn from_scalar_function(
4141
f.function_reference
4242
);
4343
};
44+
4445
let fn_name = substrait_fun_name(fn_signature);
4546
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;
4647

48+
let udf_func = consumer.get_function_registry().udf(fn_name).or_else(|e| {
49+
if let Some(alt_name) = substrait_to_df_name(fn_name) {
50+
consumer.get_function_registry().udf(alt_name).or(Err(e))
51+
} else {
52+
Err(e)
53+
}
54+
});
55+
4756
// try to first match the requested function into registered udfs, then built-in ops
4857
// and finally built-in expressions
49-
if let Ok(func) = consumer.get_function_registry().udf(fn_name) {
58+
if let Ok(func) = udf_func {
5059
Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
5160
func.to_owned(),
5261
args,
@@ -113,6 +122,13 @@ pub fn name_to_op(name: &str) -> Option<Operator> {
113122
}
114123
}
115124

125+
pub fn substrait_to_df_name(name: &str) -> Option<&str> {
126+
match name {
127+
"is_nan" => Some("isnan"),
128+
_ => None,
129+
}
130+
}
131+
116132
/// Build a balanced tree of binary operations from a binary operator and a list of arguments.
117133
///
118134
/// For example, `OR` `(a, b, c, d, e)` will be converted to: `OR(OR(a, OR(b, c)), OR(d, e))`.
@@ -369,4 +385,39 @@ mod tests {
369385
assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2) OR Int64(3) OR Int64(4)");
370386
Ok(())
371387
}
388+
389+
//Test that DataFusion can consume scalar functions that have a different name in Substrait
390+
#[tokio::test]
391+
async fn test_substrait_to_df_name_mapping() -> Result<()> {
392+
// Build substrait extensions (we are using only one function)
393+
let mut extensions = Extensions::default();
394+
//is_nan is one of the functions that has a different name in Substrait (mapping is in substrait_to_df_name())
395+
extensions.functions.insert(0, String::from("is_nan:fp32"));
396+
// Build substrait consumer
397+
let consumer = DefaultSubstraitConsumer::new(&extensions, &TEST_SESSION_STATE);
398+
399+
// Build arguments for the function call
400+
let arg = FunctionArgument {
401+
arg_type: Some(ArgType::Value(Expression {
402+
rex_type: Some(RexType::Literal(Literal {
403+
nullable: false,
404+
type_variation_reference: 0,
405+
literal_type: Some(LiteralType::Fp32(1.0)),
406+
})),
407+
})),
408+
};
409+
let arguments = vec![arg];
410+
let func = ScalarFunction {
411+
function_reference: 0,
412+
arguments,
413+
..Default::default()
414+
};
415+
// Trivial input schema
416+
let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
417+
let df_schema = DFSchema::try_from(schema).unwrap();
418+
419+
// Consume the expression and ensure we don't get an error
420+
let _ = consumer.consume_scalar_function(&func, &df_schema).await?;
421+
Ok(())
422+
}
372423
}

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,41 @@ async fn simple_scalar_function_substr() -> Result<()> {
426426
roundtrip("SELECT SUBSTR(f, 1, 3) FROM data").await
427427
}
428428

429+
// Test that DataFusion functions gets correctly mapped to Substrait names (when the names are diferent)
430+
// Follows the same structure as existing roundtrip tests, but more explicitly tests for name mappings
431+
async fn test_substrait_to_df_name_mapping(
432+
substrait_name: &str,
433+
sql: &str,
434+
) -> Result<()> {
435+
let ctx = create_context().await?;
436+
let df = ctx.sql(sql).await?;
437+
let plan = df.into_optimized_plan()?;
438+
let proto = to_substrait_plan(&plan, &ctx.state())?;
439+
440+
let function_name = match proto.extensions[0].mapping_type.as_ref().unwrap() {
441+
MappingType::ExtensionFunction(ext_f) => &ext_f.name,
442+
_ => unreachable!("Expected function extension"),
443+
};
444+
445+
assert_eq!(function_name, substrait_name);
446+
447+
let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
448+
let plan2 = ctx.state().optimize(&plan2)?;
449+
450+
let plan1str = format!("{plan}");
451+
let plan2str = format!("{plan2}");
452+
assert_eq!(plan1str, plan2str);
453+
454+
assert_eq!(plan.schema(), plan2.schema());
455+
456+
Ok(())
457+
}
458+
459+
#[tokio::test]
460+
async fn scalar_function_is_nan_mapping() -> Result<()> {
461+
test_substrait_to_df_name_mapping("is_nan", "SELECT ISNAN(a) FROM data").await
462+
}
463+
429464
#[tokio::test]
430465
async fn simple_scalar_function_is_null() -> Result<()> {
431466
roundtrip("SELECT * FROM data WHERE a IS NULL").await

0 commit comments

Comments
 (0)