Skip to content

Commit b84667e

Browse files
author
Fantayeneh Asres Gizaw
committed
Add json_extract_scalar
1 parent a1e99e4 commit b84667e

File tree

6 files changed

+187
-63
lines changed

6 files changed

+187
-63
lines changed

src/common.rs

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use std::str::Utf8Error;
22
use std::sync::Arc;
33

4+
use crate::common_union::{
5+
is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL,
6+
};
47
use datafusion::arrow::array::{
58
downcast_array, AnyDictionaryArray, Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, LargeStringArray,
69
PrimitiveArray, PrimitiveBuilder, RunArray, StringArray, StringViewArray,
@@ -11,10 +14,8 @@ use datafusion::arrow::datatypes::{ArrowNativeType, DataType, Int64Type, UInt64T
1114
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
1215
use datafusion::logical_expr::ColumnarValue;
1316
use jiter::{Jiter, JiterError, Peek};
14-
15-
use crate::common_union::{
16-
is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL,
17-
};
17+
use jsonpath_rust::parser::model::{Segment, Selector};
18+
use jsonpath_rust::parser::parse_json_path;
1819

1920
/// General implementation of `ScalarUDFImpl::return_type`.
2021
///
@@ -140,6 +141,22 @@ impl<'s> JsonPathArgs<'s> {
140141
}
141142
}
142143

144+
pub(crate) fn parse_jsonpath(path: &str) -> Vec<JsonPath<'static>> {
145+
let segments = parse_json_path(path).map(|it| it.segments).unwrap_or(Vec::new());
146+
147+
segments
148+
.into_iter()
149+
.map(|segment| match segment {
150+
Segment::Selector(s) => match s {
151+
Selector::Name(name) => JsonPath::Key(Box::leak(name.into_boxed_str())),
152+
Selector::Index(idx) => JsonPath::Index(idx as usize),
153+
_ => JsonPath::None,
154+
},
155+
_ => JsonPath::None,
156+
})
157+
.collect::<Vec<_>>()
158+
}
159+
143160
pub trait InvokeResult {
144161
type Item;
145162
type Builder;
@@ -585,3 +602,21 @@ fn mask_dictionary_keys(keys: &PrimitiveArray<Int64Type>, type_ids: &[i8]) -> Pr
585602
}
586603
PrimitiveArray::new(keys.values().clone(), Some(null_mask.into()))
587604
}
605+
606+
#[cfg(test)]
607+
mod tests {
608+
use super::*;
609+
use rstest::rstest;
610+
611+
// Test cases for parse_jsonpath
612+
#[rstest]
613+
#[case("$.a.aa", vec![JsonPath::Key("a"), JsonPath::Key("aa")])]
614+
#[case("$.a.ab[0].ac", vec![JsonPath::Key("a"), JsonPath::Key("ab"), JsonPath::Index(0), JsonPath::Key("ac")])]
615+
#[case("$.a.ab[1].ad", vec![JsonPath::Key("a"), JsonPath::Key("ab"), JsonPath::Index(1), JsonPath::Key("ad")])]
616+
#[case(r#"$.a["a b"].ad"#, vec![JsonPath::Key("a"), JsonPath::Key("\"a b\""), JsonPath::Key("ad")])]
617+
#[tokio::test]
618+
async fn test_parse_jsonpath(#[case] path: &str, #[case] expected: Vec<JsonPath<'static>>) {
619+
let result = parse_jsonpath(path);
620+
assert_eq!(result, expected);
621+
}
622+
}

src/json_extract.rs

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
use std::any::Any;
1+
use crate::common::{invoke, parse_jsonpath, return_type_check};
2+
use crate::common_macros::make_udf_function;
3+
use crate::json_get_json::jiter_json_get_json;
24
use datafusion::arrow::array::StringArray;
35
use datafusion::arrow::datatypes::{DataType, DataType::Utf8};
46
use datafusion::common::{exec_err, Result as DataFusionResult, ScalarValue};
57
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
6-
use jsonpath_rust::parser::model::{Segment, Selector};
7-
use jsonpath_rust::parser::parse_json_path;
8-
use crate::common::{invoke, return_type_check, JsonPath};
9-
use crate::common_macros::make_udf_function;
10-
use crate::json_get_json::jiter_json_get_json;
8+
use std::any::Any;
119

