@@ -41,12 +41,21 @@ pub async fn from_scalar_function(
4141 f. function_reference
4242 ) ;
4343 } ;
44+
4445 let fn_name = substrait_fun_name ( fn_signature) ;
4546 let args = from_substrait_func_args ( consumer, & f. arguments , input_schema) . await ?;
4647
48+ let udf_func = consumer. get_function_registry ( ) . udf ( fn_name) . or_else ( |e| {
49+ if let Some ( alt_name) = substrait_to_df_name ( fn_name) {
50+ consumer. get_function_registry ( ) . udf ( alt_name) . or ( Err ( e) )
51+ } else {
52+ Err ( e)
53+ }
54+ } ) ;
55+
4756 // try to first match the requested function into registered udfs, then built-in ops
4857 // and finally built-in expressions
49- if let Ok ( func) = consumer . get_function_registry ( ) . udf ( fn_name ) {
58+ if let Ok ( func) = udf_func {
5059 Ok ( Expr :: ScalarFunction ( expr:: ScalarFunction :: new_udf (
5160 func. to_owned ( ) ,
5261 args,
@@ -113,6 +122,13 @@ pub fn name_to_op(name: &str) -> Option<Operator> {
113122 }
114123}
115124
125+ pub fn substrait_to_df_name ( name : & str ) -> Option < & str > {
126+ match name {
127+ "is_nan" => Some ( "isnan" ) ,
128+ _ => None ,
129+ }
130+ }
131+
116132/// Build a balanced tree of binary operations from a binary operator and a list of arguments.
117133///
118134/// For example, `OR` `(a, b, c, d, e)` will be converted to: `OR(OR(a, OR(b, c)), OR(d, e))`.
@@ -369,4 +385,39 @@ mod tests {
369385 assert_snapshot ! ( expr. to_string( ) , @"Int64(1) OR Int64(2) OR Int64(3) OR Int64(4)" ) ;
370386 Ok ( ( ) )
371387 }
388+
389+ //Test that DataFusion can consume scalar functions that have a different name in Substrait
390+ #[ tokio:: test]
391+ async fn test_substrait_to_df_name_mapping ( ) -> Result < ( ) > {
392+ // Build substrait extensions (we are using only one function)
393+ let mut extensions = Extensions :: default ( ) ;
394+ //is_nan is one of the functions that has a different name in Substrait (mapping is in substrait_to_df_name())
395+ extensions. functions . insert ( 0 , String :: from ( "is_nan:fp32" ) ) ;
396+ // Build substrait consumer
397+ let consumer = DefaultSubstraitConsumer :: new ( & extensions, & TEST_SESSION_STATE ) ;
398+
399+ // Build arguments for the function call
400+ let arg = FunctionArgument {
401+ arg_type : Some ( ArgType :: Value ( Expression {
402+ rex_type : Some ( RexType :: Literal ( Literal {
403+ nullable : false ,
404+ type_variation_reference : 0 ,
405+ literal_type : Some ( LiteralType :: Fp32 ( 1.0 ) ) ,
406+ } ) ) ,
407+ } ) ) ,
408+ } ;
409+ let arguments = vec ! [ arg] ;
410+ let func = ScalarFunction {
411+ function_reference : 0 ,
412+ arguments,
413+ ..Default :: default ( )
414+ } ;
415+ // Trivial input schema
416+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Float32 , false ) ] ) ;
417+ let df_schema = DFSchema :: try_from ( schema) . unwrap ( ) ;
418+
419+ // Consume the expression and ensure we don't get an error
420+ let _ = consumer. consume_scalar_function ( & func, & df_schema) . await ?;
421+ Ok ( ( ) )
422+ }
372423}
0 commit comments