Skip to content

Commit 2a7c5b2

Browse files
authored
implement json_object_key (and alias json_keys) (#50)
1 parent 8a758fa commit 2a7c5b2

File tree

3 files changed

+228
-0
lines changed

3 files changed

+228
-0
lines changed

src/json_object_keys.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
use std::any::Any;
2+
use std::sync::Arc;
3+
4+
use datafusion::arrow::array::{ArrayRef, ListArray, ListBuilder, StringBuilder};
5+
use datafusion::arrow::datatypes::{DataType, Field};
6+
use datafusion::common::{Result as DataFusionResult, ScalarValue};
7+
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
8+
use jiter::Peek;
9+
10+
use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
11+
use crate::common_macros::make_udf_function;
12+
13+
make_udf_function!(
14+
JsonObjectKeys,
15+
json_object_keys,
16+
json_data path,
17+
r#"Get the keys of a JSON object as an array."#
18+
);
19+
20+
#[derive(Debug)]
21+
pub(super) struct JsonObjectKeys {
22+
signature: Signature,
23+
aliases: [String; 2],
24+
}
25+
26+
impl Default for JsonObjectKeys {
27+
fn default() -> Self {
28+
Self {
29+
signature: Signature::variadic_any(Volatility::Immutable),
30+
aliases: ["json_object_keys".to_string(), "json_keys".to_string()],
31+
}
32+
}
33+
}
34+
35+
impl ScalarUDFImpl for JsonObjectKeys {
36+
fn as_any(&self) -> &dyn Any {
37+
self
38+
}
39+
40+
fn name(&self) -> &str {
41+
self.aliases[0].as_str()
42+
}
43+
44+
fn signature(&self) -> &Signature {
45+
&self.signature
46+
}
47+
48+
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
49+
return_type_check(
50+
arg_types,
51+
self.name(),
52+
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
53+
)
54+
}
55+
56+
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
57+
invoke::<ListArrayWrapper, Vec<String>>(
58+
args,
59+
jiter_json_object_keys,
60+
|w| Ok(Arc::new(w.0) as ArrayRef),
61+
keys_to_scalar,
62+
true,
63+
)
64+
}
65+
66+
fn aliases(&self) -> &[String] {
67+
&self.aliases
68+
}
69+
}
70+
71+
/// Wrapper for a `ListArray` that allows us to implement `FromIterator<Option<Vec<String>>>` as required.
72+
#[derive(Debug)]
73+
struct ListArrayWrapper(ListArray);
74+
75+
impl FromIterator<Option<Vec<String>>> for ListArrayWrapper {
76+
fn from_iter<I: IntoIterator<Item = Option<Vec<String>>>>(iter: I) -> Self {
77+
let values_builder = StringBuilder::new();
78+
let mut builder = ListBuilder::new(values_builder);
79+
for opt_keys in iter {
80+
if let Some(keys) = opt_keys {
81+
for value in keys {
82+
builder.values().append_value(value);
83+
}
84+
builder.append(true);
85+
} else {
86+
builder.append(false);
87+
}
88+
}
89+
Self(builder.finish())
90+
}
91+
}
92+
93+
fn keys_to_scalar(opt_keys: Option<Vec<String>>) -> ScalarValue {
94+
let values_builder = StringBuilder::new();
95+
let mut builder = ListBuilder::new(values_builder);
96+
if let Some(keys) = opt_keys {
97+
for value in keys {
98+
builder.values().append_value(value);
99+
}
100+
builder.append(true);
101+
} else {
102+
builder.append(false);
103+
}
104+
let array = builder.finish();
105+
ScalarValue::List(Arc::new(array))
106+
}
107+
108+
fn jiter_json_object_keys(opt_json: Option<&str>, path: &[JsonPath]) -> Result<Vec<String>, GetError> {
109+
if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) {
110+
match peek {
111+
Peek::Object => {
112+
let mut opt_key = jiter.known_object()?;
113+
114+
let mut keys = Vec::new();
115+
while let Some(key) = opt_key {
116+
keys.push(key.to_string());
117+
jiter.next_skip()?;
118+
opt_key = jiter.next_key()?;
119+
}
120+
Ok(keys)
121+
}
122+
_ => get_err!(),
123+
}
124+
} else {
125+
get_err!()
126+
}
127+
}

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ mod json_get_int;
1717
mod json_get_json;
1818
mod json_get_str;
1919
mod json_length;
20+
mod json_object_keys;
2021
mod rewrite;
2122

2223
pub use common_union::{JsonUnionEncoder, JsonUnionValue};
@@ -31,6 +32,7 @@ pub mod functions {
3132
pub use crate::json_get_json::json_get_json;
3233
pub use crate::json_get_str::json_get_str;
3334
pub use crate::json_length::json_length;
35+
pub use crate::json_object_keys::json_object_keys;
3436
}
3537

