Skip to content

Commit 53a7944

Browse files
authored
Merge pull request #160 from influxdata/tm/invoke-async
feat: use AsyncScalarUDFImpl instead of ScalarUDFImpl
2 parents ca932b3 + a7c2620 commit 53a7944

File tree

23 files changed

+1011
-750
lines changed

23 files changed

+1011
-750
lines changed

host/src/lib.rs

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@ use std::{any::Any, ops::DerefMut, sync::Arc};
66

77
use ::http::HeaderName;
88
use arrow::datatypes::DataType;
9-
use datafusion_common::{DataFusionError, Result as DataFusionResult};
10-
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature};
9+
use datafusion_common::{DataFusionError, Result as DataFusionResult, config::ConfigOptions};
10+
use datafusion_expr::{
11+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
12+
async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl},
13+
};
1114
use tokio::{runtime::Handle, sync::Mutex};
1215
use wasmtime::{
1316
Engine, Store,
1417
component::{Component, ResourceAny},
1518
};
16-
use wasmtime_wasi::{ResourceTable, WasiCtx, WasiCtxView, WasiView, p2::pipe::MemoryOutputPipe};
19+
use wasmtime_wasi::{
20+
ResourceTable, WasiCtx, WasiCtxView, WasiView, async_trait, p2::pipe::MemoryOutputPipe,
21+
};
1722
use wasmtime_wasi_http::{
1823
HttpResult, WasiHttpCtx, WasiHttpView,
1924
bindings::http::types::ErrorCode as HttpErrorCode,
@@ -394,6 +399,11 @@ impl WasmScalarUdf {
394399

395400
Ok(udfs)
396401
}
402+
403+
/// Convert this [WasmScalarUdf] into an [AsyncScalarUDF].
404+
pub fn as_async_udf(self) -> AsyncScalarUDF {
405+
AsyncScalarUDF::new(Arc::new(self))
406+
}
397407
}
398408

399409
impl std::fmt::Debug for WasmScalarUdf {
@@ -450,21 +460,43 @@ impl ScalarUDFImpl for WasmScalarUdf {
450460
})
451461
}
452462

453-
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
454-
async_in_sync_context(async {
455-
let args = args.try_into()?;
456-
let mut store_guard = self.store.lock().await;
457-
let return_type = self
458-
.bindings
459-
.datafusion_udf_wasm_udf_types()
460-
.scalar_udf()
461-
.call_invoke_with_args(store_guard.deref_mut(), self.resource, &args)
462-
.await
463-
.context(
464-
"call ScalarUdf::invoke_with_args",
465-
Some(&store_guard.data().stderr.contents()),
466-
)??;
467-
return_type.try_into()
468-
})
463+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
464+
Err(DataFusionError::NotImplemented(
465+
"synchronous invocation of WasmScalarUdf is not supported, use invoke_async_with_args instead".to_string(),
466+
))
467+
}
468+
}
469+
470+
#[async_trait]
471+
impl AsyncScalarUDFImpl for WasmScalarUdf {
472+
fn ideal_batch_size(&self) -> Option<usize> {
473+
None
474+
}
475+
476+
async fn invoke_async_with_args(
477+
&self,
478+
args: ScalarFunctionArgs,
479+
_option: &ConfigOptions,
480+
) -> DataFusionResult<arrow::array::ArrayRef> {
481+
let args = args.try_into()?;
482+
let mut store_guard = self.store.lock().await;
483+
let return_type = self
484+
.bindings
485+
.datafusion_udf_wasm_udf_types()
486+
.scalar_udf()
487+
.call_invoke_with_args(store_guard.deref_mut(), self.resource, &args)
488+
.await
489+
.context(
490+
"call ScalarUdf::invoke_with_args",
491+
Some(&store_guard.data().stderr.contents()),
492+
)??;
493+
494+
drop(store_guard);
495+
496+
let columnar_value: ColumnarValue = return_type.try_into()?;
497+
match columnar_value {
498+
ColumnarValue::Array(v) => Ok(v),
499+
ColumnarValue::Scalar(v) => v.to_array_of_size(args.number_rows as usize),
500+
}
469501
}
470502
}
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
mod python;
22
mod rust;
3-
mod test_utils;

