Skip to content

Commit 65a8bdc

Browse files
authored
fixup benchmarks (#93)
1 parent 6e9b678 commit 65a8bdc

File tree

3 files changed

+93
-8
lines changed

3 files changed

+93
-8
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ repos:
2020
pass_filenames: false
2121
- id: clippy
2222
name: Clippy
23-
entry: cargo clippy -- -D warnings
23+
entry: cargo clippy --all-targets -- -D warnings
2424
types: [rust]
2525
language: system
2626
pass_filenames: false

benches/main.rs

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
use std::sync::Arc;
2+
13
use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criterion};
24

3-
use datafusion::arrow::datatypes::DataType;
5+
use datafusion::arrow::array::{StringArray, StringViewArray};
6+
use datafusion::arrow::datatypes::{DataType, Field};
47
use datafusion::logical_expr::ColumnarValue;
58
use datafusion::{common::ScalarValue, logical_expr::ScalarFunctionArgs};
69
use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf};
@@ -15,18 +18,27 @@ fn bench_json_contains(b: &mut Bencher) {
1518
ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))),
1619
];
1720

21+
let arg_fields = vec![
22+
Arc::new(Field::new("arg0", DataType::Utf8, false)),
23+
Arc::new(Field::new("arg1", DataType::Utf8, false)),
24+
Arc::new(Field::new("arg2", DataType::Utf8, false)),
25+
];
26+
27+
let return_field = Arc::new(Field::new("json_contains", DataType::Boolean, false));
28+
1829
b.iter(|| {
1930
json_contains
2031
.invoke_with_args(ScalarFunctionArgs {
2132
args: args.clone(),
2233
number_rows: 1,
23-
return_type: &DataType::Boolean,
34+
arg_fields: arg_fields.clone(),
35+
return_field: return_field.clone(),
2436
})
2537
.unwrap()
2638
});
2739
}
2840

29-
fn bench_json_get_str(b: &mut Bencher) {
41+
fn bench_json_get_str_scalar(b: &mut Bencher) {
3042
let json_get_str = json_get_str_udf();
3143
let args = &[
3244
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
@@ -36,20 +48,93 @@ fn bench_json_get_str(b: &mut Bencher) {
3648
ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))),
3749
];
3850

51+
let arg_fields = vec![
52+
Arc::new(Field::new("arg0", DataType::Utf8, false)),
53+
Arc::new(Field::new("arg1", DataType::Utf8, false)),
54+
Arc::new(Field::new("arg2", DataType::Utf8, false)),
55+
];
56+
57+
let return_field = Arc::new(Field::new("json_get_str", DataType::Utf8, false));
58+
3959
b.iter(|| {
4060
json_get_str
4161
.invoke_with_args(ScalarFunctionArgs {
4262
args: args.to_vec(),
63+
arg_fields: arg_fields.clone(),
4364
number_rows: 1,
44-
return_type: &DataType::Utf8,
65+
return_field: return_field.clone(),
4566
})
46-
.unwrap()
67+
.unwrap();
68+
});
69+
}
70+
71+
fn bench_json_get_str_array(b: &mut Bencher) {
72+
let json_get_str = json_get_str_udf();
73+
let args = &[
74+
ColumnarValue::Array(Arc::new(StringArray::from_iter_values(vec![
75+
r#"{"a": {"aa": "x", "ab": "y"}, "b": []}"#.to_string(),
76+
r#"{"a": {"aa": "x2", "ab": "y2"}, "b": []}"#.to_string(),
77+
]))),
78+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))),
79+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))),
80+
];
81+
82+
let arg_fields = vec![
83+
Arc::new(Field::new("arg0", DataType::Utf8, false)),
84+
Arc::new(Field::new("arg1", DataType::Utf8, false)),
85+
Arc::new(Field::new("arg2", DataType::Utf8, false)),
86+
];
87+
88+
let return_field = Arc::new(Field::new("json_get_str", DataType::Utf8, false));
89+
90+
b.iter(|| {
91+
json_get_str
92+
.invoke_with_args(ScalarFunctionArgs {
93+
args: args.to_vec(),
94+
arg_fields: arg_fields.clone(),
95+
number_rows: 1,
96+
return_field: return_field.clone(),
97+
})
98+
.unwrap();
99+
});
100+
}
101+
102+
fn bench_json_get_str_view_array(b: &mut Bencher) {
103+
let json_get_str = json_get_str_udf();
104+
let args = &[
105+
ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(vec![
106+
r#"{"a": {"aa": "x", "ab": "y"}, "b": []}"#.to_string(),
107+
r#"{"a": {"aa": "x2", "ab": "y2"}, "b": []}"#.to_string(),
108+
]))),
109+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))),
110+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))),
111+
];
112+
113+
let arg_fields = vec![
114+
Arc::new(Field::new("arg0", DataType::Utf8View, false)),
115+
Arc::new(Field::new("arg1", DataType::Utf8, false)),
116+
Arc::new(Field::new("arg2", DataType::Utf8, false)),
117+
];
118+
119+
let return_field = Arc::new(Field::new("json_get_str", DataType::Utf8, false));
120+
121+
b.iter(|| {
122+
json_get_str
123+
.invoke_with_args(ScalarFunctionArgs {
124+
args: args.to_vec(),
125+
arg_fields: arg_fields.clone(),
126+
number_rows: 1,
127+
return_field: return_field.clone(),
128+
})
129+
.unwrap();
47130
});
48131
}
49132

50133
fn criterion_benchmark(c: &mut Criterion) {
51134
c.bench_function("json_contains", bench_json_contains);
52-
c.bench_function("json_get_str", bench_json_get_str);
135+
c.bench_function("json_get_str_scalar", bench_json_get_str_scalar);
136+
c.bench_function("json_get_str_array", bench_json_get_str_array);
137+
c.bench_function("json_get_str_view_array", bench_json_get_str_view_array);
53138
}
54139

55140
criterion_group!(benches, criterion_benchmark);

tests/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ async fn test_json_get_array_nested_objects() {
121121

122122
#[tokio::test]
123123
async fn test_json_get_array_nested_arrays() {
124-
let sql = r#"select json_get_array('[[1, 2], [3, 4]]')"#;
124+
let sql = r"select json_get_array('[[1, 2], [3, 4]]')";
125125
let batches = run_query(sql).await.unwrap();
126126
let (value_type, value_repr) = display_val(batches).await;
127127
assert!(matches!(value_type, DataType::List(_)));

0 commit comments

Comments
 (0)