3638
pub mod udfs {
@@ -43,6 +45,7 @@ pub mod udfs {
4345
pub use crate::json_get_json::json_get_json_udf;
4446
pub use crate::json_get_str::json_get_str_udf;
4547
pub use crate::json_length::json_length_udf;
48+
pub use crate::json_object_keys::json_object_keys_udf;
4649
}
4750

4851
/// Register all JSON UDFs, and [`rewrite::JsonFunctionRewriter`] with the provided [`FunctionRegistry`].
@@ -65,6 +68,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
6568
json_get_str::json_get_str_udf(),
6669
json_contains::json_contains_udf(),
6770
json_length::json_length_udf(),
71+
json_object_keys::json_object_keys_udf(),
6872
];
6973
functions.into_iter().try_for_each(|udf| {
7074
let existing_udf = registry.register_udf(udf)?;

tests/main.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,3 +1432,100 @@ async fn test_dict_filter_contains() {
14321432

14331433
assert_batches_eq!(expected, &batches);
14341434
}
1435+
1436+
#[tokio::test]
1437+
async fn test_json_object_keys() {
1438+
let expected = [
1439+
"+----------------------------------+",
1440+
"| json_object_keys(test.json_data) |",
1441+
"+----------------------------------+",
1442+
"| [foo] |",
1443+
"| [foo] |",
1444+
"| [foo] |",
1445+
"| [foo] |",
1446+
"| [bar] |",
1447+
"| |",
1448+
"| |",
1449+
"+----------------------------------+",
1450+
];
1451+
1452+
let sql = "select json_object_keys(json_data) from test";
1453+
let batches = run_query(sql).await.unwrap();
1454+
assert_batches_eq!(expected, &batches);
1455+
1456+
let sql = "select json_object_keys(json_data) from test";
1457+
let batches = run_query_dict(sql).await.unwrap();
1458+
assert_batches_eq!(expected, &batches);
1459+
1460+
let sql = "select json_object_keys(json_data) from test";
1461+
let batches = run_query_large(sql).await.unwrap();
1462+
assert_batches_eq!(expected, &batches);
1463+
}
1464+
1465+
#[tokio::test]
1466+
async fn test_json_object_keys_many() {
1467+
let expected = [
1468+
"+-----------------------+",
1469+
"| v |",
1470+
"+-----------------------+",
1471+
"| [foo, bar, spam, ham] |",
1472+
"+-----------------------+",
1473+
];
1474+
1475+
let sql = r#"select json_object_keys('{"foo": 1, "bar": 2.2, "spam": true, "ham": []}') as v"#;
1476+
let batches = run_query(sql).await.unwrap();
1477+
assert_batches_eq!(expected, &batches);
1478+
}
1479+
1480+
#[tokio::test]
1481+
async fn test_json_object_keys_nested() {
1482+
let json = r#"'{"foo": [{"bar": {"spam": true, "ham": []}}]}'"#;
1483+
1484+
let sql = format!("select json_object_keys({json}) as v");
1485+
let batches = run_query(&sql).await.unwrap();
1486+
#[rustfmt::skip]
1487+
let expected = [
1488+
"+-------+",
1489+
"| v |",
1490+
"+-------+",
1491+
"| [foo] |",
1492+
"+-------+",
1493+
];
1494+
assert_batches_eq!(expected, &batches);
1495+
1496+
let sql = format!("select json_object_keys({json}, 'foo') as v");
1497+
let batches = run_query(&sql).await.unwrap();
1498+
#[rustfmt::skip]
1499+
let expected = [
1500+
"+---+",
1501+
"| v |",
1502+
"+---+",
1503+
"| |",
1504+
"+---+",
1505+
];
1506+
assert_batches_eq!(expected, &batches);
1507+
1508+
let sql = format!("select json_object_keys({json}, 'foo', 0) as v");
1509+
let batches = run_query(&sql).await.unwrap();
1510+
#[rustfmt::skip]
1511+
let expected = [
1512+
"+-------+",
1513+
"| v |",
1514+
"+-------+",
1515+
"| [bar] |",
1516+
"+-------+",
1517+
];
1518+
assert_batches_eq!(expected, &batches);
1519+
1520+
let sql = format!("select json_object_keys({json}, 'foo', 0, 'bar') as v");
1521+
let batches = run_query(&sql).await.unwrap();
1522+
#[rustfmt::skip]
1523+
let expected = [
1524+
"+-------------+",
1525+
"| v |",
1526+
"+-------------+",
1527+
"| [spam, ham] |",
1528+
"+-------------+",
1529+
];
1530+
assert_batches_eq!(expected, &batches);
1531+
}

0 commit comments

Comments
 (0)