Skip to content

Commit bf04a2f

Browse files
authored
Merge pull request #161 from influxdata/tm/return-type-async
feat: cache return type if function signature is exact
2 parents c586f10 + 017b48e commit bf04a2f

File tree

24 files changed

+187
-109
lines changed

24 files changed

+187
-109
lines changed

host/src/lib.rs

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use ::http::HeaderName;
88
use arrow::datatypes::DataType;
99
use datafusion_common::{DataFusionError, Result as DataFusionResult, config::ConfigOptions};
1010
use datafusion_expr::{
11-
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
11+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
1212
async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl},
1313
};
1414
use tokio::{runtime::Handle, sync::Mutex};
@@ -313,21 +313,31 @@ pub struct WasmScalarUdf {
313313

314314
/// Name of the UDF.
315315
///
316-
/// This was pre-fetched during UDF generation because [`ScalarUDFImpl::name`] is sync and requires us to return a
317-
/// reference.
316+
/// This was pre-fetched during UDF generation because
317+
/// [`ScalarUDFImpl::name`] is sync and requires us to return a reference.
318318
name: String,
319319

320320
/// Signature of the UDF.
321321
///
322-
/// This was pre-fetched during UDF generation because [`ScalarUDFImpl::signature`] is sync and requires us to return a
322+
/// This was pre-fetched during UDF generation because
323+
/// [`ScalarUDFImpl::signature`] is sync and requires us to return a
323324
/// reference.
324325
signature: Signature,
326+
327+
/// Return type of the UDF.
328+
///
329+
/// This was pre-fetched during UDF generation because
330+
/// [`ScalarUDFImpl::return_type`] is sync and requires us to return a
331+
/// reference. We can only compute the return type if the underlying
332+
/// [TypeSignature] is [Exact](TypeSignature::Exact).
333+
return_type: Option<DataType>,
325334
}
326335

327336
impl WasmScalarUdf {
328337
/// Create multiple UDFs from a single WASM VM.
329338
///
330-
/// UDFs bound to the same VM share state, however calling this method multiple times will yield independent WASM VMs.
339+
/// UDFs bound to the same VM share state, however calling this method
340+
/// multiple times will yield independent WASM VMs.
331341
pub async fn new(
332342
component: &WasmComponentPrecompiled,
333343
permissions: &WasmPermissions,
@@ -404,7 +414,7 @@ impl WasmScalarUdf {
404414
)?;
405415

406416
let store2: &mut Store<WasmStateImpl> = &mut store_guard;
407-
let signature = bindings
417+
let signature: Signature = bindings
408418
.datafusion_udf_wasm_udf_types()
409419
.scalar_udf()
410420
.call_signature(store2, resource)
@@ -415,12 +425,36 @@ impl WasmScalarUdf {
415425
)?
416426
.try_into()?;
417427

428+
let return_type = match &signature.type_signature {
429+
TypeSignature::Exact(t) => {
430+
let store2: &mut Store<WasmStateImpl> = &mut store_guard;
431+
let r = bindings
432+
.datafusion_udf_wasm_udf_types()
433+
.scalar_udf()
434+
.call_return_type(
435+
store2,
436+
resource,
437+
&t.iter()
438+
.map(|dt| wit_types::DataType::from(dt.clone()))
439+
.collect::<Vec<_>>(),
440+
)
441+
.await
442+
.context(
443+
"call ScalarUdf::return_type",
444+
Some(&store_guard.data().stderr.contents()),
445+
)??;
446+
Some(r.try_into()?)
447+
}
448+
_ => None,
449+
};
450+
418451
udfs.push(Self {
419452
store: Arc::clone(&store),
420453
bindings: Arc::clone(&bindings),
421454
resource,
422455
name,
423456
signature,
457+
return_type,
424458
});
425459
}
426460

@@ -431,6 +465,35 @@ impl WasmScalarUdf {
431465
pub fn as_async_udf(self) -> AsyncScalarUDF {
432466
AsyncScalarUDF::new(Arc::new(self))
433467
}
468+
469+
/// Check that the provided argument types match the UDF signature.
470+
fn check_arg_types(&self, arg_types: &[DataType]) -> DataFusionResult<()> {
471+
if let TypeSignature::Exact(expected_types) = &self.signature.type_signature {
472+
if arg_types.len() != expected_types.len() {
473+
return Err(DataFusionError::Plan(format!(
474+
"`{}` expects {} parameters but got {}",
475+
self.name,
476+
expected_types.len(),
477+
arg_types.len()
478+
)));
479+
}
480+
481+
for (i, (provided, expected)) in arg_types.iter().zip(expected_types.iter()).enumerate()
482+
{
483+
if provided != expected {
484+
return Err(DataFusionError::Plan(format!(
485+
"argument {} of `{}` should be {:?}, got {:?}",
486+
i + 1,
487+
self.name,
488+
expected,
489+
provided
490+
)));
491+
}
492+
}
493+
}
494+
495+
Ok(())
496+
}
434497
}
435498

436499
impl std::fmt::Debug for WasmScalarUdf {
@@ -441,6 +504,7 @@ impl std::fmt::Debug for WasmScalarUdf {
441504
resource,
442505
name,
443506
signature,
507+
return_type,
444508
} = self;
445509

446510
f.debug_struct("WasmScalarUdf")
@@ -449,6 +513,7 @@ impl std::fmt::Debug for WasmScalarUdf {
449513
.field("resource", resource)
450514
.field("name", name)
451515
.field("signature", signature)
516+
.field("return_type", return_type)
452517
.finish()
453518
}
454519
}
@@ -467,6 +532,12 @@ impl ScalarUDFImpl for WasmScalarUdf {
467532
}
468533

469534
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
535+
self.check_arg_types(arg_types)?;
536+
537+
if let Some(return_type) = &self.return_type {
538+
return Ok(return_type.clone());
539+
}
540+
470541
async_in_sync_context(async {
471542
let arg_types = arg_types
472543
.iter()

host/tests/integration_tests/python/argument_forms.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use datafusion_expr::{
1515

1616
use crate::integration_tests::python::test_utils::python_scalar_udf;
1717

18-
#[tokio::test(flavor = "multi_thread")]
18+
#[tokio::test]
1919
async fn test_positional_or_keyword() {
2020
const CODE: &str = "
2121
def foo(x: int) -> int:
@@ -55,7 +55,7 @@ def foo(x: int) -> int:
5555
);
5656
}
5757

58-
#[tokio::test(flavor = "multi_thread")]
58+
#[tokio::test]
5959
async fn test_positional_or_keyword_default() {
6060
const CODE: &str = "
6161
def foo(x: int = 1) -> int:
@@ -77,7 +77,7 @@ def foo(x: int = 1) -> int:
7777
);
7878
}
7979

80-
#[tokio::test(flavor = "multi_thread")]
80+
#[tokio::test]
8181
async fn test_positional_only() {
8282
const CODE: &str = "
8383
def foo(x: int, /) -> int:
@@ -117,7 +117,7 @@ def foo(x: int, /) -> int:
117117
);
118118
}
119119

120-
#[tokio::test(flavor = "multi_thread")]
120+
#[tokio::test]
121121
async fn test_positional_only_default() {
122122
const CODE: &str = "
123123
def foo(x: int = 1, /) -> int:
@@ -139,7 +139,7 @@ def foo(x: int = 1, /) -> int:
139139
);
140140
}
141141