1210
make_udf_function!(
1311
JsonExtract,
@@ -65,55 +63,20 @@ impl ScalarUDFImpl for JsonExtract {
6563

6664
let path_str = match path_arg {
6765
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s,
68-
_ => return exec_err!("'{}' expects a valid JSONPath string (e.g., '$.key[0]') as second argument", self.name()),
66+
_ => {
67+
return exec_err!(
68+
"'{}' expects a valid JSONPath string (e.g., '$.key[0]') as second argument",
69+
self.name()
70+
)
71+
}
6972
};
7073

7174
let path = parse_jsonpath(path_str);
7275

73-
invoke::<StringArray>(&[json_arg.clone()], |json, _| {
74-
jiter_json_get_json(json, &path)
75-
})
76+
invoke::<StringArray>(&[json_arg.clone()], |json, _| jiter_json_get_json(json, &path))
7677
}
7778

7879
fn aliases(&self) -> &[String] {
7980
&self.aliases
8081
}
8182
}
82-
83-
fn parse_jsonpath(path: &str) -> Vec<JsonPath<'static>> {
84-
let segments = parse_json_path(path)
85-
.map(|it| it.segments)
86-
.unwrap_or(Vec::new());
87-
88-
segments.into_iter().map(|segment| {
89-
match segment {
90-
Segment::Selector(s) => match s {
91-
Selector::Name(name) => JsonPath::Key(Box::leak(name.into_boxed_str())),
92-
Selector::Index(idx) => JsonPath::Index(idx as usize),
93-
_ => JsonPath::None,
94-
},
95-
_ => JsonPath::None,
96-
}
97-
}).collect::<Vec<_>>()
98-
}
99-
100-
#[cfg(test)]
101-
mod tests {
102-
use rstest::rstest;
103-
use super::*;
104-
105-
// Test cases for parse_jsonpath
106-
#[rstest]
107-
#[case("$.a.aa", vec![JsonPath::Key("a"), JsonPath::Key("aa")])]
108-
#[case("$.a.ab[0].ac", vec![JsonPath::Key("a"), JsonPath::Key("ab"), JsonPath::Index(0), JsonPath::Key("ac")])]
109-
#[case("$.a.ab[1].ad", vec![JsonPath::Key("a"), JsonPath::Key("ab"), JsonPath::Index(1), JsonPath::Key("ad")])]
110-
#[case(r#"$.a["a b"].ad"#, vec![JsonPath::Key("a"), JsonPath::Key("\"a b\""), JsonPath::Key("ad")])]
111-
#[tokio::test]
112-
async fn test_parse_jsonpath(
113-
#[case] path: &str,
114-
#[case] expected: Vec<JsonPath<'static>>,
115-
) {
116-
let result = parse_jsonpath(path);
117-
assert_eq!(result, expected);
118-
}
119-
}

src/json_extract_scalar.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
use std::any::Any;
2+
3+
use datafusion::arrow::datatypes::DataType;
4+
use datafusion::common::{exec_err, Result as DataFusionResult};
5+
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
6+
use datafusion::scalar::ScalarValue;
7+
8+
use crate::common::parse_jsonpath;
9+
use crate::common::{invoke, return_type_check};
10+
use crate::common_macros::make_udf_function;
11+
use crate::common_union::JsonUnion;
12+
use crate::json_get::jiter_json_get_union;
13+
14+
make_udf_function!(
15+
JsonExtractScalar,
16+
json_extract_scalar,
17+
json_data path,
18+
r#"Get a value from a JSON string by its "path""#
19+
);
20+
21+
#[derive(Debug)]
22+
pub(super) struct JsonExtractScalar {
23+
signature: Signature,
24+
aliases: [String; 1],
25+
}
26+
27+
impl Default for JsonExtractScalar {
28+
fn default() -> Self {
29+
Self {
30+
signature: Signature::variadic_any(Volatility::Immutable),
31+
aliases: ["json_extract_scalar".to_string()],
32+
}
33+
}
34+
}
35+
36+
impl ScalarUDFImpl for JsonExtractScalar {
37+
fn as_any(&self) -> &dyn Any {
38+
self
39+
}
40+
41+
fn name(&self) -> &str {
42+
self.aliases[0].as_str()
43+
}
44+
45+
fn signature(&self) -> &Signature {
46+
&self.signature
47+
}
48+
49+
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
50+
return_type_check(arg_types, self.name(), JsonUnion::data_type())
51+
}
52+
53+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
54+
if args.args.len() != 2 {
55+
return exec_err!(
56+
"'{}' expects exactly 2 arguments (JSON data, path), got {}",
57+
self.name(),
58+
args.args.len()
59+
);
60+
}
61+
62+
let json_arg = &args.args[0];
63+
let path_arg = &args.args[1];
64+
65+
let path_str = match path_arg {
66+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s,
67+
_ => {
68+
return exec_err!(
69+
"'{}' expects a valid JSONPath string (e.g., '$.key[0]') as second argument",
70+
self.name()
71+
)
72+
}
73+
};
74+
75+
let path = parse_jsonpath(path_str);
76+
77+
invoke::<JsonUnion>(&[json_arg.clone()], |json, _| jiter_json_get_union(json, &path))
78+
}
79+
80+
fn aliases(&self) -> &[String] {
81+
&self.aliases
82+
}
83+
}

