Skip to content

Commit 977a189

Browse files
authored
fix: Fall back to Spark for trunc / date_trunc functions when format string is unsupported, or is not a literal value (#2634)
1 parent f826b65 commit 977a189

File tree

6 files changed

+202
-13
lines changed

6 files changed

+202
-13
lines changed

.github/workflows/pr_build_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ jobs:
141141
org.apache.spark.CometPluginsDefaultSuite
142142
org.apache.spark.CometPluginsNonOverrideSuite
143143
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
144+
org.apache.comet.CometTemporalExpressionSuite
144145
org.apache.spark.sql.CometTPCDSQuerySuite
145146
org.apache.spark.sql.CometTPCDSQueryTestSuite
146147
org.apache.spark.sql.CometTPCHQuerySuite

.github/workflows/pr_build_macos.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ jobs:
106106
org.apache.spark.CometPluginsDefaultSuite
107107
org.apache.spark.CometPluginsNonOverrideSuite
108108
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
109+
org.apache.comet.CometTemporalExpressionSuite
109110
org.apache.spark.sql.CometTPCDSQuerySuite
110111
org.apache.spark.sql.CometTPCDSQueryTestSuite
111112
org.apache.spark.sql.CometTPCHQuerySuite

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,10 @@ object CometConf extends ShimCometConf {
740740
s"${CometConf.COMET_EXPR_CONFIG_PREFIX}.$name.allowIncompatible"
741741
}
742742

743+
def getExprAllowIncompatConfigKey(exprClass: Class[_]): String = {
744+
s"${CometConf.COMET_EXPR_CONFIG_PREFIX}.${exprClass.getSimpleName}.allowIncompatible"
745+
}
746+
743747
def getBooleanConf(name: String, defaultValue: Boolean, conf: SQLConf): Boolean = {
744748
conf.getConfString(name, defaultValue.toString).toLowerCase(Locale.ROOT) == "true"
745749
}

spark/src/main/scala/org/apache/comet/serde/datetime.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919

2020
package org.apache.comet.serde
2121

22+
import java.util.Locale
23+
2224
import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, Literal, Minute, Month, Quarter, Second, TruncDate, TruncTimestamp, WeekDay, WeekOfYear, Year}
2325
import org.apache.spark.sql.types.{DateType, IntegerType}
26+
import org.apache.spark.unsafe.types.UTF8String
2427

2528
import org.apache.comet.CometSparkSessionExtensions.withInfo
2629
import org.apache.comet.serde.CometGetDateField.CometGetDateField
@@ -256,6 +259,24 @@ object CometDateAdd extends CometScalarFunction[DateAdd]("date_add")
256259
object CometDateSub extends CometScalarFunction[DateSub]("date_sub")
257260

258261
object CometTruncDate extends CometExpressionSerde[TruncDate] {
262+
263+
val supportedFormats: Seq[String] =
264+
Seq("year", "yyyy", "yy", "quarter", "mon", "month", "mm", "week")
265+
266+
override def getSupportLevel(expr: TruncDate): SupportLevel = {
267+
expr.format match {
268+
case Literal(fmt: UTF8String, _) =>
269+
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
270+
Compatible()
271+
} else {
272+
Unsupported(Some(s"Format $fmt is not supported"))
273+
}
274+
case _ =>
275+
Incompatible(
276+
Some("Invalid format strings will throw an exception instead of returning NULL"))
277+
}
278+
}
279+
259280
override def convert(
260281
expr: TruncDate,
261282
inputs: Seq[Attribute],
@@ -274,6 +295,39 @@ object CometTruncDate extends CometExpressionSerde[TruncDate] {
274295
}
275296

276297
object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] {
298+
299+
val supportedFormats: Seq[String] =
300+
Seq(
301+
"year",
302+
"yyyy",
303+
"yy",
304+
"quarter",
305+
"mon",
306+
"month",
307+
"mm",
308+
"week",
309+
"day",
310+
"dd",
311+
"hour",
312+
"minute",
313+
"second",
314+
"millisecond",
315+
"microsecond")
316+
317+
override def getSupportLevel(expr: TruncTimestamp): SupportLevel = {
318+
expr.format match {
319+
case Literal(fmt: UTF8String, _) =>
320+
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
321+
Compatible()
322+
} else {
323+
Unsupported(Some(s"Format $fmt is not supported"))
324+
}
325+
case _ =>
326+
Incompatible(
327+
Some("Invalid format strings will throw an exception instead of returning NULL"))
328+
}
329+
}
330+
277331
override def convert(
278332
expr: TruncTimestamp,
279333
inputs: Seq[Attribute],

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.scalatest.Tag
3030

3131
import org.apache.hadoop.fs.Path
3232
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
33-
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
33+
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Literal, TruncDate, TruncTimestamp}
3434
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
3535
import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec, CometWindowExec}
3636
import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, SparkPlan, WholeStageCodegenExec}
@@ -706,11 +706,13 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
706706
val path = new Path(dir.toURI.toString, "date_trunc_with_format.parquet")
707707
makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows)
708708
withParquetTable(path.toString, "dateformattbl") {
709-
checkSparkAnswerAndOperator(
710-
"SELECT " +
711-
"dateformat, _7, " +
712-
"trunc(_7, dateformat) " +
713-
" from dateformattbl ")
709+
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[TruncDate]) -> "true") {
710+
checkSparkAnswerAndOperator(
711+
"SELECT " +
712+
"dateformat, _7, " +
713+
"trunc(_7, dateformat) " +
714+
" from dateformattbl ")
715+
}
714716
}
715717
}
716718
}
@@ -787,13 +789,15 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
787789
}
788790

