diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 1f9a4263f2..52b8eb6a30 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -41,6 +41,8 @@ use datafusion::{ }; use datafusion_comet_proto::spark_operator::Operator; use datafusion_spark::function::bitwise::bit_get::SparkBitGet; +use datafusion_spark::function::datetime::date_add::SparkDateAdd; +use datafusion_spark::function::datetime::date_sub::SparkDateSub; use datafusion_spark::function::hash::sha2::SparkSha2; use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::string::char::CharFunc; @@ -303,6 +305,8 @@ fn prepare_datafusion_session_context( session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default())); // Must be the last one to override existing functions with the same name datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 4bf1cd45d1..4b863927e6 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -19,10 +19,10 @@ use crate::hash_funcs::*; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ - spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, - spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, - spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, - SparkBitwiseCount, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace, + spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, + spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, + spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot, + SparkDateTrunc, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -166,14 +166,6 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_isnan); make_comet_scalar_udf!("isnan", func, without data_type) } - "date_add" => { - let func = Arc::new(spark_date_add); - make_comet_scalar_udf!("date_add", func, without data_type) - } - "date_sub" => { - let func = Arc::new(spark_date_sub); - make_comet_scalar_udf!("date_sub", func, without data_type) - } "array_repeat" => { let func = Arc::new(spark_array_repeat); make_comet_scalar_udf!("array_repeat", func, without data_type) diff --git a/native/spark-expr/src/datetime_funcs/date_arithmetic.rs b/native/spark-expr/src/datetime_funcs/date_arithmetic.rs deleted file mode 100644 index 4b4db2eb5c..0000000000 --- a/native/spark-expr/src/datetime_funcs/date_arithmetic.rs +++ /dev/null @@ -1,101 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::builder::IntervalDayTimeBuilder; -use arrow::array::types::{Int16Type, Int32Type, Int8Type}; -use arrow::array::{Array, Datum}; -use arrow::array::{ArrayRef, AsArray}; -use arrow::compute::kernels::numeric::{add, sub}; -use arrow::datatypes::DataType; -use arrow::datatypes::IntervalDayTime; -use arrow::error::ArrowError; -use datafusion::common::{DataFusionError, ScalarValue}; -use datafusion::physical_expr_common::datum; -use datafusion::physical_plan::ColumnarValue; -use std::sync::Arc; - -macro_rules! scalar_date_arithmetic { - ($start:expr, $days:expr, $op:expr) => {{ - let interval = IntervalDayTime::new(*$days as i32, 0); - let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval))); - datum::apply($start, &interval_cv, $op) - }}; -} -macro_rules! array_date_arithmetic { - ($days:expr, $interval_builder:expr, $intType:ty) => {{ - for day in $days.as_primitive::<$intType>().into_iter() { - if let Some(non_null_day) = day { - $interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0)); - } else { - $interval_builder.append_null(); - } - } - }}; -} - -/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second -/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the -/// second argument and use DataFusion's interface to apply Arrow's operators. -fn spark_date_arithmetic( - args: &[ColumnarValue], - op: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - let start = &args[0]; - match &args[1] { - ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => { - scalar_date_arithmetic!(start, days, op) - } - ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => { - scalar_date_arithmetic!(start, days, op) - } - ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => { - scalar_date_arithmetic!(start, days, op) - } - ColumnarValue::Array(days) => { - let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len()); - match days.data_type() { - DataType::Int8 => { - array_date_arithmetic!(days, interval_builder, Int8Type) - } - DataType::Int16 => { - array_date_arithmetic!(days, interval_builder, Int16Type) - } - DataType::Int32 => { - array_date_arithmetic!(days, interval_builder, Int32Type) - } - _ => { - return Err(DataFusionError::Internal(format!( - "Unsupported data types {args:?} for date arithmetic.", - ))) - } - } - let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish())); - datum::apply(start, &interval_cv, op) - } - _ => Err(DataFusionError::Internal(format!( - "Unsupported data types {args:?} for date arithmetic.", - ))), - } -} - -pub fn spark_date_add(args: &[ColumnarValue]) -> Result { - spark_date_arithmetic(args, add) -} - -pub fn spark_date_sub(args: &[ColumnarValue]) -> Result { - spark_date_arithmetic(args, sub) -} diff --git a/native/spark-expr/src/datetime_funcs/mod.rs b/native/spark-expr/src/datetime_funcs/mod.rs index 0ca7bb9401..ef8041e5fe 100644 --- a/native/spark-expr/src/datetime_funcs/mod.rs +++ b/native/spark-expr/src/datetime_funcs/mod.rs @@ -15,12 +15,10 @@ // specific language governing permissions and limitations // under the License. -mod date_arithmetic; mod date_trunc; mod extract_date_part; mod timestamp_trunc; -pub use date_arithmetic::{spark_date_add, spark_date_sub}; pub use date_trunc::SparkDateTrunc; pub use extract_date_part::SparkHour; pub use extract_date_part::SparkMinute; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 7bdc7ff515..932fcbe53d 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -68,10 +68,7 @@ pub use comet_scalar_funcs::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, register_all_comet_functions, }; -pub use datetime_funcs::{ - spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, - TimestampTruncExpr, -}; +pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr}; pub use error::{SparkError, SparkResult}; pub use hash_funcs::*; pub use json_funcs::ToJson; diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala b/spark/src/main/scala/org/apache/comet/serde/datetime.scala index 8e4c92d707..9473ee30e4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala +++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.{DateType, IntegerType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.CometGetDateField.CometGetDateField import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType} +import org.apache.comet.serde.QueryPlanSerde._ private object CometGetDateField extends Enumeration { type CometGetDateField = Value @@ -251,31 +251,9 @@ object CometSecond extends CometExpressionSerde[Second] { } } -object CometDateAdd extends CometExpressionSerde[DateAdd] { - override def convert( - expr: DateAdd, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val leftExpr = exprToProtoInternal(expr.left, inputs, binding) - val rightExpr = exprToProtoInternal(expr.right, inputs, binding) - val optExpr = - scalarFunctionExprToProtoWithReturnType("date_add", DateType, leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, expr.left, expr.right) - } -} +object CometDateAdd extends CometScalarFunction[DateAdd]("date_add") -object CometDateSub extends CometExpressionSerde[DateSub] { - override def convert( - expr: DateSub, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val leftExpr = exprToProtoInternal(expr.left, inputs, binding) - val rightExpr = exprToProtoInternal(expr.right, inputs, binding) - val optExpr = - scalarFunctionExprToProtoWithReturnType("date_sub", DateType, leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, expr.left, expr.right) - } -} +object CometDateSub extends CometScalarFunction[DateSub]("date_sub") object CometTruncDate extends CometExpressionSerde[TruncDate] { override def convert( diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index daf0e45cc8..07663ea91f 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -252,7 +252,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } else { assert(sparkErr.get.getMessage.contains("integer overflow")) } - assert(cometErr.get.getMessage.contains("`NaiveDate + TimeDelta` overflowed")) + assert(cometErr.get.getMessage.contains("attempt to add with overflow")) } } } @@ -296,10 +296,11 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM tbl")) if (isSpark40Plus) { assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED")) + assert(cometErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED")) } else { assert(sparkErr.get.getMessage.contains("integer overflow")) + assert(cometErr.get.getMessage.contains("integer overflow")) } - assert(cometErr.get.getMessage.contains("`NaiveDate - TimeDelta` overflowed")) } } }