Skip to content

Commit a99b13a

Browse files
authored
Merge pull request #53 from influxdata/crepererum/colvar_helpers
test: avoid boilerplate when working with `ColumnarValue`
2 parents 4248b47 + e29d685 commit a99b13a

File tree

4 files changed

+50
-13
lines changed

4 files changed

+50
-13
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
mod python;
22
mod rust;
3+
mod test_utils;

host/tests/integration_tests/python.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ use std::sync::Arc;
22

33
use arrow::datatypes::{DataType, Field};
44
use datafusion_common::{ScalarValue, assert_contains};
5-
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
5+
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
66
use datafusion_udf_wasm_host::{WasmComponentPrecompiled, WasmScalarUdf};
77

8+
use crate::integration_tests::test_utils::ColumnarValueExt;
9+
810
#[tokio::test(flavor = "multi_thread")]
911
async fn test() {
1012
let data = tokio::fs::read(format!(
@@ -28,17 +30,15 @@ async fn test() {
2830

2931
assert_eq!(udf.return_type(&[]).unwrap(), DataType::Utf8,);
3032

31-
let ColumnarValue::Scalar(scalar) = udf
33+
let scalar = udf
3234
.invoke_with_args(ScalarFunctionArgs {
3335
args: vec![],
3436
arg_fields: vec![],
3537
number_rows: 3,
3638
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
3739
})
3840
.unwrap()
39-
else {
40-
panic!("should be a scalar");
41-
};
41+
.unwrap_scalar();
4242
let ScalarValue::Utf8(s) = scalar else {
4343
panic!("scalar should be UTF8");
4444
};

host/tests/integration_tests/rust.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use datafusion_common::ScalarValue;
88
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
99
use datafusion_udf_wasm_host::{WasmComponentPrecompiled, WasmScalarUdf};
1010

11+
use crate::integration_tests::test_utils::ColumnarValueExt;
12+
1113
#[tokio::test(flavor = "multi_thread")]
1214
async fn test_add_one() {
1315
let data = tokio::fs::read(format!(
@@ -38,7 +40,7 @@ async fn test_add_one() {
3840
@"Error during planning: add_one expects exactly one argument",
3941
);
4042

41-
let ColumnarValue::Array(array) = udf
43+
let array = udf
4244
.invoke_with_args(ScalarFunctionArgs {
4345
args: vec![ColumnarValue::Array(Arc::new(Int32Array::from_iter([
4446
Some(3),
@@ -50,23 +52,19 @@ async fn test_add_one() {
5052
return_field: Arc::new(Field::new("r", DataType::Int32, true)),
5153
})
5254
.unwrap()
53-
else {
54-
panic!("should be an array")
55-
};
55+
.unwrap_array();
5656
assert_eq!(
5757
array.as_ref(),
5858
&Int32Array::from_iter([Some(4), None, Some(2)]) as &dyn Array,
5959
);
60-
let ColumnarValue::Scalar(scalar) = udf
60+
let scalar = udf
6161
.invoke_with_args(ScalarFunctionArgs {
6262
args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(3)))],
6363
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int32, true))],
6464
number_rows: 3,
6565
return_field: Arc::new(Field::new("r", DataType::Int32, true)),
6666
})
6767
.unwrap()
68-
else {
69-
panic!("should be a scalar")
70-
};
68+
.unwrap_scalar();
7169
assert_eq!(scalar, ScalarValue::Int32(Some(4)));
7270
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use arrow::array::ArrayRef;
2+
use datafusion_common::ScalarValue;
3+
use datafusion_expr::ColumnarValue;
4+
5+
/// Extension trait for [`ColumnarValue`] for easier testing.
6+
pub(crate) trait ColumnarValueExt {
7+
/// Extracts [`ColumnarValue::Array`] variant.
8+
///
9+
/// # Panic
10+
/// Panics if this is not an array.
11+
#[track_caller]
12+
fn unwrap_array(self) -> ArrayRef;
13+
14+
/// Extracts [`ColumnarValue::Scalar`] variant.
15+
///
16+
/// # Panic
17+
/// Panics if this is not an scalar.
18+
#[track_caller]
19+
fn unwrap_scalar(self) -> ScalarValue;
20+
}
21+
22+
impl ColumnarValueExt for ColumnarValue {
23+
#[track_caller]
24+
fn unwrap_array(self) -> ArrayRef {
25+
match self {
26+
Self::Array(array) => array,
27+
Self::Scalar(_) => panic!("expected an array but got a scalar"),
28+
}
29+
}
30+
31+
#[track_caller]
32+
fn unwrap_scalar(self) -> ScalarValue {
33+
match self {
34+
Self::Array(_) => panic!("expected a scalar but got an array"),
35+
Self::Scalar(scalar_value) => scalar_value,
36+
}
37+
}
38+
}

0 commit comments

Comments
 (0)