Skip to content

Commit 1f75eda

Browse files
authored
chore: Implement date_trunc as ScalarUDFImpl (apache#1880)
1 parent 6bf80b1 commit 1f75eda

File tree

6 files changed

+43
-94
lines changed

6 files changed

+43
-94
lines changed

native/core/src/execution/planner.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ use datafusion::{
6666
};
6767
use datafusion_comet_spark_expr::{
6868
create_comet_physical_fun, create_negate_expr, SparkBitwiseCount, SparkBitwiseNot,
69+
SparkDateTrunc,
6970
};
7071

7172
use crate::execution::operators::ExecutionError::GeneralError;
@@ -105,10 +106,10 @@ use datafusion_comet_proto::{
105106
};
106107
use datafusion_comet_spark_expr::{
107108
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Contains, Correlation, Covariance,
108-
CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, GetStructField, HourExpr,
109-
IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr,
110-
SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal,
111-
TimestampTruncExpr, ToJson, UnboundColumn, Variance,
109+
CreateNamedStruct, EndsWith, GetArrayStructFields, GetStructField, HourExpr, IfExpr, Like,
110+
ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr, SparkCastOptions, StartsWith,
111+
Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn,
112+
Variance,
112113
};
113114
use datafusion_spark::function::math::expm1::SparkExpm1;
114115
use itertools::Itertools;
@@ -158,6 +159,7 @@ impl PhysicalPlanner {
158159
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default()));
159160
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default()));
160161
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseCount::default()));
162+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateTrunc::default()));
161163
Self {
162164
exec_context_id: TEST_EXEC_CONTEXT_ID,
163165
session_ctx,
@@ -475,13 +477,6 @@ impl PhysicalPlanner {
475477

476478
Ok(Arc::new(SecondExpr::new(child, timezone)))
477479
}
478-
ExprStruct::TruncDate(expr) => {
479-
let child =
480-
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
481-
let format = self.create_expr(expr.format.as_ref().unwrap(), input_schema)?;
482-
483-
Ok(Arc::new(DateTruncExpr::new(child, format)))
484-
}
485480
ExprStruct::TruncTimestamp(expr) => {
486481
let child =
487482
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;

native/proto/src/proto/expr.proto

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ message Expr {
7070
BinaryExpr bitwiseShiftLeft = 43;
7171
IfExpr if = 44;
7272
NormalizeNaNAndZero normalize_nan_and_zero = 45;
73-
TruncDate truncDate = 46;
7473
TruncTimestamp truncTimestamp = 47;
7574
Abs abs = 49;
7675
Subquery subquery = 50;
@@ -344,11 +343,6 @@ message IfExpr {
344343
Expr false_expr = 3;
345344
}
346345

347-
message TruncDate {
348-
Expr child = 1;
349-
Expr format = 2;
350-
}
351-
352346
message TruncTimestamp {
353347
Expr format = 1;
354348
Expr child = 2;

native/spark-expr/src/datetime_funcs/date_trunc.rs

Lines changed: 32 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,76 +15,58 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::datatypes::{DataType, Schema};
19-
use arrow::record_batch::RecordBatch;
20-
use datafusion::common::{DataFusionError, ScalarValue::Utf8};
21-
use datafusion::logical_expr::ColumnarValue;
22-
use datafusion::physical_expr::PhysicalExpr;
23-
use std::hash::Hash;
24-
use std::{
25-
any::Any,
26-
fmt::{Debug, Display, Formatter},
27-
sync::Arc,
18+
use arrow::datatypes::DataType;
19+
use datafusion::common::{utils::take_function_args, DataFusionError, Result, ScalarValue::Utf8};
20+
use datafusion::logical_expr::{
21+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
2822
};
23+
use std::any::Any;
2924

3025
use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn};
3126

32-
#[derive(Debug, Eq)]
33-
pub struct DateTruncExpr {
34-
/// An array with DataType::Date32
35-
child: Arc<dyn PhysicalExpr>,
36-
/// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc
37-
format: Arc<dyn PhysicalExpr>,
27+
#[derive(Debug)]
28+
pub struct SparkDateTrunc {
29+
signature: Signature,
30+
aliases: Vec<String>,
3831
}
3932

40-
impl Hash for DateTruncExpr {
41-
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
42-
self.child.hash(state);
43-
self.format.hash(state);
44-
}
45-
}
46-
impl PartialEq for DateTruncExpr {
47-
fn eq(&self, other: &Self) -> bool {
48-
self.child.eq(&other.child) && self.format.eq(&other.format)
49-
}
50-
}
51-
52-
impl DateTruncExpr {
53-
pub fn new(child: Arc<dyn PhysicalExpr>, format: Arc<dyn PhysicalExpr>) -> Self {
54-
DateTruncExpr { child, format }
33+
impl SparkDateTrunc {
34+
pub fn new() -> Self {
35+
Self {
36+
signature: Signature::exact(
37+
vec![DataType::Date32, DataType::Utf8],
38+
Volatility::Immutable,
39+
),
40+
aliases: vec![],
41+
}
5542
}
5643
}
5744

58-
impl Display for DateTruncExpr {
59-
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60-
write!(
61-
f,
62-
"DateTrunc [child:{}, format: {}]",
63-
self.child, self.format
64-
)
45+
impl Default for SparkDateTrunc {
46+
fn default() -> Self {
47+
Self::new()
6548
}
6649
}
6750

68-
impl PhysicalExpr for DateTruncExpr {
51+
impl ScalarUDFImpl for SparkDateTrunc {
6952
fn as_any(&self) -> &dyn Any {
7053
self
7154
}
7255

73-
fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
74-
unimplemented!()
56+
fn name(&self) -> &str {
57+
"date_trunc"
7558
}
7659

77-
fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result<DataType> {
78-
self.child.data_type(input_schema)
60+
fn signature(&self) -> &Signature {
61+
&self.signature
7962
}
8063

81-
fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
82-
Ok(true)
64+
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
65+
Ok(DataType::Date32)
8366
}
8467

85-
fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
86-
let date = self.child.evaluate(batch)?;
87-
let format = self.format.evaluate(batch)?;
68+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
69+
let [date, format] = take_function_args(self.name(), args.args)?;
8870
match (date, format) {
8971
(ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => {
9072
let result = date_trunc_dyn(&date, format)?;
@@ -101,17 +83,7 @@ impl PhysicalExpr for DateTruncExpr {
10183
}
10284
}
10385

104-
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
105-
vec![&self.child]
106-
}
107-
108-
fn with_new_children(
109-
self: Arc<Self>,
110-
children: Vec<Arc<dyn PhysicalExpr>>,
111-
) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
112-
Ok(Arc::new(DateTruncExpr::new(
113-
Arc::clone(&children[0]),
114-
Arc::clone(&self.format),
115-
)))
86+
fn aliases(&self) -> &[String] {
87+
&self.aliases
11688
}
11789
}

native/spark-expr/src/datetime_funcs/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ mod second;
2323
mod timestamp_trunc;
2424

2525
pub use date_arithmetic::{spark_date_add, spark_date_sub};
26-
pub use date_trunc::DateTruncExpr;
26+
pub use date_trunc::SparkDateTrunc;
2727
pub use hour::HourExpr;
2828
pub use minute::MinuteExpr;
2929
pub use second::SecondExpr;

native/spark-expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pub use conversion_funcs::*;
6060

6161
pub use comet_scalar_funcs::create_comet_physical_fun;
6262
pub use datetime_funcs::{
63-
spark_date_add, spark_date_sub, DateTruncExpr, HourExpr, MinuteExpr, SecondExpr,
63+
spark_date_add, spark_date_sub, HourExpr, MinuteExpr, SecondExpr, SparkDateTrunc,
6464
TimestampTruncExpr,
6565
};
6666
pub use error::{SparkError, SparkResult};

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,21 +1044,9 @@ object QueryPlanSerde extends Logging with CometExprShim {
10441044
case TruncDate(child, format) =>
10451045
val childExpr = exprToProtoInternal(child, inputs, binding)
10461046
val formatExpr = exprToProtoInternal(format, inputs, binding)
1047-
1048-
if (childExpr.isDefined && formatExpr.isDefined) {
1049-
val builder = ExprOuterClass.TruncDate.newBuilder()
1050-
builder.setChild(childExpr.get)
1051-
builder.setFormat(formatExpr.get)
1052-
1053-
Some(
1054-
ExprOuterClass.Expr
1055-
.newBuilder()
1056-
.setTruncDate(builder)
1057-
.build())
1058-
} else {
1059-
withInfo(expr, child, format)
1060-
None
1061-
}
1047+
val optExpr =
1048+
scalarFunctionExprToProtoWithReturnType("date_trunc", DateType, childExpr, formatExpr)
1049+
optExprWithInfo(optExpr, expr, child, format)
10621050

10631051
case TruncTimestamp(format, child, timeZoneId) =>
10641052
val childExpr = exprToProtoInternal(child, inputs, binding)

0 commit comments

Comments
 (0)