diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a847f9..af16aa6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: pass_filenames: false - id: clippy name: Clippy - entry: cargo clippy -- -D warnings + entry: cargo clippy --all-targets -- -D warnings types: [rust] language: system pass_filenames: false diff --git a/benches/main.rs b/benches/main.rs index 5c3eac5..77d8eae 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criterion}; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::{StringArray, StringViewArray}; +use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::logical_expr::ColumnarValue; use datafusion::{common::ScalarValue, logical_expr::ScalarFunctionArgs}; use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf}; @@ -15,18 +18,27 @@ fn bench_json_contains(b: &mut Bencher) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; + let arg_fields = vec![ + Arc::new(Field::new("arg0", DataType::Utf8, false)), + Arc::new(Field::new("arg1", DataType::Utf8, false)), + Arc::new(Field::new("arg2", DataType::Utf8, false)), + ]; + + let return_field = Arc::new(Field::new("json_contains", DataType::Boolean, false)); + b.iter(|| { json_contains .invoke_with_args(ScalarFunctionArgs { args: args.clone(), number_rows: 1, - return_type: &DataType::Boolean, + arg_fields: arg_fields.clone(), + return_field: return_field.clone(), }) .unwrap() }); } -fn bench_json_get_str(b: &mut Bencher) { +fn bench_json_get_str_scalar(b: &mut Bencher) { let json_get_str = json_get_str_udf(); let args = &[ ColumnarValue::Scalar(ScalarValue::Utf8(Some( @@ -36,20 +48,93 @@ fn bench_json_get_str(b: &mut Bencher) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; + let arg_fields = vec![ + Arc::new(Field::new("arg0", DataType::Utf8, false)), + Arc::new(Field::new("arg1", DataType::Utf8, false)), + Arc::new(Field::new("arg2", DataType::Utf8, false)), + ]; + + let return_field = Arc::new(Field::new("json_get_str", DataType::Utf8, false)); + b.iter(|| { json_get_str .invoke_with_args(ScalarFunctionArgs { args: args.to_vec(), + arg_fields: arg_fields.clone(), number_rows: 1, - return_type: &DataType::Utf8, + return_field: return_field.clone(), }) - .unwrap() + .unwrap(); + }); +} + +fn bench_json_get_str_array(b: &mut Bencher) { + let json_get_str = json_get_str_udf(); + let args = &[ + ColumnarValue::Array(Arc::new(StringArray::from_iter_values(vec![ + r#"{"a": {"aa": "x", "ab": "y"}, "b": []}"#.to_string(), + r#"{"a": {"aa": "x2", "ab": "y2"}, "b": []}"#.to_string(), + ]))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), + ]; + + let arg_fields = vec![ + Arc::new(Field::new("arg0", DataType::Utf8, false)), + Arc::new(Field::new("arg1", DataType::Utf8, false)), + Arc::new(Field::new("arg2", DataType::Utf8, false)), + ]; + + let return_field = Arc::new(Field::new("json_get_str", DataType::Utf8, false)); + + b.iter(|| { + json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.clone(), + number_rows: 1, + return_field: return_field.clone(), + }) + .unwrap(); + }); +} + +fn bench_json_get_str_view_array(b: &mut Bencher) { + let json_get_str = json_get_str_udf(); + let args = &[ + ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(vec![ + r#"{"a": {"aa": "x", "ab": "y"}, "b": []}"#.to_string(), + r#"{"a": {"aa": "x2", "ab": "y2"}, "b": []}"#.to_string(), + ]))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), + ]; + + let arg_fields = vec![ + Arc::new(Field::new("arg0", DataType::Utf8View, false)), + Arc::new(Field::new("arg1", DataType::Utf8, false)), + Arc::new(Field::new("arg2", DataType::Utf8, false)), + ]; + + let return_field = Arc::new(Field::new("json_get_str", DataType::Utf8, false)); + + b.iter(|| { + json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.clone(), + number_rows: 1, + return_field: return_field.clone(), + }) + .unwrap(); }); } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("json_contains", bench_json_contains); - c.bench_function("json_get_str", bench_json_get_str); + c.bench_function("json_get_str_scalar", bench_json_get_str_scalar); + c.bench_function("json_get_str_array", bench_json_get_str_array); + c.bench_function("json_get_str_view_array", bench_json_get_str_view_array); } criterion_group!(benches, criterion_benchmark); diff --git a/tests/main.rs b/tests/main.rs index 827d779..e3b9986 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -121,7 +121,7 @@ async fn test_json_get_array_nested_objects() { #[tokio::test] async fn test_json_get_array_nested_arrays() { - let sql = r#"select json_get_array('[[1, 2], [3, 4]]')"#; + let sql = r"select json_get_array('[[1, 2], [3, 4]]')"; let batches = run_query(sql).await.unwrap(); let (value_type, value_repr) = display_val(batches).await; assert!(matches!(value_type, DataType::List(_)));