diff --git a/Cargo.toml b/Cargo.toml index 0a7a4e7..ea411da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ license = "Apache-2.0" keywords = ["datafusion", "JSON", "SQL"] categories = ["database-implementations", "parsing"] repository = "https://github.com/datafusion-contrib/datafusion-functions-json/" -rust-version = "1.82.0" +rust-version = "1.85.1" [dependencies] datafusion = { version = "48", default-features = false } diff --git a/tests/main.rs b/tests/main.rs index e3b9986..35a24a4 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -9,7 +9,9 @@ use datafusion::common::ScalarValue; use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion::prelude::SessionContext; use datafusion_functions_json::udfs::json_get_str_udf; -use utils::{create_context, display_val, logical_plan, run_query, run_query_dict, run_query_large, run_query_params}; +use utils::{create_context, display_val, logical_plan, run_query, run_query_params}; + +use crate::utils::{for_all_json_datatypes, run_query_datatype}; mod utils; @@ -29,10 +31,13 @@ async fn test_json_contains() { "+------------------+-------------------------------------------+", ]; - let batches = run_query("select name, json_contains(json_data, 'foo') from test") - .await - .unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype("select name, json_contains(json_data, 'foo') from test", dt) + .await + .unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -451,9 +456,12 @@ async fn test_json_contains_large() { "+----------+", ]; - let batches = run_query_large("select count(*) from test where json_contains(json_data, 'foo')") - .await - .unwrap(); + let batches = run_query_datatype( + "select count(*) from test where json_contains(json_data, 'foo')", + &DataType::LargeUtf8, + ) + .await + .unwrap(); assert_batches_eq!(expected, &batches); } @@ -467,9 +475,12 @@ async fn test_json_contains_large_vec() { "+----------+", ]; - let batches = run_query_large("select count(*) from test where json_contains(json_data, name)") - .await - .unwrap(); + let batches = run_query_datatype( + "select count(*) from test where json_contains(json_data, name)", + &DataType::LargeUtf8, + ) + .await + .unwrap(); assert_batches_eq!(expected, &batches); } @@ -483,9 +494,12 @@ async fn test_json_contains_large_both() { "+----------+", ]; - let batches = run_query_large("select count(*) from test where json_contains(json_data, json_data)") - .await - .unwrap(); + let batches = run_query_datatype( + "select count(*) from test where json_contains(json_data, json_data)", + &DataType::LargeUtf8, + ) + .await + .unwrap(); assert_batches_eq!(expected, &batches); } @@ -501,8 +515,12 @@ async fn test_json_contains_large_params() { let sql = "select count(*) from test where json_contains(json_data, 'foo')"; let params = vec![ScalarValue::LargeUtf8(Some("foo".to_string()))]; - let batches = run_query_params(sql, false, params).await.unwrap(); - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_params(sql, dt, params.clone()).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -517,14 +535,17 @@ async fn test_json_contains_large_both_params() { let sql = "select count(*) from test where json_contains(json_data, 'foo')"; let params = vec![ScalarValue::LargeUtf8(Some("foo".to_string()))]; - let batches = run_query_params(sql, true, params).await.unwrap(); - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_params(sql, dt, params.clone()).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_json_length_vec() { let sql = r"select name, json_len(json_data) as len from test"; - let batches = run_query(sql).await.unwrap(); let expected = [ "+------------------+-----+", @@ -539,10 +560,12 @@ async fn test_json_length_vec() { "| invalid_json | |", "+------------------+-----+", ]; - assert_batches_eq!(expected, &batches); - let batches = run_query_large(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -621,6 +644,7 @@ fn test_json_get_large_utf8() { #[tokio::test] async fn test_json_get_union_scalar() { + let sql = r#"select json_get(json_get('{"x": {"y": 1}}', 'x'), 'y') as v"#; let expected = [ "+---------+", "| v |", @@ -629,14 +653,16 @@ async fn test_json_get_union_scalar() { "+---------+", ]; - let batches = run_query(r#"select json_get(json_get('{"x": {"y": 1}}', 'x'), 'y') as v"#) - .await - .unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_json_get_nested_collapsed() { + let sql = "select name, json_get(json_get(json_data, 'foo'), 0) as v from test"; let expected = [ "+------------------+---------+", "| name | v |", @@ -651,10 +677,30 @@ async fn test_json_get_nested_collapsed() { "+------------------+---------+", ]; - let batches = run_query("select name, json_get(json_get(json_data, 'foo'), 0) v from test") - .await - .unwrap(); - assert_batches_eq!(expected, &batches); + let expected_dict = [ + "+------------------+---------+", + "| name | v |", + "+------------------+---------+", + "| object_foo | |", + "| object_foo_array | {int=1} |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+---------+", + ]; + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + + if matches!(dt, DataType::Dictionary(_, _)) { + assert_batches_eq!(expected_dict, &batches); + } else { + assert_batches_eq!(expected, &batches); + } + }) + .await; } #[tokio::test] @@ -678,8 +724,30 @@ async fn test_json_get_cte() { "+------------------+---------+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + let expected_dict = [ + "+------------------+---------+", + "| name | v |", + "+------------------+---------+", + "| object_foo | |", + "| object_foo_array | {int=1} |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+---------+", + ]; + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + + if matches!(dt, DataType::Dictionary(_, _)) { + assert_batches_eq!(expected_dict, &batches); + } else { + assert_batches_eq!(expected, &batches); + } + }) + .await; } #[tokio::test] @@ -719,8 +787,30 @@ async fn test_json_get_unnest() { "+------------------+---------+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + let expected_dict = [ + "+------------------+---------+", + "| name | v |", + "+------------------+---------+", + "| object_foo | |", + "| object_foo_array | {int=1} |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+---------+", + ]; + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + + if matches!(dt, DataType::Dictionary(_, _)) { + assert_batches_eq!(expected_dict, &batches); + } else { + assert_batches_eq!(expected, &batches); + } + }) + .await; } #[tokio::test] @@ -753,8 +843,11 @@ async fn test_json_get_int_unnest() { "+------------------+---+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -792,8 +885,11 @@ async fn test_json_get_union_array_nested() { "+-------------+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -823,13 +919,16 @@ async fn test_json_get_union_array_skip_double_nested() { "+--------------------------+---+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_arrow() { - let batches = run_query("select name, json_data->'foo' from test").await.unwrap(); + let sql = "select name, json_data->'foo' from test"; let expected = [ "+------------------+-------------------------+", @@ -844,7 +943,30 @@ async fn test_arrow() { "| invalid_json | {null=} |", "+------------------+-------------------------+", ]; - assert_batches_eq!(expected, &batches); + + let expected_dict = [ + "+------------------+-------------------------+", + "| name | test.json_data -> 'foo' |", + "+------------------+-------------------------+", + "| object_foo | {str=abc} |", + "| object_foo_array | {array=[1]} |", + "| object_foo_obj | {object={}} |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-------------------------+", + ]; + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + if matches!(dt, DataType::Dictionary(_, _)) { + assert_batches_eq!(expected_dict, &batches); + } else { + assert_batches_eq!(expected, &batches); + } + }) + .await; } #[tokio::test] @@ -861,7 +983,7 @@ async fn test_plan_arrow() { #[tokio::test] async fn test_long_arrow() { - let batches = run_query("select name, json_data->>'foo' from test").await.unwrap(); + let sql = "select name, json_data->>'foo' from test"; let expected = [ "+------------------+--------------------------+", @@ -876,7 +998,12 @@ async fn test_long_arrow() { "| invalid_json | |", "+------------------+--------------------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -893,9 +1020,7 @@ async fn test_plan_long_arrow() { #[tokio::test] async fn test_long_arrow_eq_str() { - let batches = run_query(r"select name, (json_data->>'foo')='abc' from test") - .await - .unwrap(); + let sql = r"select name, (json_data->>'foo')='abc' from test"; let expected = [ "+------------------+----------------------------------------+", @@ -910,14 +1035,18 @@ async fn test_long_arrow_eq_str() { "| invalid_json | |", "+------------------+----------------------------------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } /// Test column name / alias creation with a cast in the needle / key #[tokio::test] async fn test_arrow_cast_key_text() { let sql = r#"select ('{"foo": 42}'->>('foo'::text))"#; - let batches = run_query(sql).await.unwrap(); let expected = [ "+-------------------------+", @@ -927,13 +1056,16 @@ async fn test_arrow_cast_key_text() { "+-------------------------+", ]; - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_arrow_cast_int() { let sql = r#"select ('{"foo": 42}'->'foo')::int"#; - let batches = run_query(sql).await.unwrap(); let expected = [ "+------------------------+", @@ -942,9 +1074,13 @@ async fn test_arrow_cast_int() { "| 42 |", "+------------------------+", ]; - assert_batches_eq!(expected, &batches); - assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string())); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string())); + }) + .await; } #[tokio::test] @@ -961,7 +1097,7 @@ async fn test_plan_arrow_cast_int() { #[tokio::test] async fn test_arrow_double_nested() { - let batches = run_query("select name, json_data->'foo'->0 from test").await.unwrap(); + let sql = "select name, json_data->'foo'->0 from test"; let expected = [ "+------------------+------------------------------+", @@ -976,7 +1112,30 @@ async fn test_arrow_double_nested() { "| invalid_json | {null=} |", "+------------------+------------------------------+", ]; - assert_batches_eq!(expected, &batches); + + let expected_dict = [ + "+------------------+------------------------------+", + "| name | test.json_data -> 'foo' -> 0 |", + "+------------------+------------------------------+", + "| object_foo | |", + "| object_foo_array | {int=1} |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+------------------------------+", + ]; + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + if matches!(dt, DataType::Dictionary(_, _)) { + assert_batches_eq!(expected_dict, &batches); + } else { + assert_batches_eq!(expected, &batches); + } + }) + .await; } #[tokio::test] @@ -993,8 +1152,7 @@ async fn test_plan_arrow_double_nested() { #[tokio::test] async fn test_double_arrow_double_nested() { - let batches = run_query("select name, json_data->>'foo'->>0 from test").await.unwrap(); - + let sql = "select name, json_data->>'foo'->>0 from test"; let expected = [ "+------------------+--------------------------------+", "| name | test.json_data ->> 'foo' ->> 0 |", @@ -1008,7 +1166,12 @@ async fn test_double_arrow_double_nested() { "| invalid_json | |", "+------------------+--------------------------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -1025,10 +1188,7 @@ async fn test_plan_double_arrow_double_nested() { #[tokio::test] async fn test_arrow_double_nested_cast() { - let batches = run_query("select name, (json_data->'foo'->0)::int from test") - .await - .unwrap(); - + let sql = "select name, (json_data->'foo'->0)::int from test"; let expected = [ "+------------------+------------------------------+", "| name | test.json_data -> 'foo' -> 0 |", @@ -1042,7 +1202,12 @@ async fn test_arrow_double_nested_cast() { "| invalid_json | |", "+------------------+------------------------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -1059,10 +1224,7 @@ async fn test_plan_arrow_double_nested_cast() { #[tokio::test] async fn test_double_arrow_double_nested_cast() { - let batches = run_query("select name, (json_data->>'foo'->>0)::int from test") - .await - .unwrap(); - + let sql = "select name, (json_data->>'foo'->>0)::int from test"; let expected = [ "+------------------+--------------------------------+", "| name | test.json_data ->> 'foo' ->> 0 |", @@ -1076,7 +1238,12 @@ async fn test_double_arrow_double_nested_cast() { "| invalid_json | |", "+------------------+--------------------------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -1094,6 +1261,7 @@ async fn test_plan_double_arrow_double_nested_cast() { #[tokio::test] async fn test_arrow_nested_columns() { + let sql = "select json_data->str_key1->str_key2 v from more_nested"; let expected = [ "+-------------+", "| v |", @@ -1104,13 +1272,16 @@ async fn test_arrow_nested_columns() { "+-------------+", ]; - let sql = "select json_data->str_key1->str_key2 v from more_nested"; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_arrow_nested_double_columns() { + let sql = "select json_data->str_key1->str_key2->int_key v from more_nested"; let expected = [ "+---------+", "| v |", @@ -1121,9 +1292,11 @@ async fn test_arrow_nested_double_columns() { "+---------+", ]; - let sql = "select json_data->str_key1->str_key2->int_key v from more_nested"; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -1143,6 +1316,7 @@ async fn test_lexical_precedence_correct() { #[tokio::test] async fn test_question_mark_contains() { + let sql = "select name, json_data ? 'foo' from test"; let expected = [ "+------------------+------------------------+", "| name | test.json_data ? 'foo' |", @@ -1157,15 +1331,16 @@ async fn test_question_mark_contains() { "+------------------+------------------------+", ]; - let batches = run_query("select name, json_data ? 'foo' from test").await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_arrow_filter() { - let batches = run_query("select name from test where (json_data->>'foo') = 'abc'") - .await - .unwrap(); + let sql = "select name from test where (json_data->>'foo') = 'abc'"; let expected = [ "+------------+", @@ -1174,15 +1349,17 @@ async fn test_arrow_filter() { "| object_foo |", "+------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_question_filter() { - let batches = run_query("select name from test where json_data ? 'foo'") - .await - .unwrap(); - + let sql = "select name from test where json_data ? 'foo'"; let expected = [ "+------------------+", "| name |", @@ -1193,14 +1370,17 @@ async fn test_question_filter() { "| object_foo_null |", "+------------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_json_get_union_is_null() { - let batches = run_query("select name, json_get(json_data, 'foo') is null from test") - .await - .unwrap(); + let sql = "select name, json_get(json_data, 'foo') is null from test"; let expected = [ "+------------------+----------------------------------------------+", @@ -1215,14 +1395,17 @@ async fn test_json_get_union_is_null() { "| invalid_json | true |", "+------------------+----------------------------------------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_json_get_union_is_not_null() { - let batches = run_query("select name, json_get(json_data, 'foo') is not null from test") - .await - .unwrap(); + let sql = "select name, json_get(json_data, 'foo') is not null from test"; let expected = [ "+------------------+--------------------------------------------------+", @@ -1237,15 +1420,17 @@ async fn test_json_get_union_is_not_null() { "| invalid_json | false |", "+------------------+--------------------------------------------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_arrow_union_is_null() { - let batches = run_query("select name, (json_data->'foo') is null from test") - .await - .unwrap(); - + let sql = "select name, (json_data->'foo') is null from test"; let expected = [ "+------------------+---------------------------------+", "| name | test.json_data -> 'foo' IS NULL |", @@ -1259,37 +1444,17 @@ async fn test_arrow_union_is_null() { "| invalid_json | true |", "+------------------+---------------------------------+", ]; - assert_batches_eq!(expected, &batches); -} - -#[tokio::test] -async fn test_arrow_union_is_null_dict_encoded() { - let batches = run_query_dict("select name, (json_data->'foo') is null from test") - .await - .unwrap(); - let expected = [ - "+------------------+---------------------------------+", - "| name | test.json_data -> 'foo' IS NULL |", - "+------------------+---------------------------------+", - "| object_foo | false |", - "| object_foo_array | false |", - "| object_foo_obj | false |", - "| object_foo_null | true |", - "| object_bar | true |", - "| list_foo | true |", - "| invalid_json | true |", - "+------------------+---------------------------------+", - ]; - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_arrow_union_is_not_null() { - let batches = run_query("select name, (json_data->'foo') is not null from test") - .await - .unwrap(); - + let sql = "select name, (json_data->'foo') is not null from test"; let expected = [ "+------------------+-------------------------------------+", "| name | test.json_data -> 'foo' IS NOT NULL |", @@ -1303,41 +1468,20 @@ async fn test_arrow_union_is_not_null() { "| invalid_json | false |", "+------------------+-------------------------------------+", ]; - assert_batches_eq!(expected, &batches); -} -#[tokio::test] -async fn test_arrow_union_is_not_null_dict_encoded() { - let batches = run_query_dict("select name, (json_data->'foo') is not null from test") - .await - .unwrap(); - - let expected = [ - "+------------------+-------------------------------------+", - "| name | test.json_data -> 'foo' IS NOT NULL |", - "+------------------+-------------------------------------+", - "| object_foo | true |", - "| object_foo_array | true |", - "| object_foo_obj | true |", - "| object_foo_null | false |", - "| object_bar | false |", - "| list_foo | false |", - "| invalid_json | false |", - "+------------------+-------------------------------------+", - ]; - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_arrow_scalar_union_is_null() { - let batches = run_query( - r#" + let sql = r#" select ('{"x": 1}'->'foo') is null as not_contains, ('{"foo": 1}'->'foo') is null as contains_num, - ('{"foo": null}'->'foo') is null as contains_null"#, - ) - .await - .unwrap(); + ('{"foo": null}'->'foo') is null as contains_null"#; let expected = [ "+--------------+--------------+---------------+", @@ -1346,12 +1490,17 @@ async fn test_arrow_scalar_union_is_null() { "| true | false | true |", "+--------------+--------------+---------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] async fn test_long_arrow_cast() { - let batches = run_query("select (json_data->>'foo')::int from other").await.unwrap(); + let sql = "select (json_data->>'foo')::int from other"; let expected = [ "+---------------------------+", @@ -1363,7 +1512,12 @@ async fn test_long_arrow_cast() { "| |", "+---------------------------+", ]; - assert_batches_eq!(expected, &batches); + + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -1387,8 +1541,11 @@ async fn test_dict_haystack() { "+-----------------------+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } fn check_for_null_dictionary_values(array: &dyn Array) { @@ -1467,8 +1624,11 @@ async fn test_dict_haystack_filter() { "+-------------------------+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -1485,8 +1645,11 @@ async fn test_dict_haystack_needle() { "+-------------+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -1504,8 +1667,11 @@ async fn test_dict_length() { "+---+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -1522,8 +1688,11 @@ async fn test_dict_contains() { "+-------+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -1538,8 +1707,11 @@ async fn test_dict_contains_where() { "+----------+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] @@ -1557,8 +1729,11 @@ async fn test_dict_get_int() { "+---+", ]; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } async fn build_dict_schema() -> SessionContext { @@ -1693,16 +1868,11 @@ async fn test_json_object_keys() { ]; let sql = "select json_object_keys(json_data) from test"; - let batches = run_query(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); - - let sql = "select json_object_keys(json_data) from test"; - let batches = run_query_dict(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); - - let sql = "select json_object_keys(json_data) from test"; - let batches = run_query_large(sql).await.unwrap(); - assert_batches_eq!(expected, &batches); + for_all_json_datatypes(async |dt| { + let batches = run_query_datatype(sql, dt).await.unwrap(); + assert_batches_eq!(expected, &batches); + }) + .await; } #[tokio::test] diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 541d223..9c0b1e2 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use datafusion::arrow::array::{ ArrayRef, DictionaryArray, Int32Array, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array, @@ -20,8 +20,13 @@ pub async fn create_context() -> Result { Ok(ctx) } +pub static DICT_TYPE: LazyLock = + LazyLock::new(|| DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))); +pub static LARGE_DICT_TYPE: LazyLock = + LazyLock::new(|| DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::LargeUtf8))); + #[expect(clippy::too_many_lines)] -async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result { +async fn create_test_table(json_data_type: &DataType) -> Result { let ctx = create_context().await?; let test_data = [ @@ -33,28 +38,41 @@ async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result>(); - let (mut json_data_type, mut json_array): (DataType, ArrayRef) = if large_utf8 { - (DataType::LargeUtf8, Arc::new(LargeStringArray::from(json_values))) - } else { - (DataType::Utf8, Arc::new(StringArray::from(json_values))) - }; + let json_values = test_data.iter().map(|(_, json)| *json); - if dict_encoded { - json_data_type = DataType::Dictionary(DataType::Int32.into(), json_data_type.into()); - json_array = Arc::new(DictionaryArray::::new( - Int32Array::from_iter_values(0..(i32::try_from(json_array.len()).expect("fits in a i32"))), - json_array, - )); - } + let json_array = match json_data_type { + DataType::Utf8 => Arc::new(StringArray::from_iter_values(json_values)) as ArrayRef, + DataType::LargeUtf8 => Arc::new(LargeStringArray::from_iter_values(json_values)), + DataType::Utf8View => Arc::new(StringViewArray::from_iter_values(json_values)), + DataType::Dictionary(key_type, _) if key_type.as_ref() != &DataType::Int32 => { + panic!("Only Int32 dictionary encoding is supported for JSON data in these tests") + } + DataType::Dictionary(key_type, child) + if key_type.as_ref() == &DataType::Int32 && child.as_ref() == &DataType::Utf8 => + { + Arc::new(DictionaryArray::::new( + Int32Array::from_iter_values(0..(i32::try_from(json_values.len()).expect("fits in a i32"))), + Arc::new(StringArray::from_iter_values(json_values)), + )) + } + DataType::Dictionary(key_type, child) + if key_type.as_ref() == &DataType::Int32 && child.as_ref() == &DataType::LargeUtf8 => + { + Arc::new(DictionaryArray::::new( + Int32Array::from_iter_values(0..(i32::try_from(json_values.len()).expect("fits in a i32"))), + Arc::new(LargeStringArray::from_iter_values(json_values)), + )) + } + _ => panic!("Unsupported JSON data type: {json_data_type}"), + }; let test_batch = RecordBatch::try_new( Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("json_data", json_data_type, false), + Field::new("name", DataType::Utf8View, false), + Field::new("json_data", json_data_type.clone(), false), ])), vec![ - Arc::new(StringArray::from( + Arc::new(StringViewArray::from( test_data.iter().map(|(name, _)| *name).collect::>(), )), json_array, @@ -220,29 +238,35 @@ async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result Result> { - let ctx = create_test_table(false, false).await?; - ctx.sql(sql).await?.collect().await + run_query_datatype(sql, &DataType::Utf8View).await } -pub async fn run_query_large(sql: &str) -> Result> { - let ctx = create_test_table(true, false).await?; - ctx.sql(sql).await?.collect().await -} - -pub async fn run_query_dict(sql: &str) -> Result> { - let ctx = create_test_table(false, true).await?; +pub async fn run_query_datatype(sql: &str, json_data_type: &DataType) -> Result> { + let ctx = create_test_table(json_data_type).await?; ctx.sql(sql).await?.collect().await } pub async fn run_query_params( sql: &str, - large_utf8: bool, + json_data_type: &DataType, query_values: impl Into, ) -> Result> { - let ctx = create_test_table(large_utf8, false).await?; + let ctx = create_test_table(json_data_type).await?; ctx.sql(sql).await?.with_param_values(query_values)?.collect().await } +pub async fn for_all_json_datatypes(f: impl AsyncFn(&DataType)) { + for dt in [ + &DataType::Utf8, + &DataType::LargeUtf8, + &DataType::Utf8View, + &DICT_TYPE, + &LARGE_DICT_TYPE, + ] { + f(dt).await; + } +} + pub async fn display_val(batch: Vec) -> (DataType, String) { assert_eq!(batch.len(), 1); let batch = batch.first().unwrap();