11use std:: sync:: Arc ;
2+ use std:: sync:: LazyLock ;
23
34use datafusion:: arrow:: datatypes:: DataType ;
45use datafusion:: common:: config:: ConfigOptions ;
@@ -11,8 +12,12 @@ use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
1112use datafusion:: logical_expr:: planner:: { ExprPlanner , PlannerResult , RawBinaryExpr } ;
1213use datafusion:: logical_expr:: sqlparser:: ast:: BinaryOperator ;
1314use datafusion:: logical_expr:: ScalarUDF ;
15+ use datafusion:: logical_expr:: ScalarUDFImpl ;
1416use datafusion:: scalar:: ScalarValue ;
1517
18+ use crate :: common:: Sortedness ;
19+ use crate :: json_get:: JsonGet ;
20+
1621#[ derive( Debug ) ]
1722pub ( crate ) struct JsonFunctionRewriter ;
1823
@@ -31,11 +36,14 @@ impl FunctionRewrite for JsonFunctionRewriter {
3136 }
3237}
3338
39+ static JSON_GET_FUNC_NAMES : LazyLock < Vec < String > > =
40+ LazyLock :: new ( || Sortedness :: iter ( ) . map ( |s| JsonGet :: new ( s) . name ( ) . to_string ( ) ) . collect ( ) ) ;
41+
3442/// This replaces `get_json(foo, bar)::int` with `json_get_int(foo, bar)` so the JSON function can take care of
3543/// extracting the right value type from JSON without the need to materialize the JSON union.
3644fn optimise_json_get_cast ( cast : & Cast ) -> Option < Transformed < Expr > > {
3745 let scalar_func = extract_scalar_function ( & cast. expr ) ?;
38- if scalar_func. func . name ( ) != "json_get" {
46+ if ! JSON_GET_FUNC_NAMES . contains ( & scalar_func. func . name ( ) . to_owned ( ) ) {
3947 return None ;
4048 }
4149 let func = match & cast. data_type {
@@ -53,18 +61,24 @@ fn optimise_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
5361 } ) ) )
5462}
5563
64+ static JSON_FUNCTION_NAMES : LazyLock < Vec < String > > = LazyLock :: new ( || {
65+ Sortedness :: iter ( )
66+ . flat_map ( |s| {
67+ [
68+ crate :: json_get:: JsonGet :: new ( s) . name ( ) . to_string ( ) ,
69+ crate :: json_get_bool:: JsonGetBool :: new ( s) . name ( ) . to_string ( ) ,
70+ crate :: json_get_float:: JsonGetFloat :: new ( s) . name ( ) . to_string ( ) ,
71+ crate :: json_get_int:: JsonGetInt :: new ( s) . name ( ) . to_string ( ) ,
72+ crate :: json_get_str:: JsonGetStr :: new ( s) . name ( ) . to_string ( ) ,
73+ crate :: json_as_text:: JsonAsText :: new ( s) . name ( ) . to_string ( ) ,
74+ ]
75+ } )
76+ . collect ( )
77+ } ) ;
78+
5679// Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')`
5780fn unnest_json_calls ( func : & ScalarFunction ) -> Option < Transformed < Expr > > {
58- if !matches ! (
59- func. func. name( ) ,
60- "json_get"
61- | "json_get_bool"
62- | "json_get_float"
63- | "json_get_int"
64- | "json_get_json"
65- | "json_get_str"
66- | "json_as_text"
67- ) {
81+ if !JSON_FUNCTION_NAMES . contains ( & func. func . name ( ) . to_owned ( ) ) {
6882 return None ;
6983 }
7084 let mut outer_args_iter = func. args . iter ( ) ;
0 commit comments