142-
#[tokio::test(flavor = "multi_thread")]
142+
#[tokio::test]
143143
async fn test_positional_or_keyword_and_positional_only() {
144144
const CODE: &str = "
145145
def foo(x: int, /, y: int) -> int:
@@ -182,7 +182,7 @@ def foo(x: int, /, y: int) -> int:
182182
);
183183
}
184184

185-
#[tokio::test(flavor = "multi_thread")]
185+
#[tokio::test]
186186
async fn test_var_positional() {
187187
const CODE: &str = "
188188
def foo(*x: int) -> int:
@@ -204,7 +204,7 @@ def foo(*x: int) -> int:
204204
);
205205
}
206206

207-
#[tokio::test(flavor = "multi_thread")]
207+
#[tokio::test]
208208
async fn test_keyword_only() {
209209
const CODE: &str = "
210210
def foo(*, x: int) -> int:
@@ -226,7 +226,7 @@ def foo(*, x: int) -> int:
226226
);
227227
}
228228

229-
#[tokio::test(flavor = "multi_thread")]
229+
#[tokio::test]
230230
async fn test_var_keyword() {
231231
const CODE: &str = "
232232
def foo(**x: int) -> int:

host/tests/integration_tests/python/examples.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use datafusion_expr::{
1313

1414
use crate::integration_tests::python::test_utils::python_scalar_udf;
1515

16-
#[tokio::test(flavor = "multi_thread")]
16+
#[tokio::test]
1717
async fn test_add_one() {
1818
const CODE: &str = "
1919
def add_one(x: int) -> int:

host/tests/integration_tests/python/inspection/errors.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::integration_tests::python::test_utils::python_scalar_udfs;
22
use datafusion_common::DataFusionError;
33

4-
#[tokio::test(flavor = "multi_thread")]
4+
#[tokio::test]
55
async fn test_invalid_syntax() {
66
const CODE: &str = ")";
77

@@ -18,7 +18,7 @@ async fn test_invalid_syntax() {
1818
);
1919
}
2020

21-
#[tokio::test(flavor = "multi_thread")]
21+
#[tokio::test]
2222
async fn test_missing_return_type() {
2323
const CODE: &str = "
2424
def add_one(x: int):
@@ -43,7 +43,7 @@ def add_one(x: int):
4343
);
4444
}
4545

46-
#[tokio::test(flavor = "multi_thread")]
46+
#[tokio::test]
4747
async fn test_missing_arg_type() {
4848
const CODE: &str = "
4949
def add_one(x) -> int:
@@ -68,7 +68,7 @@ def add_one(x) -> int:
6868
);
6969
}
7070

71-
#[tokio::test(flavor = "multi_thread")]
71+
#[tokio::test]
7272
async fn test_union_type_2() {
7373
const CODE: &str = "
7474
def add_one(x: int | str) -> int:
@@ -93,7 +93,7 @@ def add_one(x: int | str) -> int:
9393
);
9494
}
9595

