Skip to content

Commit c6bdb87

Browse files
committed
fix: validate arg_types when requesting return_type
1 parent c9b344b commit c6bdb87

File tree

3 files changed

+47
-7
lines changed

3 files changed

+47
-7
lines changed

host/src/lib.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,35 @@ impl WasmScalarUdf {
438438
pub fn as_async_udf(self) -> AsyncScalarUDF {
439439
AsyncScalarUDF::new(Arc::new(self))
440440
}
441+
442+
/// Check that the provided argument types match the UDF signature.
443+
fn check_arg_types(&self, arg_types: &[DataType]) -> DataFusionResult<()> {
444+
if let TypeSignature::Exact(expected_types) = &self.signature.type_signature {
445+
if arg_types.len() != expected_types.len() {
446+
return Err(DataFusionError::Plan(format!(
447+
"`{}` expects {} parameters but got {}",
448+
self.name,
449+
expected_types.len(),
450+
arg_types.len()
451+
)));
452+
}
453+
454+
for (i, (provided, expected)) in arg_types.iter().zip(expected_types.iter()).enumerate()
455+
{
456+
if provided != expected {
457+
return Err(DataFusionError::Plan(format!(
458+
"argument {} of `{}` should be {:?}, got {:?}",
459+
i + 1,
460+
self.name,
461+
expected,
462+
provided
463+
)));
464+
}
465+
}
466+
};
467+
468+
Ok(())
469+
}
441470
}
442471

443472
impl std::fmt::Debug for WasmScalarUdf {
@@ -476,6 +505,8 @@ impl ScalarUDFImpl for WasmScalarUdf {
476505
}
477506

478507
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
508+
self.check_arg_types(arg_types)?;
509+
479510
if let Some(return_type) = &self.return_type {
480511
return Ok(return_type.clone());
481512
}

host/tests/integration_tests/python/examples.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ def add_one(x: int) -> int:
3232
udf.return_type(&[DataType::Int64]).unwrap(),
3333
DataType::Int64,
3434
);
35-
// Signature is exact, so it would have been cached;
36-
// thus - no error even with a wrong type of argument.
37-
assert_eq!(udf.return_type(&[]).unwrap(), DataType::Int64,);
35+
assert_eq!(
36+
udf.return_type(&[]).unwrap_err().to_string(),
37+
"Error during planning: `add_one` expects 1 parameters but got 0",
38+
);
3839

3940
// call with array
4041
let array = udf

host/tests/integration_tests/python/runtime/errors.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,19 @@ def foo(x: int) -> int:
2121

2222
let udf = python_scalar_udf(CODE).await.unwrap();
2323

24-
// Signature is exact, so it would have been cached;
25-
// thus - no error even with a wrong type of argument.
2624
insta::assert_snapshot!(
27-
udf.return_type(&[]).unwrap(),
28-
@"Int64",
25+
udf.return_type(&[]).unwrap_err(),
26+
@"Error during planning: `foo` expects 1 parameters but got 0",
27+
);
28+
29+
insta::assert_snapshot!(udf.return_type(
30+
&[DataType::Int64, DataType::Int64]).unwrap_err(),
31+
@"Error during planning: `foo` expects 1 parameters but got 2",
32+
);
33+
34+
insta::assert_snapshot!(
35+
udf.return_type(&[DataType::Float64]).unwrap_err(),
36+
@"Error during planning: argument 1 of `foo` should be Int64, got Float64",
2937
);
3038
}
3139

0 commit comments

Comments
 (0)