src/lib.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ mod common_macros;
1010
mod common_union;
1111
mod json_as_text;
1212
mod json_contains;
13+
mod json_extract;
14+
mod json_extract_scalar;
1315
mod json_get;
1416
mod json_get_bool;
1517
mod json_get_float;
@@ -19,15 +21,15 @@ mod json_get_str;
1921
mod json_length;
2022
mod json_object_keys;
2123
mod rewrite;
22-
mod json_extract;
2324

2425
pub use common_union::{JsonUnionEncoder, JsonUnionValue};
2526

2627
pub mod functions {
2728
pub use crate::json_as_text::json_as_text;
2829
pub use crate::json_contains::json_contains;
29-
pub use crate::json_get::json_get;
3030
pub use crate::json_extract::json_extract;
31+
pub use crate::json_extract_scalar::json_extract_scalar;
32+
pub use crate::json_get::json_get;
3133
pub use crate::json_get_bool::json_get_bool;
3234
pub use crate::json_get_float::json_get_float;
3335
pub use crate::json_get_int::json_get_int;
@@ -63,6 +65,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
6365
let functions: Vec<Arc<ScalarUDF>> = vec![
6466
json_get::json_get_udf(),
6567
json_extract::json_extract_udf(),
68+
json_extract_scalar::json_extract_scalar_udf(),
6669
json_get_bool::json_get_bool_udf(),
6770
json_get_float::json_get_float_udf(),
6871
json_get_int::json_get_int_udf(),

tests/json_extract_scalar_test.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use crate::utils::{display_val, run_query};
2+
use rstest::{fixture, rstest};
3+
4+
mod utils;
5+
6+
#[fixture]
7+
fn json_data() -> String {
8+
let json = r#"
9+
{
10+
"store": {
11+
"book name": "My Favorite Books",
12+
"book": [
13+
{"title": "1984", "author": "George Orwell"},
14+
{"title": "Pride and Prejudice", "author": "Jane Austen"}
15+
]
16+
}
17+
}
18+
"#;
19+
json.to_string()
20+
}
21+
22+
#[rstest]
23+
#[case("$.store.book[0].author", "{str=George Orwell}")]
24+
#[tokio::test]
25+
async fn test_json_extract_scalar(json_data: String, #[case] path: &str, #[case] expected: &str) {
26+
let result = json_extract_scalar(&json_data, path).await;
27+
assert_eq!(result, expected.to_string());
28+
}
29+
30+
#[rstest]
31+
#[case("[1, 2, 3]", "$[2]", "{int=3}")]
32+
#[case("[1, 2, 3]", "$[3]", "{null=}")]
33+
#[tokio::test]
34+
async fn test_json_extract_scalar_simple(#[case] json: String, #[case] path: &str, #[case] expected: &str) {
35+
let result = json_extract_scalar(&json, path).await;
36+
assert_eq!(result, expected.to_string());
37+
}
38+
39+
async fn json_extract_scalar(json: &str, path: &str) -> String {
40+
let sql = format!("select json_extract_scalar('{}', '{}')", json, path);
41+
let batches = run_query(sql.as_str()).await.unwrap();
42+
display_val(batches).await.1
43+
}

tests/json_extract_test.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use rstest::{fixture, rstest};
21
use crate::utils::{display_val, run_query};
2+
use rstest::{fixture, rstest};
33

44
mod utils;
55

@@ -10,13 +10,12 @@ fn json_data() -> String {
1010
}
1111

1212
#[rstest]
13-
#[case("$.a.ab", "[{\"ac\": \"Dune\", \"ca\": \"Frank Herbert\"},{\"ad\": \"Foundation\", \"da\": \"Isaac Asimov\"}]")]
13+
#[case(
14+
"$.a.ab",
15+
"[{\"ac\": \"Dune\", \"ca\": \"Frank Herbert\"},{\"ad\": \"Foundation\", \"da\": \"Isaac Asimov\"}]"
16+
)]
1417
#[tokio::test]
15-
async fn test_json_paths(
16-
json_data: String,
17-
#[case] path: &str,
18-
#[case] expected: &str,
19-
) {
18+
async fn test_json_paths(json_data: String, #[case] path: &str, #[case] expected: &str) {
2019
let result = json_extract(&json_data, path).await;
2120
assert_eq!(result, expected.to_string());
2221
}
@@ -29,10 +28,8 @@ async fn test_invalid_json_path(json_data: String) {
2928
assert_eq!(result, "".to_string());
3029
}
3130

32-
3331
async fn json_extract(json: &str, path: &str) -> String {
3432
let sql = format!("select json_extract('{}', '{}')", json, path);
3533
let batches = run_query(sql.as_str()).await.unwrap();
3634
display_val(batches).await.1
3735
}
38-

0 commit comments

Comments
 (0)