Skip to content

Commit 0f3fa2f

Browse files
mgaido91gatorsmile
authored andcommitted
[SPARK-24996][SQL] Use DSL in DeclarativeAggregate
## What changes were proposed in this pull request? The PR refactors the aggregate expressions which were not using DSL in order to simplify them. ## How was this patch tested? NA Author: Marco Gaido <[email protected]> Closes apache#21970 from mgaido91/SPARK-24996.
1 parent 408a3ff commit 0f3fa2f

File tree

11 files changed

+65
-69
lines changed

11 files changed

+65
-69
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ package object dsl {
167167
def upper(e: Expression): Expression = Upper(e)
168168
def lower(e: Expression): Expression = Lower(e)
169169
def coalesce(args: Expression*): Expression = Coalesce(args)
170+
def greatest(args: Expression*): Expression = Greatest(args)
171+
def least(args: Expression*): Expression = Least(args)
170172
def sqrt(e: Expression): Expression = Sqrt(e)
171173
def abs(e: Expression): Expression = Abs(e)
172174
def star(names: String*): Expression = names match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate {
6868
Add(
6969
sum,
7070
coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))),
71-
/* count = */ If(IsNull(child), count, count + 1L)
71+
/* count = */ If(child.isNull, count, count + 1L)
7272
)
7373

7474
override lazy val updateExpressions = updateExpressionsDef

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ abstract class CentralMomentAgg(child: Expression)
7575
val n2 = n.right
7676
val newN = n1 + n2
7777
val delta = avg.right - avg.left
78-
val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN)
78+
val deltaN = If(newN === 0.0, 0.0, delta / newN)
7979
val newAvg = avg.left + deltaN * n2
8080

8181
// higher order moments computed according to:
@@ -102,7 +102,7 @@ abstract class CentralMomentAgg(child: Expression)
102102
}
103103

104104
protected def updateExpressionsDef: Seq[Expression] = {
105-
val newN = n + Literal(1.0)
105+
val newN = n + 1.0
106106
val delta = child - avg
107107
val deltaN = delta / newN
108108
val newAvg = avg + deltaN
@@ -123,11 +123,11 @@ abstract class CentralMomentAgg(child: Expression)
123123
}
124124

125125
trimHigherOrder(Seq(
126-
If(IsNull(child), n, newN),
127-
If(IsNull(child), avg, newAvg),
128-
If(IsNull(child), m2, newM2),
129-
If(IsNull(child), m3, newM3),
130-
If(IsNull(child), m4, newM4)
126+
If(child.isNull, n, newN),
127+
If(child.isNull, avg, newAvg),
128+
If(child.isNull, m2, newM2),
129+
If(child.isNull, m3, newM3),
130+
If(child.isNull, m4, newM4)
131131
))
132132
}
133133
}
@@ -142,8 +142,7 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) {
142142
override protected def momentOrder = 2
143143

144144
override val evaluateExpression: Expression = {
145-
If(n === Literal(0.0), Literal.create(null, DoubleType),
146-
Sqrt(m2 / n))
145+
If(n === 0.0, Literal.create(null, DoubleType), sqrt(m2 / n))
147146
}
148147

149148
override def prettyName: String = "stddev_pop"
@@ -159,9 +158,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
159158
override protected def momentOrder = 2
160159

161160
override val evaluateExpression: Expression = {
162-
If(n === Literal(0.0), Literal.create(null, DoubleType),
163-
If(n === Literal(1.0), Literal(Double.NaN),
164-
Sqrt(m2 / (n - Literal(1.0)))))
161+
If(n === 0.0, Literal.create(null, DoubleType),
162+
If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0))))
165163
}
166164

167165
override def prettyName: String = "stddev_samp"
@@ -175,8 +173,7 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) {
175173
override protected def momentOrder = 2
176174

177175
override val evaluateExpression: Expression = {
178-
If(n === Literal(0.0), Literal.create(null, DoubleType),
179-
m2 / n)
176+
If(n === 0.0, Literal.create(null, DoubleType), m2 / n)
180177
}
181178

182179
override def prettyName: String = "var_pop"
@@ -190,9 +187,8 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
190187
override protected def momentOrder = 2
191188

192189
override val evaluateExpression: Expression = {
193-
If(n === Literal(0.0), Literal.create(null, DoubleType),
194-
If(n === Literal(1.0), Literal(Double.NaN),
195-
m2 / (n - Literal(1.0))))
190+
If(n === 0.0, Literal.create(null, DoubleType),
191+
If(n === 1.0, Double.NaN, m2 / (n - 1.0)))
196192
}
197193

198194
override def prettyName: String = "var_samp"
@@ -207,9 +203,8 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) {
207203
override protected def momentOrder = 3
208204

209205
override val evaluateExpression: Expression = {
210-
If(n === Literal(0.0), Literal.create(null, DoubleType),
211-
If(m2 === Literal(0.0), Literal(Double.NaN),
212-
Sqrt(n) * m3 / Sqrt(m2 * m2 * m2)))
206+
If(n === 0.0, Literal.create(null, DoubleType),
207+
If(m2 === 0.0, Double.NaN, sqrt(n) * m3 / sqrt(m2 * m2 * m2)))
213208
}
214209
}
215210

