Skip to content

Commit 9475d94

Browse files
mgaido91Robert Kruszewski
authored andcommitted
[SPARK-24957][SQL] Average with decimal followed by aggregation returns wrong result
## What changes were proposed in this pull request? When we do an average, the result is computed dividing the sum of the values by their count. In the case the result is a DecimalType, the way we are casting/managing the precision and scale is not really optimized and it is not coherent with what we do normally. In particular, a problem can happen when the `Divide` operand returns a result which contains a precision and scale different by the ones which are expected as output of the `Divide` operand. In the case reported in the JIRA, for instance, the result of the `Divide` operand is a `Decimal(38, 36)`, while the output data type for `Divide` is 38, 22. This is not an issue when the `Divide` is followed by a `CheckOverflow` or a `Cast` to the right data type, as these operations return a decimal with the defined precision and scale. Despite in the `Average` operator we do have a `Cast`, this may be bypassed if the result of `Divide` is the same type which it is casted to, hence the issue reported in the JIRA may arise. The PR proposes to use the normal rules/handling of the arithmetic operators with Decimal data type, so we both reuse the existing code (having a single logic for operations between decimals) and we fix this problem as the result is always guarded by `CheckOverflow`. ## How was this patch tested? added UT Author: Marco Gaido <[email protected]> Closes apache#21910 from mgaido91/SPARK-24957.
1 parent d68b3af commit 9475d94

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ object DecimalPrecision extends TypeCoercionRule {
8989
}
9090

9191
/** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */
92-
private val decimalAndDecimal: PartialFunction[Expression, Expression] = {
92+
private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = {
9393
// Skip nodes whose children have not been resolved yet
9494
case e if !e.childrenResolved => e
9595

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
20+
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult}
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -57,10 +57,9 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate {
5757

5858
// If all input are nulls, count will be 0 and we will get null after the division.
5959
override lazy val evaluateExpression = child.dataType match {
60-
case DecimalType.Fixed(p, s) =>
61-
// increase the precision and scale to prevent precision loss
62-
val dt = DecimalType.bounded(p + 14, s + 4)
63-
Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)),
60+
case _: DecimalType =>
61+
Cast(
62+
DecimalPrecision.decimalAndDecimal(sum / Cast(count, DecimalType.LongDecimal)),
6463
resultType)
6564
case _ =>
6665
Cast(sum, resultType) / Cast(count, resultType)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
10051005
)
10061006
)
10071007
}
1008+
1009+
test("SPARK-24957: average with decimal followed by aggregation returning wrong result") {
1010+
val df = Seq(("a", BigDecimal("12.0")),
1011+
("a", BigDecimal("12.0")),
1012+
("a", BigDecimal("11.9999999988")),
1013+
("a", BigDecimal("12.0")),
1014+
("a", BigDecimal("12.0")),
1015+
("a", BigDecimal("11.9999999988")),
1016+
("a", BigDecimal("11.9999999988"))).toDF("text", "number")
1017+
val agg1 = df.groupBy($"text").agg(avg($"number").as("avg_res"))
1018+
val agg2 = agg1.groupBy($"text").agg(sum($"avg_res"))
1019+
checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142860000")))
1020+
}
10081021
}
10091022

10101023

0 commit comments

Comments
 (0)