host/tests/integration_tests/python/argument_forms.rs

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ use arrow::{
77
array::{Array, Int64Array},
88
datatypes::{DataType, Field},
99
};
10-
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
11-
12-
use crate::integration_tests::{
13-
python::test_utils::python_scalar_udf, test_utils::ColumnarValueExt,
10+
use datafusion_common::config::ConfigOptions;
11+
use datafusion_expr::{
12+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
13+
async_udf::AsyncScalarUDFImpl,
1414
};
1515

16+
use crate::integration_tests::python::test_utils::python_scalar_udf;
17+
1618
#[tokio::test(flavor = "multi_thread")]
1719
async fn test_positional_or_keyword() {
1820
const CODE: &str = "
@@ -32,18 +34,21 @@ def foo(x: int) -> int:
3234
);
3335

3436
let array = udf
35-
.invoke_with_args(ScalarFunctionArgs {
36-
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
37-
Some(3),
38-
None,
39-
Some(-10),
40-
])))],
41-
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
42-
number_rows: 3,
43-
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
44-
})
45-
.unwrap()
46-
.unwrap_array();
37+
.invoke_async_with_args(
38+
ScalarFunctionArgs {
39+
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
40+
Some(3),
41+
None,
42+
Some(-10),
43+
])))],
44+
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
45+
number_rows: 3,
46+
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
47+
},
48+
&ConfigOptions::default(),
49+
)
50+
.await
51+
.unwrap();
4752
assert_eq!(
4853
array.as_ref(),
4954
&Int64Array::from_iter([Some(4), None, Some(-9)]) as &dyn Array,
@@ -91,18 +96,21 @@ def foo(x: int, /) -> int:
9196
);
9297

9398
let array = udf
94-
.invoke_with_args(ScalarFunctionArgs {
95-
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
96-
Some(3),
97-
None,
98-
Some(-10),
99-
])))],
100-
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
101-
number_rows: 3,
102-
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
103-
})
104-
.unwrap()
105-
.unwrap_array();
99+
.invoke_async_with_args(
100+
ScalarFunctionArgs {
101+
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
102+
Some(3),
103+
None,
104+
Some(-10),
105+
])))],
106+
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
107+
number_rows: 3,
108+
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
109+
},
110+
&ConfigOptions::default(),
111+
)
112+
.await
113+
.unwrap();
106114
assert_eq!(
107115
array.as_ref(),
108116
&Int64Array::from_iter([Some(4), None, Some(-9)]) as &dyn Array,
@@ -151,20 +159,23 @@ def foo(x: int, /, y: int) -> int:
151159
);
152160

153161
let array = udf
154-
.invoke_with_args(ScalarFunctionArgs {
155-
args: vec![
156-
ColumnarValue::Array(Arc::new(Int64Array::from_iter([Some(3)]))),
157-
ColumnarValue::Array(Arc::new(Int64Array::from_iter([Some(4)]))),
158-
],
159-
arg_fields: vec![
160-
Arc::new(Field::new("a1", DataType::Int64, true)),
161-
Arc::new(Field::new("a2", DataType::Int64, true)),
162-
],
163-
number_rows: 1,
164-
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
165-
})
166-
.unwrap()
167-
.unwrap_array();
162+
.invoke_async_with_args(
163+
ScalarFunctionArgs {
164+
args: vec![
165+
ColumnarValue::Array(Arc::new(Int64Array::from_iter([Some(3)]))),
166+
ColumnarValue::Array(Arc::new(Int64Array::from_iter([Some(4)]))),
167+
],
168+
arg_fields: vec![
169+
Arc::new(Field::new("a1", DataType::Int64, true)),
170+
Arc::new(Field::new("a2", DataType::Int64, true)),
171+
],
172+
number_rows: 1,
173+
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
174+
},
175+
&ConfigOptions::default(),
176+
)
177+
.await
178+
.unwrap();
168179
assert_eq!(
169180
array.as_ref(),
170181
&Int64Array::from_iter([Some(7)]) as &dyn Array,

host/tests/integration_tests/python/examples.rs

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ use arrow::{
55
datatypes::{DataType, Field},
66
};
77
use datafusion_common::ScalarValue;
8-
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
9-
10-
use crate::integration_tests::{
11-
python::test_utils::python_scalar_udf, test_utils::ColumnarValueExt,
8+
use datafusion_common::config::ConfigOptions;
9+
use datafusion_expr::{
10+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
11+
async_udf::AsyncScalarUDFImpl,
1212
};
1313

14+
use crate::integration_tests::python::test_utils::python_scalar_udf;
15+
1416
#[tokio::test(flavor = "multi_thread")]
1517
async fn test_add_one() {
1618
const CODE: &str = "
@@ -37,33 +39,39 @@ def add_one(x: int) -> int:
3739

3840
// call with array
3941
let array = udf
40-
.invoke_with_args(ScalarFunctionArgs {
41-
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
42-
Some(3),
43-
None,
44-
Some(1),
45-
])))],
46-
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
47-
number_rows: 3,
48-
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
49-
})
50-
.unwrap()
51-
.unwrap_array();
42+
.invoke_async_with_args(
43+
ScalarFunctionArgs {
44+
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
45+
Some(3),
46+
None,
47+
Some(1),
48+
])))],
49+
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
50+
number_rows: 3,
51+
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
52+
},
53+
&ConfigOptions::default(),
54+
)
55+
.await
56+
.unwrap();
5257
assert_eq!(
5358
array.as_ref(),
5459
&Int64Array::from_iter([Some(4), None, Some(2)]) as &dyn Array,
5560
);
5661

