Skip to content

Commit b964ff4

Browse files
committed
fix optimizations
1 parent d5c84c2 commit b964ff4

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

src/common.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,14 @@ pub(crate) enum Sortedness {
612612
Recursive,
613613
}
614614

615+
impl Sortedness {
616+
pub(crate) fn iter() -> impl Iterator<Item = Self> {
617+
[Sortedness::Unspecified, Sortedness::TopLevel, Sortedness::Recursive]
618+
.iter()
619+
.copied()
620+
}
621+
}
622+
615623
impl Sortedness {
616624
pub(crate) fn function_name_suffix(self) -> &'static str {
617625
match self {

src/rewrite.rs

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::sync::Arc;
2+
use std::sync::LazyLock;
23

34
use datafusion::arrow::datatypes::DataType;
45
use datafusion::common::config::ConfigOptions;
@@ -11,8 +12,12 @@ use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
1112
use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr};
1213
use datafusion::logical_expr::sqlparser::ast::BinaryOperator;
1314
use datafusion::logical_expr::ScalarUDF;
15+
use datafusion::logical_expr::ScalarUDFImpl;
1416
use datafusion::scalar::ScalarValue;
1517

18+
use crate::common::Sortedness;
19+
use crate::json_get::JsonGet;
20+
1621
#[derive(Debug)]
1722
pub(crate) struct JsonFunctionRewriter;
1823

@@ -31,11 +36,14 @@ impl FunctionRewrite for JsonFunctionRewriter {
3136
}
3237
}
3338

39+
static JSON_GET_FUNC_NAMES: LazyLock<Vec<String>> =
40+
LazyLock::new(|| Sortedness::iter().map(|s| JsonGet::new(s).name().to_string()).collect());
41+
3442
/// This replaces `get_json(foo, bar)::int` with `json_get_int(foo, bar)` so the JSON function can take care of
3543
/// extracting the right value type from JSON without the need to materialize the JSON union.
3644
fn optimise_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
3745
let scalar_func = extract_scalar_function(&cast.expr)?;
38-
if scalar_func.func.name() != "json_get" {
46+
if !JSON_GET_FUNC_NAMES.contains(&scalar_func.func.name().to_owned()) {
3947
return None;
4048
}
4149
let func = match &cast.data_type {
@@ -53,18 +61,24 @@ fn optimise_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
5361
})))
5462
}
5563

64+
static JSON_FUNCTION_NAMES: LazyLock<Vec<String>> = LazyLock::new(|| {
65+
Sortedness::iter()
66+
.flat_map(|s| {
67+
[
68+
crate::json_get::JsonGet::new(s).name().to_string(),
69+
crate::json_get_bool::JsonGetBool::new(s).name().to_string(),
70+
crate::json_get_float::JsonGetFloat::new(s).name().to_string(),
71+
crate::json_get_int::JsonGetInt::new(s).name().to_string(),
72+
crate::json_get_str::JsonGetStr::new(s).name().to_string(),
73+
crate::json_as_text::JsonAsText::new(s).name().to_string(),
74+
]
75+
})
76+
.collect()
77+
});
78+
5679
// Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')`
5780
fn unnest_json_calls(func: &ScalarFunction) -> Option<Transformed<Expr>> {
58-
if !matches!(
59-
func.func.name(),
60-
"json_get"
61-
| "json_get_bool"
62-
| "json_get_float"
63-
| "json_get_int"
64-
| "json_get_json"
65-
| "json_get_str"
66-
| "json_as_text"
67-
) {
81+
if !JSON_FUNCTION_NAMES.contains(&func.func.name().to_owned()) {
6882
return None;
6983
}
7084
let mut outer_args_iter = func.args.iter();

tests/main.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@ async fn test_json_get_cast_int() {
346346
let batches = run_query(sql).await.unwrap();
347347
assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string()));
348348

349+
let sql = r#"select json_get_top_level_sorted('{"foo": 42}', 'foo')::int"#;
350+
let batches = run_query(sql).await.unwrap();
351+
assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string()));
352+
349353
// floats not allowed
350354
let sql = r#"select json_get('{"foo": 4.2}', 'foo')::int"#;
351355
let batches = run_query(sql).await.unwrap();
@@ -400,13 +404,21 @@ async fn test_json_get_cast_float() {
400404
let sql = r#"select json_get('{"foo": 4.2e2}', 'foo')::float"#;
401405
let batches = run_query(sql).await.unwrap();
402406
assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string()));
407+
408+
let sql = r#"select json_get_top_level_sorted('{"foo": 4.2e2}', 'foo')::float"#;
409+
let batches = run_query(sql).await.unwrap();
410+
assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string()));
403411
}
404412

405413
#[tokio::test]
406414
async fn test_json_get_cast_numeric() {
407415
let sql = r#"select json_get('{"foo": 4.2e2}', 'foo')::numeric"#;
408416
let batches = run_query(sql).await.unwrap();
409417
assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string()));
418+
419+
let sql = r#"select json_get_top_level_sorted('{"foo": 4.2e2}', 'foo')::numeric"#;
420+
let batches = run_query(sql).await.unwrap();
421+
assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string()));
410422
}
411423

412424
#[tokio::test]

0 commit comments

Comments
 (0)