Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ use datafusion::{
};
use datafusion_comet_proto::spark_operator::Operator;
use datafusion_spark::function::bitwise::bit_get::SparkBitGet;
use datafusion_spark::function::datetime::date_add::SparkDateAdd;
use datafusion_spark::function::datetime::date_sub::SparkDateSub;
use datafusion_spark::function::hash::sha2::SparkSha2;
use datafusion_spark::function::math::expm1::SparkExpm1;
use datafusion_spark::function::string::char::CharFunc;
Expand Down Expand Up @@ -303,6 +305,8 @@ fn prepare_datafusion_session_context(
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default()));

// Must be the last one to override existing functions with the same name
datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?;
Expand Down
16 changes: 4 additions & 12 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ use crate::hash_funcs::*;
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode,
SparkBitwiseCount, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace,
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot,
SparkDateTrunc, SparkStringSpace,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -166,14 +166,6 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(spark_isnan);
make_comet_scalar_udf!("isnan", func, without data_type)
}
"date_add" => {
let func = Arc::new(spark_date_add);
make_comet_scalar_udf!("date_add", func, without data_type)
}
"date_sub" => {
let func = Arc::new(spark_date_sub);
make_comet_scalar_udf!("date_sub", func, without data_type)
}
"array_repeat" => {
let func = Arc::new(spark_array_repeat);
make_comet_scalar_udf!("array_repeat", func, without data_type)
Expand Down
101 changes: 0 additions & 101 deletions native/spark-expr/src/datetime_funcs/date_arithmetic.rs

This file was deleted.

2 changes: 0 additions & 2 deletions native/spark-expr/src/datetime_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
// specific language governing permissions and limitations
// under the License.

mod date_arithmetic;
mod date_trunc;
mod extract_date_part;
mod timestamp_trunc;

pub use date_arithmetic::{spark_date_add, spark_date_sub};
pub use date_trunc::SparkDateTrunc;
pub use extract_date_part::SparkHour;
pub use extract_date_part::SparkMinute;
Expand Down
5 changes: 1 addition & 4 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ pub use comet_scalar_funcs::{
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
register_all_comet_functions,
};
pub use datetime_funcs::{
spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute, SparkSecond,
TimestampTruncExpr,
};
pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr};
pub use error::{SparkError, SparkResult};
pub use hash_funcs::*;
pub use json_funcs::ToJson;
Expand Down
28 changes: 3 additions & 25 deletions spark/src/main/scala/org/apache/comet/serde/datetime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.{DateType, IntegerType}
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.CometGetDateField.CometGetDateField
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType}
import org.apache.comet.serde.QueryPlanSerde._

private object CometGetDateField extends Enumeration {
type CometGetDateField = Value
Expand Down Expand Up @@ -251,31 +251,9 @@ object CometSecond extends CometExpressionSerde[Second] {
}
}

object CometDateAdd extends CometExpressionSerde[DateAdd] {
override def convert(
expr: DateAdd,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("date_add", DateType, leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, expr.left, expr.right)
}
}
object CometDateAdd extends CometScalarFunction[DateAdd]("date_add")

object CometDateSub extends CometExpressionSerde[DateSub] {
override def convert(
expr: DateSub,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("date_sub", DateType, leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, expr.left, expr.right)
}
}
object CometDateSub extends CometScalarFunction[DateSub]("date_sub")

object CometTruncDate extends CometExpressionSerde[TruncDate] {
override def convert(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
} else {
assert(sparkErr.get.getMessage.contains("integer overflow"))
}
assert(cometErr.get.getMessage.contains("`NaiveDate + TimeDelta` overflowed"))
assert(cometErr.get.getMessage.contains("attempt to add with overflow"))
}
}
}
Expand Down Expand Up @@ -296,10 +296,11 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
checkSparkMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM tbl"))
if (isSpark40Plus) {
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
assert(cometErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
} else {
assert(sparkErr.get.getMessage.contains("integer overflow"))
assert(cometErr.get.getMessage.contains("integer overflow"))
}
assert(cometErr.get.getMessage.contains("`NaiveDate - TimeDelta` overflowed"))
}
}
}
Expand Down
Loading