File tree Expand file tree Collapse file tree 3 files changed +47
-7
lines changed
tests/integration_tests/python Expand file tree Collapse file tree 3 files changed +47
-7
lines changed Original file line number Diff line number Diff line change @@ -438,6 +438,35 @@ impl WasmScalarUdf {
438438 pub fn as_async_udf ( self ) -> AsyncScalarUDF {
439439 AsyncScalarUDF :: new ( Arc :: new ( self ) )
440440 }
441+
442+ /// Check that the provided argument types match the UDF signature.
443+ fn check_arg_types ( & self , arg_types : & [ DataType ] ) -> DataFusionResult < ( ) > {
444+ if let TypeSignature :: Exact ( expected_types) = & self . signature . type_signature {
445+ if arg_types. len ( ) != expected_types. len ( ) {
446+ return Err ( DataFusionError :: Plan ( format ! (
447+ "`{}` expects {} parameters but got {}" ,
448+ self . name,
449+ expected_types. len( ) ,
450+ arg_types. len( )
451+ ) ) ) ;
452+ }
453+
454+ for ( i, ( provided, expected) ) in arg_types. iter ( ) . zip ( expected_types. iter ( ) ) . enumerate ( )
455+ {
456+ if provided != expected {
457+ return Err ( DataFusionError :: Plan ( format ! (
458+ "argument {} of `{}` should be {:?}, got {:?}" ,
459+ i + 1 ,
460+ self . name,
461+ expected,
462+ provided
463+ ) ) ) ;
464+ }
465+ }
466+ } ;
467+
468+ Ok ( ( ) )
469+ }
441470}
442471
443472impl std:: fmt:: Debug for WasmScalarUdf {
@@ -476,6 +505,8 @@ impl ScalarUDFImpl for WasmScalarUdf {
476505 }
477506
478507 fn return_type ( & self , arg_types : & [ DataType ] ) -> DataFusionResult < DataType > {
508+ self . check_arg_types ( arg_types) ?;
509+
479510 if let Some ( return_type) = & self . return_type {
480511 return Ok ( return_type. clone ( ) ) ;
481512 }
Original file line number Diff line number Diff line change @@ -32,9 +32,10 @@ def add_one(x: int) -> int:
3232 udf. return_type( & [ DataType :: Int64 ] ) . unwrap( ) ,
3333 DataType :: Int64 ,
3434 ) ;
35- // Signature is exact, so it would have been cached;
36- // thus - no error even with a wrong type of argument.
37- assert_eq ! ( udf. return_type( & [ ] ) . unwrap( ) , DataType :: Int64 , ) ;
35+ assert_eq ! (
36+ udf. return_type( & [ ] ) . unwrap_err( ) . to_string( ) ,
37+ "Error during planning: `add_one` expects 1 parameters but got 0" ,
38+ ) ;
3839
3940 // call with array
4041 let array = udf
Original file line number Diff line number Diff line change @@ -21,11 +21,19 @@ def foo(x: int) -> int:
2121
2222 let udf = python_scalar_udf ( CODE ) . await . unwrap ( ) ;
2323
24- // Signature is exact, so it would have been cached;
25- // thus - no error even with a wrong type of argument.
2624 insta:: assert_snapshot!(
27- udf. return_type( & [ ] ) . unwrap( ) ,
28- @"Int64" ,
25+ udf. return_type( & [ ] ) . unwrap_err( ) ,
26+ @"Error during planning: `foo` expects 1 parameters but got 0" ,
27+ ) ;
28+
29+ insta:: assert_snapshot!( udf. return_type(
30+ & [ DataType :: Int64 , DataType :: Int64 ] ) . unwrap_err( ) ,
31+ @"Error during planning: `foo` expects 1 parameters but got 2" ,
32+ ) ;
33+
34+ insta:: assert_snapshot!(
35+ udf. return_type( & [ DataType :: Float64 ] ) . unwrap_err( ) ,
36+ @"Error during planning: argument 1 of `foo` should be Int64, got Float64" ,
2937 ) ;
3038}
3139
You can’t perform that action at this time.
0 commit comments