diff --git a/src/rewrite.rs b/src/rewrite.rs index 58fc7d4..72637a4 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -1,12 +1,17 @@ +use std::sync::Arc; + use datafusion::arrow::datatypes::DataType; use datafusion::common::config::ConfigOptions; use datafusion::common::tree_node::Transformed; +use datafusion::common::Column; use datafusion::common::DFSchema; use datafusion::common::Result; use datafusion::logical_expr::expr::{Alias, Cast, Expr, ScalarFunction}; use datafusion::logical_expr::expr_rewriter::FunctionRewrite; use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; use datafusion::logical_expr::sqlparser::ast::BinaryOperator; +use datafusion::logical_expr::ScalarUDF; +use datafusion::scalar::ScalarValue; #[derive(Debug)] pub(crate) struct JsonFunctionRewriter; @@ -93,27 +98,91 @@ fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> { } } +#[derive(Debug, Clone, Copy)] +enum JsonOperator { + Arrow, + LongArrow, + Question, +} + +impl TryFrom<&BinaryOperator> for JsonOperator { + type Error = (); + + fn try_from(op: &BinaryOperator) -> Result { + match op { + BinaryOperator::Arrow => Ok(JsonOperator::Arrow), + BinaryOperator::LongArrow => Ok(JsonOperator::LongArrow), + BinaryOperator::Question => Ok(JsonOperator::Question), + _ => Err(()), + } + } +} + +impl From for Arc { + fn from(op: JsonOperator) -> Arc { + match op { + JsonOperator::Arrow => crate::udfs::json_get_udf(), + JsonOperator::LongArrow => crate::udfs::json_as_text_udf(), + JsonOperator::Question => crate::udfs::json_contains_udf(), + } + } +} + +impl std::fmt::Display for JsonOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + JsonOperator::Arrow => write!(f, "->"), + JsonOperator::LongArrow => write!(f, "->>"), + JsonOperator::Question => write!(f, "?"), + } + } +} + +/// Convert an Expr to a String representatiion for use in alias names. +fn expr_to_sql_repr(expr: &Expr) -> String { + match expr { + Expr::Column(Column { name, relation }) => relation + .as_ref() + .map_or_else(|| name.clone(), |r| format!("{r}.{name}")), + Expr::Alias(alias) => alias.name.clone(), + Expr::Literal(scalar) => match scalar { + ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => { + format!("'{v}'") + } + ScalarValue::UInt8(Some(v)) => v.to_string(), + ScalarValue::UInt16(Some(v)) => v.to_string(), + ScalarValue::UInt32(Some(v)) => v.to_string(), + ScalarValue::UInt64(Some(v)) => v.to_string(), + ScalarValue::Int8(Some(v)) => v.to_string(), + ScalarValue::Int16(Some(v)) => v.to_string(), + ScalarValue::Int32(Some(v)) => v.to_string(), + ScalarValue::Int64(Some(v)) => v.to_string(), + _ => scalar.to_string(), + }, + Expr::Cast(cast) => expr_to_sql_repr(&cast.expr), + _ => expr.to_string(), + } +} + /// Implement a custom SQL planner to replace postgres JSON operators with custom UDFs #[derive(Debug, Default)] pub struct JsonExprPlanner; impl ExprPlanner for JsonExprPlanner { fn plan_binary_op(&self, expr: RawBinaryExpr, _schema: &DFSchema) -> Result> { - let (func, op_display) = match &expr.op { - BinaryOperator::Arrow => (crate::json_get::json_get_udf(), "->"), - BinaryOperator::LongArrow => (crate::json_as_text::json_as_text_udf(), "->>"), - BinaryOperator::Question => (crate::json_contains::json_contains_udf(), "?"), - _ => return Ok(PlannerResult::Original(expr)), - }; - let alias_name = match &expr.left { - Expr::Alias(alias) => format!("{} {} {}", alias.name, op_display, expr.right), - left_expr => format!("{} {} {}", left_expr, op_display, expr.right), + let Ok(op) = JsonOperator::try_from(&expr.op) else { + return Ok(PlannerResult::Original(expr)); }; + let left_repr = expr_to_sql_repr(&expr.left); + let right_repr = expr_to_sql_repr(&expr.right); + + let alias_name = format!("{left_repr} {op} {right_repr}"); + // we put the alias in so that default column titles are `foo -> bar` instead of `json_get(foo, bar)` Ok(PlannerResult::Planned(Expr::Alias(Alias::new( Expr::ScalarFunction(ScalarFunction { - func, + func: op.into(), args: vec![expr.left, expr.right], }), None::<&str>, diff --git a/tests/main.rs b/tests/main.rs index 0019e1e..12c75f3 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -738,17 +738,17 @@ async fn test_arrow() { let batches = run_query("select name, json_data->'foo' from test").await.unwrap(); let expected = [ - "+------------------+-------------------------------+", - "| name | test.json_data -> Utf8(\"foo\") |", - "+------------------+-------------------------------+", - "| object_foo | {str=abc} |", - "| object_foo_array | {array=[1]} |", - "| object_foo_obj | {object={}} |", - "| object_foo_null | {null=} |", - "| object_bar | {null=} |", - "| list_foo | {null=} |", - "| invalid_json | {null=} |", - "+------------------+-------------------------------+", + "+------------------+-------------------------+", + "| name | test.json_data -> 'foo' |", + "+------------------+-------------------------+", + "| object_foo | {str=abc} |", + "| object_foo_array | {array=[1]} |", + "| object_foo_obj | {object={}} |", + "| object_foo_null | {null=} |", + "| object_bar | {null=} |", + "| list_foo | {null=} |", + "| invalid_json | {null=} |", + "+------------------+-------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -758,7 +758,7 @@ async fn test_plan_arrow() { let lines = logical_plan(r"explain select json_data->'foo' from test").await; let expected = [ - "Projection: json_get(test.json_data, Utf8(\"foo\")) AS test.json_data -> Utf8(\"foo\")", + "Projection: json_get(test.json_data, Utf8(\"foo\")) AS test.json_data -> 'foo'", " TableScan: test projection=[json_data]", ]; @@ -770,17 +770,17 @@ async fn test_long_arrow() { let batches = run_query("select name, json_data->>'foo' from test").await.unwrap(); let expected = [ - "+------------------+--------------------------------+", - "| name | test.json_data ->> Utf8(\"foo\") |", - "+------------------+--------------------------------+", - "| object_foo | abc |", - "| object_foo_array | [1] |", - "| object_foo_obj | {} |", - "| object_foo_null | |", - "| object_bar | |", - "| list_foo | |", - "| invalid_json | |", - "+------------------+--------------------------------+", + "+------------------+--------------------------+", + "| name | test.json_data ->> 'foo' |", + "+------------------+--------------------------+", + "| object_foo | abc |", + "| object_foo_array | [1] |", + "| object_foo_obj | {} |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+--------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -790,7 +790,7 @@ async fn test_plan_long_arrow() { let lines = logical_plan(r"explain select json_data->>'foo' from test").await; let expected = [ - "Projection: json_as_text(test.json_data, Utf8(\"foo\")) AS test.json_data ->> Utf8(\"foo\")", + "Projection: json_as_text(test.json_data, Utf8(\"foo\")) AS test.json_data ->> 'foo'", " TableScan: test projection=[json_data]", ]; @@ -804,32 +804,49 @@ async fn test_long_arrow_eq_str() { .unwrap(); let expected = [ - "+------------------+----------------------------------------------+", - "| name | test.json_data ->> Utf8(\"foo\") = Utf8(\"abc\") |", - "+------------------+----------------------------------------------+", - "| object_foo | true |", - "| object_foo_array | false |", - "| object_foo_obj | false |", - "| object_foo_null | |", - "| object_bar | |", - "| list_foo | |", - "| invalid_json | |", - "+------------------+----------------------------------------------+", + "+------------------+----------------------------------------+", + "| name | test.json_data ->> 'foo' = Utf8(\"abc\") |", + "+------------------+----------------------------------------+", + "| object_foo | true |", + "| object_foo_array | false |", + "| object_foo_obj | false |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+----------------------------------------+", ]; assert_batches_eq!(expected, &batches); } +/// Test column name / alias creation with a cast in the needle / key +#[tokio::test] +async fn test_arrow_cast_key_text() { + let sql = r#"select ('{"foo": 42}'->>('foo'::text))"#; + let batches = run_query(sql).await.unwrap(); + + let expected = [ + "+-------------------------+", + "| '{\"foo\": 42}' ->> 'foo' |", + "+-------------------------+", + "| 42 |", + "+-------------------------+", + ]; + + assert_batches_eq!(expected, &batches); +} + #[tokio::test] async fn test_arrow_cast_int() { let sql = r#"select ('{"foo": 42}'->'foo')::int"#; let batches = run_query(sql).await.unwrap(); let expected = [ - "+------------------------------------+", - "| Utf8(\"{\"foo\": 42}\") -> Utf8(\"foo\") |", - "+------------------------------------+", - "| 42 |", - "+------------------------------------+", + "+------------------------+", + "| '{\"foo\": 42}' -> 'foo' |", + "+------------------------+", + "| 42 |", + "+------------------------+", ]; assert_batches_eq!(expected, &batches); @@ -841,7 +858,7 @@ async fn test_plan_arrow_cast_int() { let lines = logical_plan(r"explain select (json_data->'foo')::int from test").await; let expected = [ - "Projection: json_get_int(test.json_data, Utf8(\"foo\")) AS test.json_data -> Utf8(\"foo\")", + "Projection: json_get_int(test.json_data, Utf8(\"foo\")) AS test.json_data -> 'foo'", " TableScan: test projection=[json_data]", ]; @@ -853,17 +870,17 @@ async fn test_arrow_double_nested() { let batches = run_query("select name, json_data->'foo'->0 from test").await.unwrap(); let expected = [ - "+------------------+-------------------------------------------+", - "| name | test.json_data -> Utf8(\"foo\") -> Int64(0) |", - "+------------------+-------------------------------------------+", - "| object_foo | {null=} |", - "| object_foo_array | {int=1} |", - "| object_foo_obj | {null=} |", - "| object_foo_null | {null=} |", - "| object_bar | {null=} |", - "| list_foo | {null=} |", - "| invalid_json | {null=} |", - "+------------------+-------------------------------------------+", + "+------------------+------------------------------+", + "| name | test.json_data -> 'foo' -> 0 |", + "+------------------+------------------------------+", + "| object_foo | {null=} |", + "| object_foo_array | {int=1} |", + "| object_foo_obj | {null=} |", + "| object_foo_null | {null=} |", + "| object_bar | {null=} |", + "| list_foo | {null=} |", + "| invalid_json | {null=} |", + "+------------------+------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -873,7 +890,7 @@ async fn test_plan_arrow_double_nested() { let lines = logical_plan(r"explain select json_data->'foo'->0 from test").await; let expected = [ - "Projection: json_get(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data -> Utf8(\"foo\") -> Int64(0)", + "Projection: json_get(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data -> 'foo' -> 0", " TableScan: test projection=[json_data]", ]; @@ -885,17 +902,17 @@ async fn test_double_arrow_double_nested() { let batches = run_query("select name, json_data->>'foo'->>0 from test").await.unwrap(); let expected = [ - "+------------------+---------------------------------------------+", - "| name | test.json_data ->> Utf8(\"foo\") ->> Int64(0) |", - "+------------------+---------------------------------------------+", - "| object_foo | |", - "| object_foo_array | 1 |", - "| object_foo_obj | |", - "| object_foo_null | |", - "| object_bar | |", - "| list_foo | |", - "| invalid_json | |", - "+------------------+---------------------------------------------+", + "+------------------+--------------------------------+", + "| name | test.json_data ->> 'foo' ->> 0 |", + "+------------------+--------------------------------+", + "| object_foo | |", + "| object_foo_array | 1 |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+--------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -905,7 +922,7 @@ async fn test_plan_double_arrow_double_nested() { let lines = logical_plan(r"explain select json_data->>'foo'->>0 from test").await; let expected = [ - "Projection: json_as_text(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data ->> Utf8(\"foo\") ->> Int64(0)", + "Projection: json_as_text(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data ->> 'foo' ->> 0", " TableScan: test projection=[json_data]", ]; @@ -919,17 +936,17 @@ async fn test_arrow_double_nested_cast() { .unwrap(); let expected = [ - "+------------------+-------------------------------------------+", - "| name | test.json_data -> Utf8(\"foo\") -> Int64(0) |", - "+------------------+-------------------------------------------+", - "| object_foo | |", - "| object_foo_array | 1 |", - "| object_foo_obj | |", - "| object_foo_null | |", - "| object_bar | |", - "| list_foo | |", - "| invalid_json | |", - "+------------------+-------------------------------------------+", + "+------------------+------------------------------+", + "| name | test.json_data -> 'foo' -> 0 |", + "+------------------+------------------------------+", + "| object_foo | |", + "| object_foo_array | 1 |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -939,7 +956,7 @@ async fn test_plan_arrow_double_nested_cast() { let lines = logical_plan(r"explain select (json_data->'foo'->0)::int from test").await; let expected = [ - "Projection: json_get_int(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data -> Utf8(\"foo\") -> Int64(0)", + "Projection: json_get_int(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data -> 'foo' -> 0", " TableScan: test projection=[json_data]", ]; @@ -953,17 +970,17 @@ async fn test_double_arrow_double_nested_cast() { .unwrap(); let expected = [ - "+------------------+---------------------------------------------+", - "| name | test.json_data ->> Utf8(\"foo\") ->> Int64(0) |", - "+------------------+---------------------------------------------+", - "| object_foo | |", - "| object_foo_array | 1 |", - "| object_foo_obj | |", - "| object_foo_null | |", - "| object_bar | |", - "| list_foo | |", - "| invalid_json | |", - "+------------------+---------------------------------------------+", + "+------------------+--------------------------------+", + "| name | test.json_data ->> 'foo' ->> 0 |", + "+------------------+--------------------------------+", + "| object_foo | |", + "| object_foo_array | 1 |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+--------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -974,7 +991,7 @@ async fn test_plan_double_arrow_double_nested_cast() { // NB: json_as_text(..)::int is NOT the same as `json_get_int(..)`, hence the cast is not rewritten let expected = [ - "Projection: CAST(json_as_text(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data ->> Utf8(\"foo\") ->> Int64(0) AS Int32)", + "Projection: CAST(json_as_text(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data ->> 'foo' ->> 0 AS Int32)", " TableScan: test projection=[json_data]", ]; @@ -1033,17 +1050,17 @@ async fn test_lexical_precedence_correct() { #[tokio::test] async fn test_question_mark_contains() { let expected = [ - "+------------------+------------------------------+", - "| name | test.json_data ? Utf8(\"foo\") |", - "+------------------+------------------------------+", - "| object_foo | true |", - "| object_foo_array | true |", - "| object_foo_obj | true |", - "| object_foo_null | true |", - "| object_bar | false |", - "| list_foo | false |", - "| invalid_json | false |", - "+------------------+------------------------------+", + "+------------------+------------------------+", + "| name | test.json_data ? 'foo' |", + "+------------------+------------------------+", + "| object_foo | true |", + "| object_foo_array | true |", + "| object_foo_obj | true |", + "| object_foo_null | true |", + "| object_bar | false |", + "| list_foo | false |", + "| invalid_json | false |", + "+------------------+------------------------+", ]; let batches = run_query("select name, json_data ? 'foo' from test").await.unwrap(); @@ -1136,17 +1153,17 @@ async fn test_arrow_union_is_null() { .unwrap(); let expected = [ - "+------------------+---------------------------------------+", - "| name | test.json_data -> Utf8(\"foo\") IS NULL |", - "+------------------+---------------------------------------+", - "| object_foo | false |", - "| object_foo_array | false |", - "| object_foo_obj | false |", - "| object_foo_null | true |", - "| object_bar | true |", - "| list_foo | true |", - "| invalid_json | true |", - "+------------------+---------------------------------------+", + "+------------------+---------------------------------+", + "| name | test.json_data -> 'foo' IS NULL |", + "+------------------+---------------------------------+", + "| object_foo | false |", + "| object_foo_array | false |", + "| object_foo_obj | false |", + "| object_foo_null | true |", + "| object_bar | true |", + "| list_foo | true |", + "| invalid_json | true |", + "+------------------+---------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -1158,17 +1175,17 @@ async fn test_arrow_union_is_null_dict_encoded() { .unwrap(); let expected = [ - "+------------------+---------------------------------------+", - "| name | test.json_data -> Utf8(\"foo\") IS NULL |", - "+------------------+---------------------------------------+", - "| object_foo | false |", - "| object_foo_array | false |", - "| object_foo_obj | false |", - "| object_foo_null | true |", - "| object_bar | true |", - "| list_foo | true |", - "| invalid_json | true |", - "+------------------+---------------------------------------+", + "+------------------+---------------------------------+", + "| name | test.json_data -> 'foo' IS NULL |", + "+------------------+---------------------------------+", + "| object_foo | false |", + "| object_foo_array | false |", + "| object_foo_obj | false |", + "| object_foo_null | true |", + "| object_bar | true |", + "| list_foo | true |", + "| invalid_json | true |", + "+------------------+---------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -1180,17 +1197,17 @@ async fn test_arrow_union_is_not_null() { .unwrap(); let expected = [ - "+------------------+-------------------------------------------+", - "| name | test.json_data -> Utf8(\"foo\") IS NOT NULL |", - "+------------------+-------------------------------------------+", - "| object_foo | true |", - "| object_foo_array | true |", - "| object_foo_obj | true |", - "| object_foo_null | false |", - "| object_bar | false |", - "| list_foo | false |", - "| invalid_json | false |", - "+------------------+-------------------------------------------+", + "+------------------+-------------------------------------+", + "| name | test.json_data -> 'foo' IS NOT NULL |", + "+------------------+-------------------------------------+", + "| object_foo | true |", + "| object_foo_array | true |", + "| object_foo_obj | true |", + "| object_foo_null | false |", + "| object_bar | false |", + "| list_foo | false |", + "| invalid_json | false |", + "+------------------+-------------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -1202,17 +1219,17 @@ async fn test_arrow_union_is_not_null_dict_encoded() { .unwrap(); let expected = [ - "+------------------+-------------------------------------------+", - "| name | test.json_data -> Utf8(\"foo\") IS NOT NULL |", - "+------------------+-------------------------------------------+", - "| object_foo | true |", - "| object_foo_array | true |", - "| object_foo_obj | true |", - "| object_foo_null | false |", - "| object_bar | false |", - "| list_foo | false |", - "| invalid_json | false |", - "+------------------+-------------------------------------------+", + "+------------------+-------------------------------------+", + "| name | test.json_data -> 'foo' IS NOT NULL |", + "+------------------+-------------------------------------+", + "| object_foo | true |", + "| object_foo_array | true |", + "| object_foo_obj | true |", + "| object_foo_null | false |", + "| object_bar | false |", + "| list_foo | false |", + "| invalid_json | false |", + "+------------------+-------------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -1243,14 +1260,14 @@ async fn test_long_arrow_cast() { let batches = run_query("select (json_data->>'foo')::int from other").await.unwrap(); let expected = [ - "+---------------------------------+", - "| other.json_data ->> Utf8(\"foo\") |", - "+---------------------------------+", - "| 42 |", - "| 42 |", - "| |", - "| |", - "+---------------------------------+", + "+---------------------------+", + "| other.json_data ->> 'foo' |", + "+---------------------------+", + "| 42 |", + "| 42 |", + "| |", + "| |", + "+---------------------------+", ]; assert_batches_eq!(expected, &batches); }