@@ -220,9 +215,8 @@ case class Kurtosis(child: Expression) extends CentralMomentAgg(child) {
220215
override protected def momentOrder = 4
221216

222217
override val evaluateExpression: Expression = {
223-
If(n === Literal(0.0), Literal.create(null, DoubleType),
224-
If(m2 === Literal(0.0), Literal(Double.NaN),
225-
n * m4 / (m2 * m2) - Literal(3.0)))
218+
If(n === 0.0, Literal.create(null, DoubleType),
219+
If(m2 === 0.0, Double.NaN, n * m4 / (m2 * m2) - 3.0))
226220
}
227221

228222
override def prettyName: String = "kurtosis"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ abstract class PearsonCorrelation(x: Expression, y: Expression)
5454
val n2 = n.right
5555
val newN = n1 + n2
5656
val dx = xAvg.right - xAvg.left
57-
val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
57+
val dxN = If(newN === 0.0, 0.0, dx / newN)
5858
val dy = yAvg.right - yAvg.left
59-
val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
59+
val dyN = If(newN === 0.0, 0.0, dy / newN)
6060
val newXAvg = xAvg.left + dxN * n2
6161
val newYAvg = yAvg.left + dyN * n2
6262
val newCk = ck.left + ck.right + dx * dyN * n1 * n2
@@ -67,7 +67,7 @@ abstract class PearsonCorrelation(x: Expression, y: Expression)
6767
}
6868