96-
#[tokio::test(flavor = "multi_thread")]
96+
#[tokio::test]
9797
async fn test_union_type_2_and_none() {
9898
const CODE: &str = "
9999
def add_one(x: int | str | None) -> int:
@@ -118,7 +118,7 @@ def add_one(x: int | str | None) -> int:
118118
);
119119
}
120120

121-
#[tokio::test(flavor = "multi_thread")]
121+
#[tokio::test]
122122
async fn test_union_type_2_identical() {
123123
const CODE: &str = "
124124
def add_one(x: int | str | int) -> int:
@@ -143,7 +143,7 @@ def add_one(x: int | str | int) -> int:
143143
);
144144
}
145145

146-
#[tokio::test(flavor = "multi_thread")]
146+
#[tokio::test]
147147
async fn test_union_type_2_identical_and_none() {
148148
const CODE: &str = "
149149
def add_one(x: int | None | str | int) -> int:
@@ -168,7 +168,7 @@ def add_one(x: int | None | str | int) -> int:
168168
);
169169
}
170170

171-
#[tokio::test(flavor = "multi_thread")]
171+
#[tokio::test]
172172
async fn test_union_type_3() {
173173
const CODE: &str = "
174174
def add_one(x: int | str | float) -> int:
@@ -193,7 +193,7 @@ def add_one(x: int | str | float) -> int:
193193
);
194194
}
195195

196-
#[tokio::test(flavor = "multi_thread")]
196+
#[tokio::test]
197197
async fn test_type_annotation_is_not_a_type() {
198198
const CODE: &str = "
199199
def add_one(x: 1337) -> int:
@@ -218,7 +218,7 @@ def add_one(x: 1337) -> int:
218218
);
219219
}
220220

221-
#[tokio::test(flavor = "multi_thread")]
221+
#[tokio::test]
222222
async fn test_unsupported_type() {
223223
const CODE: &str = "
224224
def add_one(x: list[int]) -> int:
@@ -243,7 +243,7 @@ def add_one(x: list[int]) -> int:
243243
);
244244
}
245245

246-
#[tokio::test(flavor = "multi_thread")]
246+
#[tokio::test]
247247
async fn test_custom_type() {
248248
const CODE: &str = "
249249
class C:
@@ -271,7 +271,7 @@ def add_one(x: C) -> int:
271271
);
272272
}
273273

274-
#[tokio::test(flavor = "multi_thread")]
274+
#[tokio::test]
275275
async fn test_exception() {
276276
const CODE: &str = "
277277
raise Exception('foo')

host/tests/integration_tests/python/inspection/filter.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::integration_tests::python::test_utils::python_scalar_udfs;
22
use datafusion_expr::ScalarUDFImpl;
33

4-
#[tokio::test(flavor = "multi_thread")]
4+
#[tokio::test]
55
async fn test_underscore() {
66
const CODE: &str = "
77
def foo(x: int) -> int:
@@ -14,7 +14,7 @@ def _bar(x: int) -> int:
1414
assert_eq!(found_udfs(CODE).await, ["foo".to_owned()]);
1515
}
1616

17-
#[tokio::test(flavor = "multi_thread")]
17+
#[tokio::test]
1818
async fn test_non_callalbes() {
1919
const CODE: &str = "
2020
variable = 1
@@ -26,7 +26,7 @@ def foo(x: int) -> int:
2626
assert_eq!(found_udfs(CODE).await, ["foo".to_owned()]);
2727
}
2828

29-
#[tokio::test(flavor = "multi_thread")]
29+
#[tokio::test]
3030
async fn test_imports() {
3131
const CODE: &str = "
3232
from sys import exit
@@ -38,7 +38,7 @@ def foo(x: int) -> int:
3838
assert_eq!(found_udfs(CODE).await, ["foo".to_owned()]);
3939
}
4040

41-
#[tokio::test(flavor = "multi_thread")]
41+
#[tokio::test]
4242
async fn test_classes() {
4343
const CODE: &str = "
4444
class C:

host/tests/integration_tests/python/runtime/dependencies.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, async_udf::AsyncScalarU
99

1010
use crate::integration_tests::python::test_utils::python_scalar_udf;
1111

12-
#[tokio::test(flavor = "multi_thread")]
12+
#[tokio::test]
1313
async fn call_other_function() {
1414
const CODE: &str = "
1515
def _sub1(x: int) -> int:
@@ -45,7 +45,7 @@ def foo(x: int) -> int:
4545
);
4646
}
4747

48-
#[tokio::test(flavor = "multi_thread")]
48+
#[tokio::test]
4949
async fn functools_cache() {
5050
const CODE: &str = "
5151
from functools import cache

0 commit comments

Comments
 (0)