Skip to content

Commit 52e6b34

Browse files
authored
fix: Add strict floating point mode and fallback to Spark for min/max/sort on floating point inputs when enabled (#2747)
1 parent ddf788c commit 52e6b34

File tree

8 files changed

+100
-13
lines changed

8 files changed

+100
-13
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,15 @@ object CometConf extends ShimCometConf {
674674
.booleanConf
675675
.createWithDefault(false)
676676

677+
val COMET_EXEC_STRICT_FLOATING_POINT: ConfigEntry[Boolean] =
678+
conf("spark.comet.exec.strictFloatingPoint")
679+
.category(CATEGORY_EXEC)
680+
.doc(
681+
"When enabled, fall back to Spark for floating-point operations that may differ from " +
682+
s"Spark, such as when comparing or sorting -0.0 and 0.0. $COMPAT_GUIDE.")
683+
.booleanConf
684+
.createWithDefault(false)
685+
677686
val COMET_REGEXP_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] =
678687
conf("spark.comet.regexp.allowIncompatible")
679688
.category(CATEGORY_EXEC)

docs/source/user-guide/latest/compatibility.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,10 @@ Spark normalizes NaN and zero for floating point numbers for several cases. See
4747
However, one exception is comparison. Spark does not normalize NaN and zero when comparing values
4848
because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`). But the comparison
4949
functions of arrow-rs used by DataFusion do not normalize NaN and zero (e.g., [arrow::compute::kernels::cmp::eq](https://docs.rs/arrow/latest/arrow/compute/kernels/cmp/fn.eq.html#)).
50-
So Comet will add additional normalization expression of NaN and zero for comparison.
51-
52-
Sorting on floating-point data types (or complex types containing floating-point values) is not compatible with
53-
Spark if the data contains both zero and negative zero. This is likely an edge case that is not of concern for many users
54-
and sorting on floating-point data can be enabled by setting `spark.comet.expression.SortOrder.allowIncompatible=true`.
50+
So Comet adds additional normalization expression of NaN and zero for comparisons, and may still have differences
51+
to Spark in some cases, especially when the data contains both positive and negative zero. This is likely an edge
52+
case that is not of concern for many users. If it is a concern, setting `spark.comet.exec.strictFloatingPoint=true`
53+
will make relevant operations fall back to Spark.
5554

5655
## Incompatible Expressions
5756

docs/source/user-guide/latest/configs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Comet provides the following configuration settings.
6262
| `spark.comet.exceptionOnDatetimeRebase` | Whether to throw exception when seeing dates/timestamps from the legacy hybrid (Julian + Gregorian) calendar. Since Spark 3, dates/timestamps were written according to the Proleptic Gregorian calendar. When this is true, Comet will throw exceptions when seeing these dates/timestamps that were written by Spark version before 3.0. If this is false, these dates/timestamps will be read as if they were written to the Proleptic Gregorian calendar and will not be rebased. | false |
6363
| `spark.comet.exec.enabled` | Whether to enable Comet native vectorized execution for Spark. This controls whether Spark should convert operators into their Comet counterparts and execute them in native space. Note: each operator is associated with a separate config in the format of `spark.comet.exec.<operator_name>.enabled` at the moment, and both the config and this need to be turned on, in order for the operator to be executed in native. | true |
6464
| `spark.comet.exec.replaceSortMergeJoin` | Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin for improved performance. This feature is not stable yet. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | false |
65+
| `spark.comet.exec.strictFloatingPoint` | When enabled, fall back to Spark for floating-point operations that differ from Spark, such as when comparing or sorting -0.0 and 0.0. For more information, refer to the [Comet Compatibility Guide](https://datafusion.apache.org/comet/user-guide/compatibility.html). | false |
6566
| `spark.comet.expression.allowIncompatible` | Comet is not currently fully compatible with Spark for all expressions. Set this config to true to allow them anyway. For more information, refer to the [Comet Compatibility Guide](https://datafusion.apache.org/comet/user-guide/compatibility.html). | false |
6667
| `spark.comet.maxTempDirectorySize` | The maximum amount of data (in bytes) stored inside the temporary directories. | 107374182400b |
6768
| `spark.comet.metrics.updateInterval` | The interval in milliseconds to update metrics. If interval is negative, metrics will be updated upon task completion. | 3000 |

spark/src/main/scala/org/apache/comet/serde/CometSortOrder.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,14 @@ object CometSortOrder extends CometExpressionSerde[SortOrder] {
4141
}
4242
}
4343

44-
if (containsFloatingPoint(expr.child.dataType)) {
45-
Incompatible(Some(
46-
s"Sorting on floating-point is not 100% compatible with Spark. ${CometConf.COMPAT_GUIDE}"))
44+
if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
45+
containsFloatingPoint(expr.child.dataType)) {
46+
// https://github.com/apache/datafusion-comet/issues/2626
47+
Incompatible(
48+
Some(
49+
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
50+
s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
51+
s"${CometConf.COMPAT_GUIDE}"))
4752
} else {
4853
Compatible()
4954
}

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ import scala.jdk.CollectionConverters._
2424
import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode}
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
2626
import org.apache.spark.sql.internal.SQLConf
27-
import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, LongType, ShortType, StringType}
27+
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}
2828

2929
import org.apache.comet.CometConf
30+
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
3031
import org.apache.comet.CometSparkSessionExtensions.withInfo
3132
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
3233

@@ -42,6 +43,17 @@ object CometMin extends CometAggregateExpressionSerde[Min] {
4243
withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
4344
return None
4445
}
46+
47+
if (expr.dataType == DataTypes.FloatType || expr.dataType == DataTypes.DoubleType) {
48+
if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get()) {
49+
// https://github.com/apache/datafusion-comet/issues/2448
50+
withInfo(
51+
aggExpr,
52+
s"floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true")
53+
return None
54+
}
55+
}
56+
4557
val child = expr.children.head
4658
val childExpr = exprToProto(child, inputs, binding)
4759
val dataType = serializeDataType(expr.dataType)
@@ -78,6 +90,17 @@ object CometMax extends CometAggregateExpressionSerde[Max] {
7890
withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
7991
return None
8092
}
93+
94+
if (expr.dataType == DataTypes.FloatType || expr.dataType == DataTypes.DoubleType) {
95+
if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get()) {
96+
// https://github.com/apache/datafusion-comet/issues/2448
97+
withInfo(
98+
aggExpr,
99+
s"floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true")
100+
return None
101+
}
102+
}
103+
81104
val child = expr.children.head
82105
val childExpr = exprToProto(child, inputs, binding)
83106
val dataType = serializeDataType(expr.dataType)

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,15 @@ trait CometBaseAggregate {
10441044

10451045
val aggExprs =
10461046
aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf))
1047+
1048+
if (aggExprs.exists(_.isEmpty)) {
1049+
withInfo(
1050+
aggregate,
1051+
"Unsupported aggregate expression(s)",
1052+
aggregateExpressions ++ aggregateExpressions.map(_.aggregateFunction): _*)
1053+
return None
1054+
}
1055+
10471056
if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
10481057
aggExprs.forall(_.isDefined)) {
10491058
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
7272
DataGenOptions(generateNegativeZero = true))
7373
df.createOrReplaceTempView("tbl")
7474

75-
withSQLConf(CometConf.getExprAllowIncompatConfigKey("SortOrder") -> "false") {
75+
withSQLConf(
76+
CometConf.getExprAllowIncompatConfigKey("SortOrder") -> "false",
77+
CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key -> "true") {
7678
checkSparkAnswerAndFallbackReasons(
7779
"select * from tbl order by 1, 2",
7880
Set(
@@ -94,7 +96,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
9496
DataGenOptions(generateNegativeZero = true))
9597
df.createOrReplaceTempView("tbl")
9698

97-
withSQLConf(CometConf.getExprAllowIncompatConfigKey("SortOrder") -> "false") {
99+
withSQLConf(
100+
CometConf.getExprAllowIncompatConfigKey("SortOrder") -> "false",
101+
CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key -> "true") {
98102
checkSparkAnswerAndFallbackReason(
99103
"select * from tbl order by 1, 2",
100104
"unsupported range partitioning sort order")
@@ -118,7 +122,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
118122
DataGenOptions(generateNegativeZero = true))
119123
df.createOrReplaceTempView("tbl")
120124

121-
withSQLConf(CometConf.getExprAllowIncompatConfigKey("SortOrder") -> "false") {
125+
withSQLConf(
126+
CometConf.getExprAllowIncompatConfigKey("SortOrder") -> "false",
127+
CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key -> "true") {
122128
checkSparkAnswerAndFallbackReason(
123129
"select * from tbl order by 1, 2",
124130
"unsupported range partitioning sort order")

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,51 @@ import org.apache.spark.sql.comet.CometHashAggregateExec
2828
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
2929
import org.apache.spark.sql.functions.{avg, count_distinct, sum}
3030
import org.apache.spark.sql.internal.SQLConf
31+
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
3132

3233
import org.apache.comet.CometConf
33-
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions}
34+
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
35+
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGenerator, SchemaGenOptions}
3436

3537
/**
3638
* Test suite dedicated to Comet native aggregate operator
3739
*/
3840
class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
3941
import testImplicits._
4042

43+
test("min/max floating point with negative zero") {
44+
val r = new Random(42)
45+
val schema = StructType(
46+
Seq(
47+
StructField("float_col", DataTypes.FloatType, nullable = true),
48+
StructField("double_col", DataTypes.DoubleType, nullable = true)))
49+
val df = FuzzDataGenerator.generateDataFrame(
50+
r,
51+
spark,
52+
schema,
53+
1000,
54+
DataGenOptions(generateNegativeZero = true))
55+
df.createOrReplaceTempView("tbl")
56+
57+
for (col <- Seq("float_col", "double_col")) {
58+
// assert that data contains positive and negative zero
59+
assert(spark.sql(s"select * from tbl where cast($col as string) = '0.0'").count() > 0)
60+
assert(spark.sql(s"select * from tbl where cast($col as string) = '-0.0'").count() > 0)
61+
for (agg <- Seq("min", "max")) {
62+
withSQLConf(COMET_EXEC_STRICT_FLOATING_POINT.key -> "true") {
63+
checkSparkAnswerAndFallbackReasons(
64+
s"select $agg($col) from tbl where cast($col as string) in ('0.0', '-0.0')",
65+
Set(
66+
"Unsupported aggregate expression(s)",
67+
s"floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true"))
68+
}
69+
checkSparkAnswer(
70+
s"select $col, count(*) from tbl " +
71+
s"where cast($col as string) in ('0.0', '-0.0') group by $col")
72+
}
73+
}
74+
}
75+
4176
test("avg decimal") {
4277
withTempDir { dir =>
4378
val path = new Path(dir.toURI.toString, "test.parquet")

0 commit comments

Comments
 (0)