diff --git a/README.md b/README.md index 0c2a984..61e8713 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ SELECT id, json_col->'a' as json_col_a FROM test_table * [x] `json_get_float(json: str, *keys: str | int) -> float` - Get a float value from a JSON string by its "path" * [x] `json_get_bool(json: str, *keys: str | int) -> bool` - Get a boolean value from a JSON string by its "path" * [x] `json_get_json(json: str, *keys: str | int) -> str` - Get a nested raw JSON string from a JSON string by its "path" +* [x] `json_get_array(json: str, *keys: str | int) -> array` - Get an arrow array from a JSON string by its "path" * [x] `json_as_text(json: str, *keys: str | int) -> str` - Get any value from a JSON string by its "path", represented as a string (used for the `->>` operator) * [x] `json_length(json: str, *keys: str | int) -> int` - get the length of a JSON string or array diff --git a/src/json_get_array.rs b/src/json_get_array.rs new file mode 100644 index 0000000..0508a51 --- /dev/null +++ b/src/json_get_array.rs @@ -0,0 +1,134 @@ +use std::any::Any; +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, ListBuilder, StringBuilder}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use jiter::Peek; + +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common_macros::make_udf_function; + +make_udf_function!( + JsonGetArray, + json_get_array, + json_data path, + r#"Get an arrow array from a JSON string by its "path""# +); + +#[derive(Debug)] +pub(super) struct JsonGetArray { + signature: Signature, + aliases: [String; 1], +} + +impl Default for JsonGetArray { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: ["json_get_array".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGetArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + return_type_check( + arg_types, + self.name(), + DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new( + "item", + DataType::Utf8, + true, + ))), + ) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_array) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +#[derive(Debug)] +struct BuildArrayList; + +impl InvokeResult for BuildArrayList { + type Item = Vec; + + type Builder = ListBuilder; + + const ACCEPT_DICT_RETURN: bool = false; + + fn builder(capacity: usize) -> Self::Builder { + let values_builder = StringBuilder::new(); + ListBuilder::with_capacity(values_builder, capacity) + } + + fn append_value(builder: &mut Self::Builder, value: Option) { + builder.append_option(value.map(|v| v.into_iter().map(Some))); + } + + fn finish(mut builder: Self::Builder) -> DataFusionResult { + Ok(Arc::new(builder.finish())) + } + + fn scalar(value: Option) -> ScalarValue { + let mut builder = ListBuilder::new(StringBuilder::new()); + + if let Some(array_items) = value { + for item in array_items { + builder.values().append_value(item); + } + + builder.append(true); + } else { + builder.append(false); + } + let array = builder.finish(); + ScalarValue::List(Arc::new(array)) + } +} + +fn jiter_json_get_array(opt_json: Option<&str>, path: &[JsonPath]) -> Result, GetError> { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { + match peek { + Peek::Array => { + let mut peek_opt = jiter.known_array()?; + let mut array_items: Vec = Vec::new(); + + while let Some(element_peek) = peek_opt { + // Get the raw JSON slice for each array element + let start = jiter.current_index(); + jiter.known_skip(element_peek)?; + let slice = jiter.slice_to_current(start); + let element_str = std::str::from_utf8(slice)?.to_string(); + + array_items.push(element_str); + peek_opt = jiter.array_step()?; + } + + Ok(array_items) + } + _ => get_err!(), + } + } else { + get_err!() + } +} diff --git a/src/lib.rs b/src/lib.rs index cb0f25a..884fce5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ mod common_union; mod json_as_text; mod json_contains; mod json_get; +mod json_get_array; mod json_get_bool; mod json_get_float; mod json_get_int; @@ -26,6 +27,7 @@ pub mod functions { pub use crate::json_as_text::json_as_text; pub use crate::json_contains::json_contains; pub use crate::json_get::json_get; + pub use crate::json_get_array::json_get_array; pub use crate::json_get_bool::json_get_bool; pub use crate::json_get_float::json_get_float; pub use crate::json_get_int::json_get_int; @@ -39,6 +41,7 @@ pub mod udfs { pub use crate::json_as_text::json_as_text_udf; pub use crate::json_contains::json_contains_udf; pub use crate::json_get::json_get_udf; + pub use crate::json_get_array::json_get_array_udf; pub use crate::json_get_bool::json_get_bool_udf; pub use crate::json_get_float::json_get_float_udf; pub use crate::json_get_int::json_get_int_udf; @@ -64,6 +67,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { json_get_float::json_get_float_udf(), json_get_int::json_get_int_udf(), json_get_json::json_get_json_udf(), + json_get_array::json_get_array_udf(), json_as_text::json_as_text_udf(), json_get_str::json_get_str_udf(), json_contains::json_contains_udf(), diff --git a/tests/main.rs b/tests/main.rs index f591385..780b5fe 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -83,7 +83,7 @@ async fn test_json_get_union() { } #[tokio::test] -async fn test_json_get_array() { +async fn test_json_get_array_elem() { let sql = "select json_get('[1, 2, 3]', 2)"; let batches = run_query(sql).await.unwrap(); let (value_type, value_repr) = display_val(batches).await; @@ -91,6 +91,69 @@ async fn test_json_get_array() { assert_eq!(value_repr, "{int=3}"); } +#[tokio::test] +async fn test_json_get_array_basic_numbers() { + let sql = "select json_get_array('[1, 2, 3]')"; + let batches = run_query(sql).await.unwrap(); + let (value_type, value_repr) = display_val(batches).await; + assert!(matches!(value_type, DataType::List(_))); + assert_eq!(value_repr, "[1, 2, 3]"); +} + +#[tokio::test] +async fn test_json_get_array_mixed_types() { + let sql = r#"select json_get_array('["hello", 42, true, null, 3.14]')"#; + let batches = run_query(sql).await.unwrap(); + let (value_type, value_repr) = display_val(batches).await; + assert!(matches!(value_type, DataType::List(_))); + assert_eq!(value_repr, r#"["hello", 42, true, null, 3.14]"#); +} + +#[tokio::test] +async fn test_json_get_array_nested_objects() { + let sql = r#"select json_get_array('[{"name": "John"}, {"age": 30}]')"#; + let batches = run_query(sql).await.unwrap(); + let (value_type, value_repr) = display_val(batches).await; + assert!(matches!(value_type, DataType::List(_))); + assert_eq!(value_repr, r#"[{"name": "John"}, {"age": 30}]"#); +} + +#[tokio::test] +async fn test_json_get_array_nested_arrays() { + let sql = r#"select json_get_array('[[1, 2], [3, 4]]')"#; + let batches = run_query(sql).await.unwrap(); + let (value_type, value_repr) = display_val(batches).await; + assert!(matches!(value_type, DataType::List(_))); + assert_eq!(value_repr, "[[1, 2], [3, 4]]"); +} + +#[tokio::test] +async fn test_json_get_array_empty() { + let sql = "select json_get_array('[]')"; + let batches = run_query(sql).await.unwrap(); + let (value_type, value_repr) = display_val(batches).await; + assert!(matches!(value_type, DataType::List(_))); + assert_eq!(value_repr, "[]"); +} + +#[tokio::test] +async fn test_json_get_array_invalid_json() { + let sql = "select json_get_array('')"; + let batches = run_query(sql).await.unwrap(); + let (value_type, value_repr) = display_val(batches).await; + assert!(matches!(value_type, DataType::List(_))); + assert_eq!(value_repr, ""); +} + +#[tokio::test] +async fn test_json_get_array_with_path() { + let sql = r#"select json_get_array('{"items": [1, 2, 3]}', 'items')"#; + let batches = run_query(sql).await.unwrap(); + let (value_type, value_repr) = display_val(batches).await; + assert!(matches!(value_type, DataType::List(_))); + assert_eq!(value_repr, "[1, 2, 3]"); +} + #[tokio::test] async fn test_json_get_equals() { let e = run_query(r"select name, json_get(json_data, 'foo')='abc' from test")