789791
test("date_trunc with format array") {
790-
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
791-
val numRows = 1000
792-
Seq(true, false).foreach { dictionaryEnabled =>
793-
withTempDir { dir =>
794-
val path = new Path(dir.toURI.toString, "timestamp_trunc_with_format.parquet")
795-
makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows)
796-
withParquetTable(path.toString, "timeformattbl") {
792+
val numRows = 1000
793+
Seq(true, false).foreach { dictionaryEnabled =>
794+
withTempDir { dir =>
795+
val path = new Path(dir.toURI.toString, "timestamp_trunc_with_format.parquet")
796+
makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows)
797+
withParquetTable(path.toString, "timeformattbl") {
798+
withSQLConf(
799+
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
800+
CometConf.getExprAllowIncompatConfigKey(classOf[TruncTimestamp]) -> "true") {
797801
checkSparkAnswerAndOperator(
798802
"SELECT " +
799803
"format, _0, _1, _2, _3, _4, _5, " +
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet
21+
22+
import scala.util.Random
23+
24+
import org.apache.spark.sql.{CometTestBase, SaveMode}
25+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
26+
import org.apache.spark.sql.internal.SQLConf
27+
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
28+
29+
import org.apache.comet.serde.{CometTruncDate, CometTruncTimestamp}
30+
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
31+
32+
class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
33+
34+
test("trunc (TruncDate)") {
35+
val supportedFormats = CometTruncDate.supportedFormats
36+
val unsupportedFormats = Seq("invalid")
37+
38+
val r = new Random(42)
39+
val schema = StructType(
40+
Seq(
41+
StructField("c0", DataTypes.DateType, true),
42+
StructField("c1", DataTypes.StringType, true)))
43+
val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions())
44+
45+
df.createOrReplaceTempView("tbl")
46+
47+
for (format <- supportedFormats) {
48+
checkSparkAnswerAndOperator(s"SELECT c0, trunc(c0, '$format') from tbl order by c0, c1")
49+
}
50+
for (format <- unsupportedFormats) {
51+
// Comet should fall back to Spark for unsupported or invalid formats
52+
checkSparkAnswerAndFallbackReason(
53+
s"SELECT c0, trunc(c0, '$format') from tbl order by c0, c1",
54+
s"Format $format is not supported")
55+
}
56+
57+
// Comet should fall back to Spark if format is not a literal
58+
checkSparkAnswerAndFallbackReason(
59+
"SELECT c0, trunc(c0, c1) from tbl order by c0, c1",
60+
"Invalid format strings will throw an exception instead of returning NULL")
61+
}
62+
63+
test("date_trunc (TruncTimestamp) - reading from DataFrame") {
64+
val supportedFormats = CometTruncTimestamp.supportedFormats
65+
val unsupportedFormats = Seq("invalid")
66+
67+
createTimestampTestData.createOrReplaceTempView("tbl")
68+
69+
// TODO test fails with non-UTC timezone
70+
// https://github.com/apache/datafusion-comet/issues/2649
71+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
72+
for (format <- supportedFormats) {
73+
checkSparkAnswerAndOperator(s"SELECT c0, date_trunc('$format', c0) from tbl order by c0")
74+
}
75+
for (format <- unsupportedFormats) {
76+
// Comet should fall back to Spark for unsupported or invalid formats
77+
checkSparkAnswerAndFallbackReason(
78+
s"SELECT c0, date_trunc('$format', c0) from tbl order by c0",
79+
s"Format $format is not supported")
80+
}
81+
// Comet should fall back to Spark if format is not a literal
82+
checkSparkAnswerAndFallbackReason(
83+
"SELECT c0, date_trunc(fmt, c0) from tbl order by c0, fmt",
84+
"Invalid format strings will throw an exception instead of returning NULL")
85+
}
86+
}
87+
88+
test("date_trunc (TruncTimestamp) - reading from Parquet") {
89+
val supportedFormats = CometTruncTimestamp.supportedFormats
90+
val unsupportedFormats = Seq("invalid")
91+
92+
withTempDir { path =>
93+
createTimestampTestData.write.mode(SaveMode.Overwrite).parquet(path.toString)
94+
spark.read.parquet(path.toString).createOrReplaceTempView("tbl")
95+
96+
// TODO test fails with non-UTC timezone
97+
// https://github.com/apache/datafusion-comet/issues/2649
98+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
99+
for (format <- supportedFormats) {
100+
checkSparkAnswerAndOperator(
101+
s"SELECT c0, date_trunc('$format', c0) from tbl order by c0")
102+
}
103+
for (format <- unsupportedFormats) {
104+
// Comet should fall back to Spark for unsupported or invalid formats
105+
checkSparkAnswerAndFallbackReason(
106+
s"SELECT c0, date_trunc('$format', c0) from tbl order by c0",
107+
s"Format $format is not supported")
108+
}
109+
// Comet should fall back to Spark if format is not a literal
110+
checkSparkAnswerAndFallbackReason(
111+
"SELECT c0, date_trunc(fmt, c0) from tbl order by c0, fmt",
112+
"Invalid format strings will throw an exception instead of returning NULL")
113+
}
114+
}
115+
}
116+
117+
private def createTimestampTestData = {
118+
val r = new Random(42)
119+
val schema = StructType(
120+
Seq(
121+
StructField("c0", DataTypes.TimestampType, true),
122+
StructField("fmt", DataTypes.StringType, true)))
123+
FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions())
124+
}
125+
}

0 commit comments

Comments
 (0)