diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index e7651ec5fa..e3b0e40566 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -161,6 +161,7 @@ jobs: org.apache.comet.CometStringExpressionSuite org.apache.comet.CometBitwiseExpressionSuite org.apache.comet.CometMapExpressionSuite + org.apache.comet.CometJsonExpressionSuite org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 8fd0aab78d..02d31c4f57 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -126,6 +126,7 @@ jobs: org.apache.comet.CometStringExpressionSuite org.apache.comet.CometBitwiseExpressionSuite org.apache.comet.CometMapExpressionSuite + org.apache.comet.CometJsonExpressionSuite org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index db7d2ce32b..13a9c752e3 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -264,6 +264,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.IsNaN.enabled` | Enable Comet acceleration for `IsNaN` | true | | `spark.comet.expression.IsNotNull.enabled` | Enable Comet acceleration for `IsNotNull` | true | | `spark.comet.expression.IsNull.enabled` | Enable Comet acceleration for `IsNull` | true | +| `spark.comet.expression.JsonToStructs.enabled` | Enable Comet acceleration for `JsonToStructs` | true | | `spark.comet.expression.KnownFloatingPointNormalized.enabled` | Enable Comet acceleration for `KnownFloatingPointNormalized` | true | | `spark.comet.expression.Length.enabled` | Enable Comet acceleration for `Length` | true | | `spark.comet.expression.LessThan.enabled` | Enable Comet acceleration for `LessThan` | true | diff --git a/native/Cargo.lock b/native/Cargo.lock index acdd279760..c28be6c54f 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1882,6 +1882,7 @@ dependencies = [ "num", "rand 0.9.2", "regex", + "serde_json", "thiserror 2.0.17", "tokio", "twox-hash", diff --git a/native/core/src/execution/expressions/strings.rs b/native/core/src/execution/expressions/strings.rs index 5f4300eb1e..7219395963 100644 --- a/native/core/src/execution/expressions/strings.rs +++ b/native/core/src/execution/expressions/strings.rs @@ -25,12 +25,13 @@ use datafusion::common::ScalarValue; use datafusion::physical_expr::expressions::{LikeExpr, Literal}; use datafusion::physical_expr::PhysicalExpr; use datafusion_comet_proto::spark_expression::Expr; -use datafusion_comet_spark_expr::{RLike, SubstringExpr}; +use datafusion_comet_spark_expr::{FromJson, RLike, SubstringExpr}; use crate::execution::{ expressions::extract_expr, operators::ExecutionError, planner::{expression_registry::ExpressionBuilder, PhysicalPlanner}, + serde::to_arrow_datatype, }; /// Builder for Substring expressions @@ -98,3 +99,27 @@ impl ExpressionBuilder for RlikeBuilder { } } } + +pub struct FromJsonBuilder; + +impl ExpressionBuilder for FromJsonBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, FromJson); + let child = planner.create_expr( + expr.child.as_ref().ok_or_else(|| { + ExecutionError::GeneralError("FromJson missing child".to_string()) + })?, + input_schema, + )?; + let schema = + to_arrow_datatype(expr.schema.as_ref().ok_or_else(|| { + ExecutionError::GeneralError("FromJson missing schema".to_string()) + })?); + Ok(Arc::new(FromJson::new(child, schema, &expr.timezone))) + } +} diff --git a/native/core/src/execution/planner/expression_registry.rs b/native/core/src/execution/planner/expression_registry.rs index 3321f61182..e85fbe5104 100644 --- a/native/core/src/execution/planner/expression_registry.rs +++ b/native/core/src/execution/planner/expression_registry.rs @@ -94,6 +94,7 @@ pub enum ExpressionType { CreateNamedStruct, GetStructField, ToJson, + FromJson, ToPrettyString, ListExtract, GetArrayStructFields, @@ -281,6 +282,8 @@ impl ExpressionRegistry { .insert(ExpressionType::Like, Box::new(LikeBuilder)); self.builders .insert(ExpressionType::Rlike, Box::new(RlikeBuilder)); + self.builders + .insert(ExpressionType::FromJson, Box::new(FromJsonBuilder)); } /// Extract expression type from Spark protobuf expression @@ -336,6 +339,7 @@ impl ExpressionRegistry { Some(ExprStruct::CreateNamedStruct(_)) => Ok(ExpressionType::CreateNamedStruct), Some(ExprStruct::GetStructField(_)) => Ok(ExpressionType::GetStructField), Some(ExprStruct::ToJson(_)) => Ok(ExpressionType::ToJson), + Some(ExprStruct::FromJson(_)) => Ok(ExpressionType::FromJson), Some(ExprStruct::ToPrettyString(_)) => Ok(ExpressionType::ToPrettyString), Some(ExprStruct::ListExtract(_)) => Ok(ExpressionType::ListExtract), Some(ExprStruct::GetArrayStructFields(_)) => Ok(ExpressionType::GetArrayStructFields), diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index a7736f561a..1c453b6336 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -85,6 +85,7 @@ message Expr { Rand randn = 62; EmptyExpr spark_partition_id = 63; EmptyExpr monotonically_increasing_id = 64; + FromJson from_json = 89; } } @@ -268,6 +269,12 @@ message ToJson { bool ignore_null_fields = 6; } +message FromJson { + Expr child = 1; + DataType schema = 2; + string timezone = 3; +} + enum BinaryOutputStyle { UTF8 = 0; BASIC = 1; diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index b3a46fd917..c973a5b37b 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -33,6 +33,7 @@ datafusion = { workspace = true } chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } +serde_json = "1.0" thiserror = { workspace = true } futures = { workspace = true } twox-hash = "2.1.2" diff --git a/native/spark-expr/src/json_funcs/from_json.rs b/native/spark-expr/src/json_funcs/from_json.rs new file mode 100644 index 0000000000..ebcc84b8ff --- /dev/null +++ b/native/spark-expr/src/json_funcs/from_json.rs @@ -0,0 +1,639 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, BooleanBuilder, Float32Builder, Float64Builder, Int32Builder, Int64Builder, + RecordBatch, StringBuilder, StructArray, +}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::Result; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ColumnarValue; +use std::any::Any; +use std::fmt::{Debug, Display, Formatter}; +use std::sync::Arc; + +/// from_json function - parses JSON strings into structured types +#[derive(Debug, Eq)] +pub struct FromJson { + /// The JSON string input expression + expr: Arc, + /// Target schema for parsing + schema: DataType, + /// Timezone for timestamp parsing (future use) + timezone: String, +} + +impl PartialEq for FromJson { + fn eq(&self, other: &Self) -> bool { + self.expr.eq(&other.expr) && self.schema == other.schema && self.timezone == other.timezone + } +} + +impl std::hash::Hash for FromJson { + fn hash(&self, state: &mut H) { + self.expr.hash(state); + // Note: DataType doesn't implement Hash, so we hash its debug representation + format!("{:?}", self.schema).hash(state); + self.timezone.hash(state); + } +} + +impl FromJson { + pub fn new(expr: Arc, schema: DataType, timezone: &str) -> Self { + Self { + expr, + schema, + timezone: timezone.to_owned(), + } + } +} + +impl Display for FromJson { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "from_json({}, schema={:?}, timezone={})", + self.expr, self.schema, self.timezone + ) + } +} + +impl PartialEq for FromJson { + fn eq(&self, other: &dyn Any) -> bool { + if let Some(other) = other.downcast_ref::() { + self.expr.eq(&other.expr) + && self.schema == other.schema + && self.timezone == other.timezone + } else { + false + } + } +} + +impl PhysicalExpr for FromJson { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + + fn data_type(&self, _: &Schema) -> Result { + Ok(self.schema.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + // Always nullable - parse errors return null in PERMISSIVE mode + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let input = self.expr.evaluate(batch)?.into_array(batch.num_rows())?; + Ok(ColumnarValue::Array(json_string_to_struct( + &input, + &self.schema, + )?)) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.expr] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.len() == 1); + Ok(Arc::new(Self::new( + Arc::clone(&children[0]), + self.schema.clone(), + &self.timezone, + ))) + } +} + +/// Parse JSON string array into struct array +fn json_string_to_struct(arr: &Arc, schema: &DataType) -> Result { + use arrow::array::StringArray; + use arrow::buffer::NullBuffer; + + let string_array = arr.as_any().downcast_ref::().ok_or_else(|| { + datafusion::common::DataFusionError::Execution("from_json expects string input".to_string()) + })?; + + let DataType::Struct(fields) = schema else { + return Err(datafusion::common::DataFusionError::Execution( + "from_json requires struct schema".to_string(), + )); + }; + + let num_rows = string_array.len(); + let mut field_builders = create_field_builders(fields, num_rows)?; + let mut struct_nulls = vec![true; num_rows]; + for (row_idx, struct_null) in struct_nulls.iter_mut().enumerate() { + if string_array.is_null(row_idx) { + // Null input -> null struct + *struct_null = false; + append_null_to_all_builders(&mut field_builders); + } else { + let json_str = string_array.value(row_idx); + + // Parse JSON (PERMISSIVE mode: return null fields on error) + match serde_json::from_str::(json_str) { + Ok(json_value) => { + if let serde_json::Value::Object(obj) = json_value { + // Struct is not null, extract each field + *struct_null = true; + for (field, builder) in fields.iter().zip(field_builders.iter_mut()) { + let field_value = obj.get(field.name()); + append_field_value(builder, field, field_value)?; + } + } else { + // Not an object -> struct with null fields + *struct_null = true; + append_null_to_all_builders(&mut field_builders); + } + } + Err(_) => { + // Parse error -> struct with null fields (PERMISSIVE mode) + *struct_null = true; + append_null_to_all_builders(&mut field_builders); + } + } + } + } + + let arrays: Vec = field_builders + .into_iter() + .map(finish_builder) + .collect::>>()?; + let null_buffer = NullBuffer::from(struct_nulls); + Ok(Arc::new(StructArray::new( + fields.clone(), + arrays, + Some(null_buffer), + ))) +} + +/// Builder enum for different data types +enum FieldBuilder { + Int32(Int32Builder), + Int64(Int64Builder), + Float32(Float32Builder), + Float64(Float64Builder), + Boolean(BooleanBuilder), + String(StringBuilder), + Struct { + fields: arrow::datatypes::Fields, + builders: Vec, + null_buffer: Vec, + }, +} + +fn create_field_builders( + fields: &arrow::datatypes::Fields, + capacity: usize, +) -> Result> { + fields + .iter() + .map(|field| match field.data_type() { + DataType::Int32 => Ok(FieldBuilder::Int32(Int32Builder::with_capacity(capacity))), + DataType::Int64 => Ok(FieldBuilder::Int64(Int64Builder::with_capacity(capacity))), + DataType::Float32 => Ok(FieldBuilder::Float32(Float32Builder::with_capacity( + capacity, + ))), + DataType::Float64 => Ok(FieldBuilder::Float64(Float64Builder::with_capacity( + capacity, + ))), + DataType::Boolean => Ok(FieldBuilder::Boolean(BooleanBuilder::with_capacity( + capacity, + ))), + DataType::Utf8 => Ok(FieldBuilder::String(StringBuilder::with_capacity( + capacity, + capacity * 16, + ))), + DataType::Struct(nested_fields) => { + let nested_builders = create_field_builders(nested_fields, capacity)?; + Ok(FieldBuilder::Struct { + fields: nested_fields.clone(), + builders: nested_builders, + null_buffer: Vec::with_capacity(capacity), + }) + } + dt => Err(datafusion::common::DataFusionError::Execution(format!( + "Unsupported field type in from_json: {:?}", + dt + ))), + }) + .collect() +} + +fn append_null_to_all_builders(builders: &mut [FieldBuilder]) { + for builder in builders { + match builder { + FieldBuilder::Int32(b) => b.append_null(), + FieldBuilder::Int64(b) => b.append_null(), + FieldBuilder::Float32(b) => b.append_null(), + FieldBuilder::Float64(b) => b.append_null(), + FieldBuilder::Boolean(b) => b.append_null(), + FieldBuilder::String(b) => b.append_null(), + FieldBuilder::Struct { + builders: nested_builders, + null_buffer, + .. + } => { + // Append null to nested struct + null_buffer.push(false); + append_null_to_all_builders(nested_builders); + } + } + } +} + +fn append_field_value( + builder: &mut FieldBuilder, + field: &Field, + json_value: Option<&serde_json::Value>, +) -> Result<()> { + use serde_json::Value; + + let value = match json_value { + Some(Value::Null) | None => { + // Missing field or explicit null -> append null + match builder { + FieldBuilder::Int32(b) => b.append_null(), + FieldBuilder::Int64(b) => b.append_null(), + FieldBuilder::Float32(b) => b.append_null(), + FieldBuilder::Float64(b) => b.append_null(), + FieldBuilder::Boolean(b) => b.append_null(), + FieldBuilder::String(b) => b.append_null(), + FieldBuilder::Struct { + builders: nested_builders, + null_buffer, + .. + } => { + null_buffer.push(false); + append_null_to_all_builders(nested_builders); + } + } + return Ok(()); + } + Some(v) => v, + }; + + match (builder, field.data_type()) { + (FieldBuilder::Int32(b), DataType::Int32) => { + if let Some(i) = value.as_i64() { + if i >= i32::MIN as i64 && i <= i32::MAX as i64 { + b.append_value(i as i32); + } else { + b.append_null(); // Overflow + } + } else { + b.append_null(); // Type mismatch + } + } + (FieldBuilder::Int64(b), DataType::Int64) => { + if let Some(i) = value.as_i64() { + b.append_value(i); + } else { + b.append_null(); + } + } + (FieldBuilder::Float32(b), DataType::Float32) => { + if let Some(f) = value.as_f64() { + b.append_value(f as f32); + } else { + b.append_null(); + } + } + (FieldBuilder::Float64(b), DataType::Float64) => { + if let Some(f) = value.as_f64() { + b.append_value(f); + } else { + b.append_null(); + } + } + (FieldBuilder::Boolean(b), DataType::Boolean) => { + if let Some(bool_val) = value.as_bool() { + b.append_value(bool_val); + } else { + b.append_null(); + } + } + (FieldBuilder::String(b), DataType::Utf8) => { + if let Some(s) = value.as_str() { + b.append_value(s); + } else { + // Stringify non-string values + b.append_value(value.to_string()); + } + } + ( + FieldBuilder::Struct { + fields: nested_fields, + builders: nested_builders, + null_buffer, + }, + DataType::Struct(_), + ) => { + // Handle nested struct + if let Some(obj) = value.as_object() { + // Non-null nested struct + null_buffer.push(true); + for (nested_field, nested_builder) in + nested_fields.iter().zip(nested_builders.iter_mut()) + { + let nested_value = obj.get(nested_field.name()); + append_field_value(nested_builder, nested_field, nested_value)?; + } + } else { + // Not an object -> null nested struct + null_buffer.push(false); + append_null_to_all_builders(nested_builders); + } + } + _ => { + return Err(datafusion::common::DataFusionError::Execution( + "Type mismatch in from_json".to_string(), + )); + } + } + + Ok(()) +} + +fn finish_builder(builder: FieldBuilder) -> Result { + Ok(match builder { + FieldBuilder::Int32(mut b) => Arc::new(b.finish()), + FieldBuilder::Int64(mut b) => Arc::new(b.finish()), + FieldBuilder::Float32(mut b) => Arc::new(b.finish()), + FieldBuilder::Float64(mut b) => Arc::new(b.finish()), + FieldBuilder::Boolean(mut b) => Arc::new(b.finish()), + FieldBuilder::String(mut b) => Arc::new(b.finish()), + FieldBuilder::Struct { + fields, + builders, + null_buffer, + } => { + let nested_arrays: Vec = builders + .into_iter() + .map(finish_builder) + .collect::>>()?; + let null_buf = arrow::buffer::NullBuffer::from(null_buffer); + Arc::new(StructArray::new(fields, nested_arrays, Some(null_buf))) + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::Fields; + + #[test] + fn test_simple_struct() -> Result<()> { + let schema = DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + + let input: Arc = Arc::new(StringArray::from(vec![ + Some(r#"{"a": 123, "b": "hello"}"#), + Some(r#"{"a": 456}"#), + Some(r#"invalid json"#), + None, + ])); + + let result = json_string_to_struct(&input, &schema)?; + let struct_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(struct_array.len(), 4); + + // First row + let a_array = struct_array + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a_array.value(0), 123); + let b_array = struct_array + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(b_array.value(0), "hello"); + + // Second row (missing field b) + assert_eq!(a_array.value(1), 456); + assert!(b_array.is_null(1)); + + // Third row (parse error -> struct NOT null, all fields null) + assert!(!struct_array.is_null(2), "Struct should not be null"); + assert!(a_array.is_null(2)); + assert!(b_array.is_null(2)); + + // Fourth row (null input -> struct IS null) + assert!(struct_array.is_null(3), "Struct itself should be null"); + + Ok(()) + } + + #[test] + fn test_all_primitive_types() -> Result<()> { + let schema = DataType::Struct(Fields::from(vec![ + Field::new("i32", DataType::Int32, true), + Field::new("i64", DataType::Int64, true), + Field::new("f32", DataType::Float32, true), + Field::new("f64", DataType::Float64, true), + Field::new("bool", DataType::Boolean, true), + Field::new("str", DataType::Utf8, true), + ])); + + let input: Arc = Arc::new(StringArray::from(vec![Some( + r#"{"i32":123,"i64":9999999999,"f32":1.5,"f64":2.5,"bool":true,"str":"test"}"#, + )])); + + let result = json_string_to_struct(&input, &schema)?; + let struct_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(struct_array.len(), 1); + + // Verify all types + let i32_array = struct_array + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(i32_array.value(0), 123); + + let i64_array = struct_array + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(i64_array.value(0), 9999999999); + + let f32_array = struct_array + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(f32_array.value(0), 1.5); + + let f64_array = struct_array + .column(3) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(f64_array.value(0), 2.5); + + let bool_array = struct_array + .column(4) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(bool_array.value(0)); + + let str_array = struct_array + .column(5) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(str_array.value(0), "test"); + + Ok(()) + } + + #[test] + fn test_empty_and_null_json() -> Result<()> { + let schema = DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + + let input: Arc = Arc::new(StringArray::from(vec![ + Some(r#"{}"#), // Empty object + Some(r#"null"#), // JSON null + Some(r#"[]"#), // Array (not object) + Some(r#"123"#), // Number (not object) + ])); + + let result = json_string_to_struct(&input, &schema)?; + let struct_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(struct_array.len(), 4); + + let a_array = struct_array + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let b_array = struct_array + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + // All rows should have non-null structs with null field values + for i in 0..4 { + assert!( + !struct_array.is_null(i), + "Row {} struct should not be null", + i + ); + assert!(a_array.is_null(i), "Row {} field a should be null", i); + assert!(b_array.is_null(i), "Row {} field b should be null", i); + } + + Ok(()) + } + + #[test] + fn test_nested_struct() -> Result<()> { + let schema = DataType::Struct(Fields::from(vec![ + Field::new( + "outer", + DataType::Struct(Fields::from(vec![ + Field::new("inner_a", DataType::Int32, true), + Field::new("inner_b", DataType::Utf8, true), + ])), + true, + ), + Field::new("top_level", DataType::Int32, true), + ])); + + let input: Arc = Arc::new(StringArray::from(vec![ + Some(r#"{"outer":{"inner_a":123,"inner_b":"hello"},"top_level":999}"#), + Some(r#"{"outer":{"inner_a":456},"top_level":888}"#), // Missing nested field + Some(r#"{"outer":null,"top_level":777}"#), // Null nested struct + Some(r#"{"top_level":666}"#), // Missing nested struct + ])); + + let result = json_string_to_struct(&input, &schema)?; + let struct_array = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(struct_array.len(), 4); + + // Check outer struct + let outer_array = struct_array + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let top_level_array = struct_array + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + // Row 0: Valid nested struct + assert!(!outer_array.is_null(0), "Nested struct should not be null"); + let inner_a_array = outer_array + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let inner_b_array = outer_array + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(inner_a_array.value(0), 123); + assert_eq!(inner_b_array.value(0), "hello"); + assert_eq!(top_level_array.value(0), 999); + + // Row 1: Missing nested field + assert!(!outer_array.is_null(1)); + assert_eq!(inner_a_array.value(1), 456); + assert!(inner_b_array.is_null(1)); + assert_eq!(top_level_array.value(1), 888); + + // Row 2: Null nested struct + assert!(outer_array.is_null(2), "Nested struct should be null"); + assert_eq!(top_level_array.value(2), 777); + + // Row 3: Missing nested struct + assert!(outer_array.is_null(3), "Nested struct should be null"); + assert_eq!(top_level_array.value(3), 666); + + Ok(()) + } +} diff --git a/native/spark-expr/src/json_funcs/mod.rs b/native/spark-expr/src/json_funcs/mod.rs index de3037590d..9f025070d7 100644 --- a/native/spark-expr/src/json_funcs/mod.rs +++ b/native/spark-expr/src/json_funcs/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod from_json; mod to_json; +pub use from_json::FromJson; pub use to_json::ToJson; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 2903061d60..96e727ae55 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -71,7 +71,7 @@ pub use comet_scalar_funcs::{ pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr}; pub use error::{SparkError, SparkResult}; pub use hash_funcs::*; -pub use json_funcs::ToJson; +pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 54df2f1688..83917d33fc 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -131,6 +131,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[CreateNamedStruct] -> CometCreateNamedStruct, classOf[GetArrayStructFields] -> CometGetArrayStructFields, classOf[GetStructField] -> CometGetStructField, + classOf[JsonToStructs] -> CometJsonToStructs, classOf[StructsToJson] -> CometStructsToJson) private val hashExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( diff --git a/spark/src/main/scala/org/apache/comet/serde/structs.scala b/spark/src/main/scala/org/apache/comet/serde/structs.scala index 208b2e1262..55e031d346 100644 --- a/spark/src/main/scala/org/apache/comet/serde/structs.scala +++ b/spark/src/main/scala/org/apache/comet/serde/structs.scala @@ -21,11 +21,11 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, GetArrayStructFields, GetStructField, StructsToJson} +import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, GetArrayStructFields, GetStructField, JsonToStructs, StructsToJson} import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructType} import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} object CometCreateNamedStruct extends CometExpressionSerde[CreateNamedStruct] { override def convert( @@ -167,3 +167,67 @@ object CometStructsToJson extends CometExpressionSerde[StructsToJson] { } } } + +object CometJsonToStructs extends CometExpressionSerde[JsonToStructs] { + + override def getSupportLevel(expr: JsonToStructs): SupportLevel = { + // this feature is partially implemented and not comprehensively tested yet + Incompatible() + } + + override def convert( + expr: JsonToStructs, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + + if (expr.schema == null) { + withInfo(expr, "from_json requires explicit schema") + return None + } + + def isSupportedType(dt: DataType): Boolean = { + dt match { + case StructType(fields) => + fields.nonEmpty && fields.forall(f => isSupportedType(f.dataType)) + case DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | + DataTypes.DoubleType | DataTypes.BooleanType | DataTypes.StringType => + true + case _ => false + } + } + + val schemaType = expr.schema + if (!isSupportedType(schemaType)) { + withInfo(expr, "from_json: Unsupported schema type") + return None + } + + val options = expr.options + if (options.nonEmpty) { + val mode = options.getOrElse("mode", "PERMISSIVE") + if (mode != "PERMISSIVE") { + withInfo(expr, s"from_json: Only PERMISSIVE mode supported, got: $mode") + return None + } + val knownOptions = Set("mode") + val unknownOpts = options.keySet -- knownOptions + if (unknownOpts.nonEmpty) { + withInfo(expr, s"from_json: Ignoring unsupported options: ${unknownOpts.mkString(", ")}") + } + } + + // Convert child expression and schema to protobuf + for { + childProto <- exprToProtoInternal(expr.child, inputs, binding) + schemaProto <- serializeDataType(schemaType) + } yield { + val fromJson = ExprOuterClass.FromJson + .newBuilder() + .setChild(childProto) + .setSchema(schemaProto) + .setTimezone(expr.timeZoneId.getOrElse("UTC")) + .build() + ExprOuterClass.Expr.newBuilder().setFromJson(fromJson).build() + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala new file mode 100644 index 0000000000..38f5765268 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.expressions.JsonToStructs +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +class CometJsonExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[JsonToStructs]) -> "true") { + testFun + } + } + } + + test("from_json - basic primitives") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + (0 until 100).map(i => { + val json = s"""{"a":$i,"b":"str_$i"}""" + (i, json) + }), + "tbl", + withDictionary = dictionaryEnabled) { + + val schema = "a INT, b STRING" + checkSparkAnswerAndOperator(s"SELECT from_json(_2, '$schema') FROM tbl") + checkSparkAnswerAndOperator(s"SELECT from_json(_2, '$schema').a FROM tbl") + checkSparkAnswerAndOperator(s"SELECT from_json(_2, '$schema').b FROM tbl") + } + } + } + + test("from_json - null and error handling") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + Seq( + (1, """{"a":123,"b":"test"}"""), // Valid JSON + (2, """{"a":456}"""), // Missing field b + (3, """{"a":null}"""), // Explicit null + (4, """invalid json"""), // Parse error + (5, """{}"""), // Empty object + (6, """null"""), // JSON null + (7, null) // Null input + ), + "tbl", + withDictionary = dictionaryEnabled) { + + val schema = "a INT, b STRING" + checkSparkAnswerAndOperator( + s"SELECT _1, from_json(_2, '$schema') as parsed FROM tbl ORDER BY _1") + } + } + } + + test("from_json - all primitive types") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + (0 until 50).map(i => { + val sign = if (i % 2 == 0) 1 else -1 + val json = + s"""{"i32":${sign * i},"i64":${sign * i * 1000000000L},"f32":${sign * i * 1.5},"f64":${sign * i * 2.5},"bool":${i % 2 == 0},"str":"value_$i"}""" + (i, json) + }), + "tbl", + withDictionary = dictionaryEnabled) { + + val schema = "i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE, bool BOOLEAN, str STRING" + checkSparkAnswerAndOperator(s"SELECT from_json(_2, '$schema') FROM tbl") + checkSparkAnswerAndOperator(s"SELECT from_json(_2, '$schema').i32 FROM tbl") + checkSparkAnswerAndOperator(s"SELECT from_json(_2, '$schema').str FROM tbl") + } + } + } + + test("from_json - null input produces null struct") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + Seq( + (1, """{"a":1,"b":"x"}"""), // Valid JSON to establish column type + (2, null) // Null input + ), + "tbl", + withDictionary = dictionaryEnabled) { + + val schema = "a INT, b STRING" + // Verify that null input produces a NULL struct (not a struct with null fields) + checkSparkAnswerAndOperator( + s"SELECT _1, from_json(_2, '$schema') IS NULL as struct_is_null FROM tbl WHERE _1 = 2") + // Field access on null struct should return null + checkSparkAnswerAndOperator( + s"SELECT _1, from_json(_2, '$schema').a FROM tbl WHERE _1 = 2") + } + } + } + + test("from_json - nested struct") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + (0 until 50).map(i => { + val json = s"""{"outer":{"inner_a":$i,"inner_b":"nested_$i"},"top_level":${i * 10}}""" + (i, json) + }), + "tbl", + withDictionary = dictionaryEnabled) { + + val schema = "outer STRUCT, top_level INT" + checkSparkAnswerAndOperator(s"SELECT from_json(_2, '$schema') FROM tbl") + checkSparkAnswerAndOperator(s"SELECT from_json(_2, '$schema').outer FROM tbl") + checkSparkAnswerAndOperator(s"SELECT from_json(_2, '$schema').outer.inner_a FROM tbl") + checkSparkAnswerAndOperator(s"SELECT from_json(_2, '$schema').top_level FROM tbl") + } + } + } + + test("from_json - valid json with incompatible schema") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + Seq( + (1, """{"a":"not_a_number","b":"test"}"""), // String where INT expected + (2, """{"a":123,"b":456}"""), // Number where STRING expected + (3, """{"a":{"nested":"value"},"b":"test"}"""), // Object where INT expected + (4, """{"a":[1,2,3],"b":"test"}"""), // Array where INT expected + (5, """{"a":123.456,"b":"test"}"""), // Float where INT expected + (6, """{"a":true,"b":"test"}"""), // Boolean where INT expected + (7, """{"a":123,"b":null}""") // Null value for STRING field + ), + "tbl", + withDictionary = dictionaryEnabled) { + + val schema = "a INT, b STRING" + // When types don't match, Spark typically returns null for that field + checkSparkAnswerAndOperator( + s"SELECT _1, from_json(_2, '$schema') as parsed FROM tbl ORDER BY _1") + checkSparkAnswerAndOperator(s"SELECT _1, from_json(_2, '$schema').a FROM tbl ORDER BY _1") + checkSparkAnswerAndOperator(s"SELECT _1, from_json(_2, '$schema').b FROM tbl ORDER BY _1") + } + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala new file mode 100644 index 0000000000..e8bd00bd9c --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.catalyst.expressions.JsonToStructs + +import org.apache.comet.CometConf + +/** + * Configuration for a JSON expression benchmark. + * @param name + * Name for the benchmark + * @param schema + * Target schema for from_json + * @param query + * SQL query to benchmark + * @param extraCometConfigs + * Additional Comet configurations for the scan+exec case + */ +case class JsonExprConfig( + name: String, + schema: String, + query: String, + extraCometConfigs: Map[String, String] = Map.empty) + +// spotless:off +/** + * Benchmark to measure performance of Comet JSON expressions. To run this benchmark: + * `SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometJsonExpressionBenchmark` + * Results will be written to "spark/benchmarks/CometJsonExpressionBenchmark-**results.txt". + */ +// spotless:on +object CometJsonExpressionBenchmark extends CometBenchmarkBase { + + /** + * Generic method to run a JSON expression benchmark with the given configuration. + */ + def runJsonExprBenchmark(config: JsonExprConfig, values: Int): Unit = { + val benchmark = new Benchmark(config.name, values, output = output) + + withTempPath { dir => + withTempTable("parquetV1Table") { + // Generate data with specified JSON patterns + val jsonData = config.name match { + case "from_json - simple primitives" => + spark.sql(s""" + SELECT + concat('{"a":', CAST(value AS STRING), ',"b":"str_', CAST(value AS STRING), '"}') AS json_str + FROM $tbl + """) + + case "from_json - all primitive types" => + spark.sql(s""" + SELECT + concat( + '{"i32":', CAST(value % 1000 AS STRING), + ',"i64":', CAST(value * 1000000000L AS STRING), + ',"f32":', CAST(value * 1.5 AS STRING), + ',"f64":', CAST(value * 2.5 AS STRING), + ',"bool":', CASE WHEN value % 2 = 0 THEN 'true' ELSE 'false' END, + ',"str":"value_', CAST(value AS STRING), '"}' + ) AS json_str + FROM $tbl + """) + + case "from_json - with nulls" => + spark.sql(s""" + SELECT + CASE + WHEN value % 10 = 0 THEN NULL + WHEN value % 5 = 0 THEN '{"a":null,"b":"test"}' + WHEN value % 3 = 0 THEN '{"a":123}' + ELSE concat('{"a":', CAST(value AS STRING), ',"b":"str_', CAST(value AS STRING), '"}') + END AS json_str + FROM $tbl + """) + + case "from_json - nested struct" => + spark.sql(s""" + SELECT + concat( + '{"outer":{"inner_a":', CAST(value AS STRING), + ',"inner_b":"nested_', CAST(value AS STRING), '"}}') AS json_str + FROM $tbl + """) + + case "from_json - field access" => + spark.sql(s""" + SELECT + concat('{"a":', CAST(value AS STRING), ',"b":"str_', CAST(value AS STRING), '"}') AS json_str + FROM $tbl + """) + + case _ => + spark.sql(s""" + SELECT + concat('{"a":', CAST(value AS STRING), ',"b":"str_', CAST(value AS STRING), '"}') AS json_str + FROM $tbl + """) + } + + prepareTable(dir, jsonData) + + benchmark.addCase("SQL Parquet - Spark") { _ => + spark.sql(config.query).noop() + } + + benchmark.addCase("SQL Parquet - Comet (Scan)") { _ => + withSQLConf(CometConf.COMET_ENABLED.key -> "true") { + spark.sql(config.query).noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Scan, Exec)") { _ => + val baseConfigs = + Map( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.getExprAllowIncompatConfigKey(classOf[JsonToStructs]) -> "true", + "spark.sql.optimizer.constantFolding.enabled" -> "false") + val allConfigs = baseConfigs ++ config.extraCometConfigs + + withSQLConf(allConfigs.toSeq: _*) { + spark.sql(config.query).noop() + } + } + + benchmark.run() + } + } + } + + // Configuration for all JSON expression benchmarks + private val jsonExpressions = List( + JsonExprConfig( + "from_json - simple primitives", + "a INT, b STRING", + "SELECT from_json(json_str, 'a INT, b STRING') FROM parquetV1Table"), + JsonExprConfig( + "from_json - all primitive types", + "i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE, bool BOOLEAN, str STRING", + "SELECT from_json(json_str, 'i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE, bool BOOLEAN, str STRING') FROM parquetV1Table"), + JsonExprConfig( + "from_json - with nulls", + "a INT, b STRING", + "SELECT from_json(json_str, 'a INT, b STRING') FROM parquetV1Table"), + JsonExprConfig( + "from_json - nested struct", + "outer STRUCT", + "SELECT from_json(json_str, 'outer STRUCT') FROM parquetV1Table"), + JsonExprConfig( + "from_json - field access", + "a INT, b STRING", + "SELECT from_json(json_str, 'a INT, b STRING').a FROM parquetV1Table")) + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + val values = 1024 * 1024 + + jsonExpressions.foreach { config => + runBenchmarkWithTable(config.name, values) { v => + runJsonExprBenchmark(config, v) + } + } + } +}