Skip to content

Commit 83d951d

Browse files
kazantsev-maksimKazantsev Maksim
andauthored
Chore: Used DataFusion impl of date_add and date_sub functions (#2473)
* Date_add and date_sub to DataFusion impl * Fix tests --------- Co-authored-by: Kazantsev Maksim <[email protected]>
1 parent f83e8db commit 83d951d

File tree

7 files changed

+15
-146
lines changed

7 files changed

+15
-146
lines changed

native/core/src/execution/jni_api.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ use datafusion::{
4141
};
4242
use datafusion_comet_proto::spark_operator::Operator;
4343
use datafusion_spark::function::bitwise::bit_get::SparkBitGet;
44+
use datafusion_spark::function::datetime::date_add::SparkDateAdd;
45+
use datafusion_spark::function::datetime::date_sub::SparkDateSub;
4446
use datafusion_spark::function::hash::sha2::SparkSha2;
4547
use datafusion_spark::function::math::expm1::SparkExpm1;
4648
use datafusion_spark::function::string::char::CharFunc;
@@ -303,6 +305,8 @@ fn prepare_datafusion_session_context(
303305
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default()));
304306
session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default()));
305307
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default()));
308+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default()));
309+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default()));
306310

307311
// Must be the last one to override existing functions with the same name
308312
datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?;

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ use crate::hash_funcs::*;
1919
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
2020
use crate::math_funcs::modulo_expr::spark_modulo;
2121
use crate::{
22-
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
23-
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
24-
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode,
25-
SparkBitwiseCount, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace,
22+
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
23+
spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
24+
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot,
25+
SparkDateTrunc, SparkStringSpace,
2626
};
2727
use arrow::datatypes::DataType;
2828
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -166,14 +166,6 @@ pub fn create_comet_physical_fun_with_eval_mode(
166166
let func = Arc::new(spark_isnan);
167167
make_comet_scalar_udf!("isnan", func, without data_type)
168168
}
169-
"date_add" => {
170-
let func = Arc::new(spark_date_add);
171-
make_comet_scalar_udf!("date_add", func, without data_type)
172-
}
173-
"date_sub" => {
174-
let func = Arc::new(spark_date_sub);
175-
make_comet_scalar_udf!("date_sub", func, without data_type)
176-
}
177169
"array_repeat" => {
178170
let func = Arc::new(spark_array_repeat);
179171
make_comet_scalar_udf!("array_repeat", func, without data_type)

native/spark-expr/src/datetime_funcs/date_arithmetic.rs

Lines changed: 0 additions & 101 deletions
This file was deleted.

native/spark-expr/src/datetime_funcs/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
mod date_arithmetic;
1918
mod date_trunc;
2019
mod extract_date_part;
2120
mod timestamp_trunc;
2221

23-
pub use date_arithmetic::{spark_date_add, spark_date_sub};
2422
pub use date_trunc::SparkDateTrunc;
2523
pub use extract_date_part::SparkHour;
2624
pub use extract_date_part::SparkMinute;

native/spark-expr/src/lib.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,7 @@ pub use comet_scalar_funcs::{
6868
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
6969
register_all_comet_functions,
7070
};
71-
pub use datetime_funcs::{
72-
spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute, SparkSecond,
73-
TimestampTruncExpr,
74-
};
71+
pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr};
7572
pub use error::{SparkError, SparkResult};
7673
pub use hash_funcs::*;
7774
pub use json_funcs::ToJson;

spark/src/main/scala/org/apache/comet/serde/datetime.scala

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types.{DateType, IntegerType}
2525
import org.apache.comet.CometSparkSessionExtensions.withInfo
2626
import org.apache.comet.serde.CometGetDateField.CometGetDateField
2727
import org.apache.comet.serde.ExprOuterClass.Expr
28-
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType}
28+
import org.apache.comet.serde.QueryPlanSerde._
2929

3030
private object CometGetDateField extends Enumeration {
3131
type CometGetDateField = Value
@@ -251,31 +251,9 @@ object CometSecond extends CometExpressionSerde[Second] {
251251
}
252252
}
253253

254-
object CometDateAdd extends CometExpressionSerde[DateAdd] {
255-
override def convert(
256-
expr: DateAdd,
257-
inputs: Seq[Attribute],
258-
binding: Boolean): Option[ExprOuterClass.Expr] = {
259-
val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
260-
val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
261-
val optExpr =
262-
scalarFunctionExprToProtoWithReturnType("date_add", DateType, leftExpr, rightExpr)
263-
optExprWithInfo(optExpr, expr, expr.left, expr.right)
264-
}
265-
}
254+
object CometDateAdd extends CometScalarFunction[DateAdd]("date_add")
266255

267-
object CometDateSub extends CometExpressionSerde[DateSub] {
268-
override def convert(
269-
expr: DateSub,
270-
inputs: Seq[Attribute],
271-
binding: Boolean): Option[ExprOuterClass.Expr] = {
272-
val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
273-
val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
274-
val optExpr =
275-
scalarFunctionExprToProtoWithReturnType("date_sub", DateType, leftExpr, rightExpr)
276-
optExprWithInfo(optExpr, expr, expr.left, expr.right)
277-
}
278-
}
256+
object CometDateSub extends CometScalarFunction[DateSub]("date_sub")
279257

280258
object CometTruncDate extends CometExpressionSerde[TruncDate] {
281259
override def convert(

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
252252
} else {
253253
assert(sparkErr.get.getMessage.contains("integer overflow"))
254254
}
255-
assert(cometErr.get.getMessage.contains("`NaiveDate + TimeDelta` overflowed"))
255+
assert(cometErr.get.getMessage.contains("attempt to add with overflow"))
256256
}
257257
}
258258
}
@@ -296,10 +296,11 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
296296
checkSparkMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM tbl"))
297297
if (isSpark40Plus) {
298298
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
299+
assert(cometErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
299300
} else {
300301
assert(sparkErr.get.getMessage.contains("integer overflow"))
302+
assert(cometErr.get.getMessage.contains("integer overflow"))
301303
}
302-
assert(cometErr.get.getMessage.contains("`NaiveDate - TimeDelta` overflowed"))
303304
}
304305
}
305306
}

0 commit comments

Comments
 (0)