5762
// call with scalar, output will still be an array
5863
let array = udf
59-
.invoke_with_args(ScalarFunctionArgs {
60-
args: vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(3)))],
61-
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
62-
number_rows: 3,
63-
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
64-
})
65-
.unwrap()
66-
.unwrap_array();
64+
.invoke_async_with_args(
65+
ScalarFunctionArgs {
66+
args: vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(3)))],
67+
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
68+
number_rows: 3,
69+
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
70+
},
71+
&ConfigOptions::default(),
72+
)
73+
.await
74+
.unwrap();
6775
assert_eq!(
6876
array.as_ref(),
6977
&Int64Array::from_iter([Some(4), Some(4), Some(4)]) as &dyn Array,

host/tests/integration_tests/python/runtime/dependencies.rs

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@ use arrow::{
44
array::{Array, Int64Array},
55
datatypes::{DataType, Field},
66
};
7-
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
7+
use datafusion_common::config::ConfigOptions;
8+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, async_udf::AsyncScalarUDFImpl};
89

9-
use crate::integration_tests::{
10-
python::test_utils::python_scalar_udf, test_utils::ColumnarValueExt,
11-
};
10+
use crate::integration_tests::python::test_utils::python_scalar_udf;
1211

1312
#[tokio::test(flavor = "multi_thread")]
1413
async fn call_other_function() {
@@ -25,18 +24,21 @@ def foo(x: int) -> int:
2524

2625
let udf = python_scalar_udf(CODE).await.unwrap();
2726
let array = udf
28-
.invoke_with_args(ScalarFunctionArgs {
29-
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
30-
Some(1),
31-
Some(2),
32-
Some(3),
33-
])))],
34-
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
35-
number_rows: 3,
36-
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
37-
})
38-
.unwrap()
39-
.unwrap_array();
27+
.invoke_async_with_args(
28+
ScalarFunctionArgs {
29+
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
30+
Some(1),
31+
Some(2),
32+
Some(3),
33+
])))],
34+
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
35+
number_rows: 3,
36+
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
37+
},
38+
&ConfigOptions::default(),
39+
)
40+
.await
41+
.unwrap();
4042
assert_eq!(
4143
array.as_ref(),
4244
&Int64Array::from_iter([Some(12), Some(23), Some(34)]) as &dyn Array,
@@ -59,18 +61,21 @@ def foo(x: int) -> int:
5961

6062
let udf = python_scalar_udf(CODE).await.unwrap();
6163
let array = udf
62-
.invoke_with_args(ScalarFunctionArgs {
63-
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
64-
Some(10),
65-
Some(20),
66-
Some(10),
67-
])))],
68-
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
69-
number_rows: 3,
70-
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
71-
})
72-
.unwrap()
73-
.unwrap_array();
64+
.invoke_async_with_args(
65+
ScalarFunctionArgs {
66+
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
67+
Some(10),
68+
Some(20),
69+
Some(10),
70+
])))],
71+
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
72+
number_rows: 3,
73+
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
74+
},
75+
&ConfigOptions::default(),
76+
)
77+
.await
78+
.unwrap();
7479
assert_eq!(
7580
array.as_ref(),
7681
&Int64Array::from_iter([Some(11), Some(22), Some(11)]) as &dyn Array,

0 commit comments

Comments
 (0)