6969
protected def updateExpressionsDef: Seq[Expression] = {
70-
val newN = n + Literal(1.0)
70+
val newN = n + 1.0
7171
val dx = x - xAvg
7272
val dxN = dx / newN
7373
val dy = y - yAvg
@@ -78,7 +78,7 @@ abstract class PearsonCorrelation(x: Expression, y: Expression)
7878
val newXMk = xMk + dx * (x - newXAvg)
7979
val newYMk = yMk + dy * (y - newYAvg)
8080

81-
val isNull = IsNull(x) || IsNull(y)
81+
val isNull = x.isNull || y.isNull
8282
Seq(
8383
If(isNull, n, newN),
8484
If(isNull, xAvg, newXAvg),
@@ -99,9 +99,8 @@ case class Corr(x: Expression, y: Expression)
9999
extends PearsonCorrelation(x, y) {
100100

101101
override val evaluateExpression: Expression = {
102-
If(n === Literal(0.0), Literal.create(null, DoubleType),
103-
If(n === Literal(1.0), Literal(Double.NaN),
104-
ck / Sqrt(xMk * yMk)))
102+
If(n === 0.0, Literal.create(null, DoubleType),
103+
If(n === 1.0, Double.NaN, ck / sqrt(xMk * yMk)))
105104
}
106105

107106
override def prettyName: String = "corr"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ abstract class Covariance(x: Expression, y: Expression)
5050
val n2 = n.right
5151
val newN = n1 + n2
5252
val dx = xAvg.right - xAvg.left
53-
val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
53+
val dxN = If(newN === 0.0, 0.0, dx / newN)
5454
val dy = yAvg.right - yAvg.left
55-
val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
55+
val dyN = If(newN === 0.0, 0.0, dy / newN)
5656
val newXAvg = xAvg.left + dxN * n2
5757
val newYAvg = yAvg.left + dyN * n2
5858
val newCk = ck.left + ck.right + dx * dyN * n1 * n2
@@ -61,15 +61,15 @@ abstract class Covariance(x: Expression, y: Expression)
6161
}
6262

6363
protected def updateExpressionsDef: Seq[Expression] = {
64-
val newN = n + Literal(1.0)
64+
val newN = n + 1.0
6565
val dx = x - xAvg
6666
val dy = y - yAvg
6767
val dyN = dy / newN
6868
val newXAvg = xAvg + dx / newN
6969
val newYAvg = yAvg + dyN
7070
val newCk = ck + dx * (y - newYAvg)
7171

72-
val isNull = IsNull(x) || IsNull(y)
72+
val isNull = x.isNull || y.isNull
7373
Seq(
7474
If(isNull, n, newN),
7575
If(isNull, xAvg, newXAvg),
@@ -83,8 +83,7 @@ abstract class Covariance(x: Expression, y: Expression)
8383
usage = "_FUNC_(expr1, expr2) - Returns the population covariance of a set of number pairs.")
8484
case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) {
8585
override val evaluateExpression: Expression = {
86-
If(n === Literal(0.0), Literal.create(null, DoubleType),
87-
ck / n)
86+
If(n === 0.0, Literal.create(null, DoubleType), ck / n)
8887
}
8988
override def prettyName: String = "covar_pop"
9089
}
@@ -94,9 +93,8 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance
9493
usage = "_FUNC_(expr1, expr2) - Returns the sample covariance of a set of number pairs.")
9594
case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) {
9695
override val evaluateExpression: Expression = {
97-
If(n === Literal(0.0), Literal.create(null, DoubleType),
98-
If(n === Literal(1.0), Literal(Double.NaN),
99-
ck / (n - Literal(1.0))))
96+
If(n === 0.0, Literal.create(null, DoubleType),
97+
If(n === 1.0, Double.NaN, ck / (n - 1.0)))
10098
}
10199
override def prettyName: String = "covar_samp"
102100
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
1919

2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
22+
import org.apache.spark.sql.catalyst.dsl.expressions._
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.types._
2425

@@ -80,8 +81,8 @@ case class First(child: Expression, ignoreNullsExpr: Expression)
8081
override lazy val updateExpressions: Seq[Expression] = {
8182
if (ignoreNulls) {
8283
Seq(
83-
/* first = */ If(Or(valueSet, IsNull(child)), first, child),
84-
/* valueSet = */ Or(valueSet, IsNotNull(child))
84+
/* first = */ If(valueSet || child.isNull, first, child),
85+
/* valueSet = */ valueSet || child.isNotNull
8586
)
8687
} else {
8788
Seq(
@@ -97,7 +98,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression)
9798
// false, we are safe to do so because first.right will be null in this case).
9899
Seq(
99100
/* first = */ If(valueSet.left, first.left, first.right),
100-
/* valueSet = */ Or(valueSet.left, valueSet.right)
101+
/* valueSet = */ valueSet.left || valueSet.right
101102
)
102103
}
103104

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
1919

2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
22+
import org.apache.spark.sql.catalyst.dsl.expressions._
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.types._
2425

@@ -80,8 +81,8 @@ case class Last(child: Expression, ignoreNullsExpr: Expression)
8081
override lazy val updateExpressions: Seq[Expression] = {
8182
if (ignoreNulls) {
8283
Seq(
83-
/* last = */ If(IsNull(child), last, child),
84-
/* valueSet = */ Or(valueSet, IsNotNull(child))
84+
/* last = */ If(child.isNull, last, child),
85+
/* valueSet = */ valueSet || child.isNotNull
8586
)
8687
} else {
8788
Seq(
@@ -95,7 +96,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression)
9596
// Prefer the right hand expression if it has been set.
9697
Seq(
9798
/* last = */ If(valueSet.right, last.right, last.left),
98-
/* valueSet = */ Or(valueSet.right, valueSet.left)
99+
/* valueSet = */ valueSet.right || valueSet.left
99100
)
100101
}
101102

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
2122
import org.apache.spark.sql.catalyst.expressions._
2223
import org.apache.spark.sql.catalyst.util.TypeUtils
2324
import org.apache.spark.sql.types._
@@ -45,12 +46,12 @@ case class Max(child: Expression) extends DeclarativeAggregate {
4546
)
4647

4748
override lazy val updateExpressions: Seq[Expression] = Seq(
48-
/* max = */ Greatest(Seq(max, child))
49+
/* max = */ greatest(max, child)
4950
)
5051

5152
override lazy val mergeExpressions: Seq[Expression] = {
5253
Seq(
53-
/* max = */ Greatest(Seq(max.left, max.right))
54+
/* max = */ greatest(max.left, max.right)
5455
)
5556
}
5657

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
2122
import org.apache.spark.sql.catalyst.expressions._
2223
import org.apache.spark.sql.catalyst.util.TypeUtils
2324
import org.apache.spark.sql.types._
@@ -45,12 +46,12 @@ case class Min(child: Expression) extends DeclarativeAggregate {
4546
)
4647

4748
override lazy val updateExpressions: Seq[Expression] = Seq(
48-
/* min = */ Least(Seq(min, child))
49+
/* min = */ least(min, child)
4950
)
5051

5152
override lazy val mergeExpressions: Seq[Expression] = {
5253
Seq(
53-
/* min = */ Least(Seq(min.left, min.right))
54+
/* min = */ least(min.left, min.right)
5455
)
5556
}
5657

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
2122
import org.apache.spark.sql.catalyst.expressions._
2223
import org.apache.spark.sql.catalyst.util.TypeUtils
2324
import org.apache.spark.sql.types._
@@ -61,20 +62,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
6162
if (child.nullable) {
6263
Seq(
6364
/* sum = */
64-
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
65+
coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
6566
)
6667
} else {
6768
Seq(
6869
/* sum = */
69-
Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType))
70+
coalesce(sum, zero) + child.cast(sumDataType)
7071
)
7172
}
7273
}
7374

7475
override lazy val mergeExpressions: Seq[Expression] = {
7576
Seq(
7677
/* sum = */
77-
Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left))
78+
coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
7879
)
7980
}
8081

0 commit comments

Comments
 (0)