Skip to content

Commit 327bb30

Browse files
committed
[SPARK-23911][SQL] Add aggregate function.
## What changes were proposed in this pull request? This pr adds `aggregate` function which applies a binary operator to an initial state and all elements in the array, and reduces this to a single state. The final state is converted into the final result by applying a finish function. ```sql > SELECT aggregate(array(1, 2, 3), (acc, x) -> acc + x); 6 > SELECT aggregate(array(1, 2, 3), (acc, x) -> acc + x, acc -> acc * 10); 60 ``` ## How was this patch tested? Added tests. Author: Takuya UESHIN <[email protected]> Closes apache#21982 from ueshin/issues/SPARK-23911/aggregate.
1 parent 5f9633d commit 327bb30

File tree

6 files changed

+318
-1
lines changed

6 files changed

+318
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ object FunctionRegistry {
442442
expression[ArrayDistinct]("array_distinct"),
443443
expression[ArrayTransform]("transform"),
444444
expression[ArrayFilter]("filter"),
445+
expression[ArrayAggregate]("aggregate"),
445446
CreateStruct.registryEntry,
446447

447448
// misc functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicReference
2222
import scala.collection.mutable
2323

2424
import org.apache.spark.sql.catalyst.InternalRow
25+
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
2526
import org.apache.spark.sql.catalyst.expressions.codegen._
2627
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2728
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
@@ -76,6 +77,13 @@ case class LambdaFunction(
7677
override def eval(input: InternalRow): Any = function.eval(input)
7778
}
7879

80+
object LambdaFunction {
81+
val identity: LambdaFunction = {
82+
val id = UnresolvedAttribute.quoted("id")
83+
LambdaFunction(id, Seq(id))
84+
}
85+
}
86+
7987
/**
8088
* A higher order function takes one or more (lambda) functions and applies these to some objects.
8189
* The function produces a number of variables which can be consumed by some lambda function.
@@ -270,3 +278,90 @@ case class ArrayFilter(
270278

271279
override def prettyName: String = "filter"
272280
}
281+
282+
/**
283+
* Applies a binary operator to a start value and all elements in the array.
284+
*/
285+
@ExpressionDescription(
286+
usage =
287+
"""
288+
_FUNC_(expr, start, merge, finish) - Applies a binary operator to an initial state and all
289+
elements in the array, and reduces this to a single state. The final state is converted
290+
into the final result by applying a finish function.
291+
""",
292+
examples = """
293+
Examples:
294+
> SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x);
295+
6
296+
> SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x, acc -> acc * 10);
297+
60
298+
""",
299+
since = "2.4.0")
300+
case class ArrayAggregate(
301+
input: Expression,
302+
zero: Expression,
303+
merge: Expression,
304+
finish: Expression)
305+
extends HigherOrderFunction with CodegenFallback {
306+
307+
def this(input: Expression, zero: Expression, merge: Expression) = {
308+
this(input, zero, merge, LambdaFunction.identity)
309+
}
310+
311+
override def inputs: Seq[Expression] = input :: zero :: Nil
312+
313+
override def functions: Seq[Expression] = merge :: finish :: Nil
314+
315+
override def nullable: Boolean = input.nullable || finish.nullable
316+
317+
override def dataType: DataType = finish.dataType
318+
319+
override def checkInputDataTypes(): TypeCheckResult = {
320+
if (!ArrayType.acceptsType(input.dataType)) {
321+
TypeCheckResult.TypeCheckFailure(
322+
s"argument 1 requires ${ArrayType.simpleString} type, " +
323+
s"however, '${input.sql}' is of ${input.dataType.catalogString} type.")
324+
} else if (!DataType.equalsStructurally(
325+
zero.dataType, merge.dataType, ignoreNullability = true)) {
326+
TypeCheckResult.TypeCheckFailure(
327+
s"argument 3 requires ${zero.dataType.simpleString} type, " +
328+
s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.")
329+
} else {
330+
TypeCheckResult.TypeCheckSuccess
331+
}
332+
}
333+
334+
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = {
335+
// Be very conservative with nullable. We cannot be sure that the accumulator does not
336+
// evaluate to null. So we always set nullable to true here.
337+
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
338+
val acc = zero.dataType -> true
339+
val newMerge = f(merge, acc :: elem :: Nil)
340+
val newFinish = f(finish, acc :: Nil)
341+
copy(merge = newMerge, finish = newFinish)
342+
}
343+
344+
@transient lazy val LambdaFunction(_,
345+
Seq(accForMergeVar: NamedLambdaVariable, elementVar: NamedLambdaVariable), _) = merge
346+
@transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish
347+
348+
override def eval(input: InternalRow): Any = {
349+
val arr = this.input.eval(input).asInstanceOf[ArrayData]
350+
if (arr == null) {
351+
null
352+
} else {
353+
val Seq(mergeForEval, finishForEval) = functionsForEval
354+
accForMergeVar.value.set(zero.eval(input))
355+
var i = 0
356+
while (i < arr.numElements()) {
357+
elementVar.value.set(arr.get(i, elementVar.dataType))
358+
accForMergeVar.value.set(mergeForEval.eval(input))
359+
i += 1
360+
}
361+
accForFinishVar.value.set(accForMergeVar.value.get)
362+
finishForEval.eval(input)
363+
}
364+
}
365+
366+
override def prettyName: String = "aggregate"
367+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,27 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
5959
ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
6060
}
6161

62+
def aggregate(
63+
expr: Expression,
64+
zero: Expression,
65+
merge: (Expression, Expression) => Expression,
66+
finish: Expression => Expression): Expression = {
67+
val at = expr.dataType.asInstanceOf[ArrayType]
68+
val zeroType = zero.dataType
69+
ArrayAggregate(
70+
expr,
71+
zero,
72+
createLambda(zeroType, true, at.elementType, at.containsNull, merge),
73+
createLambda(zeroType, true, finish))
74+
}
75+
76+
def aggregate(
77+
expr: Expression,
78+
zero: Expression,
79+
merge: (Expression, Expression) => Expression): Expression = {
80+
aggregate(expr, zero, merge, identity)
81+
}
82+
6283
test("ArrayTransform") {
6384
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
6485
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
@@ -131,4 +152,33 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
131152
checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)),
132153
Seq(Seq(1, 3), null, Seq(5)))
133154
}
155+
156+
test("ArrayAggregate") {
157+
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
158+
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
159+
val ai2 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false))
160+
val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false))
161+
162+
checkEvaluation(aggregate(ai0, 0, (acc, elem) => acc + elem, acc => acc * 10), 60)
163+
checkEvaluation(aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc => acc * 10), 40)
164+
checkEvaluation(aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 10), 0)
165+
checkEvaluation(aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 10), null)
166+
167+
val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false))
168+
val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true))
169+
val as2 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false))
170+
val asn = Literal.create(null, ArrayType(StringType, containsNull = false))
171+
172+
checkEvaluation(aggregate(as0, "", (acc, elem) => Concat(Seq(acc, elem))), "abc")
173+
checkEvaluation(aggregate(as1, "", (acc, elem) => Concat(Seq(acc, coalesce(elem, "x")))), "axc")
174+
checkEvaluation(aggregate(as2, "", (acc, elem) => Concat(Seq(acc, elem))), "")
175+
checkEvaluation(aggregate(asn, "", (acc, elem) => Concat(Seq(acc, elem))), null)
176+
177+
val aai = Literal.create(Seq[Seq[Integer]](Seq(1, 2, 3), null, Seq(4, 5)),
178+
ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
179+
checkEvaluation(
180+
aggregate(aai, 0,
181+
(acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)),
182+
15)
183+
}
134184
}

sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,15 @@ select filter(cast(null as array<int>), y -> true) as v;
3333

3434
-- Filter nested arrays
3535
select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested;
36+
37+
-- Aggregate.
38+
select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested;
39+
40+
-- Aggregate average.
41+
select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested;
42+
43+
-- Aggregate nested arrays
44+
select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested;
45+
46+
-- Aggregate a null array
47+
select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v;

sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 11
2+
-- Number of queries: 15
33

44

55
-- !query 0
@@ -107,3 +107,41 @@ struct<v:array<array<int>>>
107107
[[96,65],[]]
108108
[[99],[123],[]]
109109
[[]]
110+
111+
112+
-- !query 11
113+
select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested
114+
-- !query 11 schema
115+
struct<v:int>
116+
-- !query 11 output
117+
131
118+
15
119+
5
120+
121+
122+
-- !query 12
123+
select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested
124+
-- !query 12 schema
125+
struct<v:double>
126+
-- !query 12 output
127+
0.5
128+
12.0
129+
64.5
130+
131+
132+
-- !query 13
133+
select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested
134+
-- !query 13 schema
135+
struct<v:array<int>>
136+
-- !query 13 output
137+
[1010880,8]
138+
[17]
139+
[4752,20664,1]
140+
141+
142+
-- !query 14
143+
select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v
144+
-- !query 14 schema
145+
struct<v:int>
146+
-- !query 14 output
147+
NULL

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,6 +1896,127 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
18961896
assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type"))
18971897
}
18981898

1899+
test("aggregate function - array for primitive type not containing null") {
1900+
val df = Seq(
1901+
Seq(1, 9, 8, 7),
1902+
Seq(5, 8, 9, 7, 2),
1903+
Seq.empty,
1904+
null
1905+
).toDF("i")
1906+
1907+
def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
1908+
checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"),
1909+
Seq(
1910+
Row(25),
1911+
Row(31),
1912+
Row(0),
1913+
Row(null)))
1914+
checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10)"),
1915+
Seq(
1916+
Row(250),
1917+
Row(310),
1918+
Row(0),
1919+
Row(null)))
1920+
}
1921+
1922+
// Test with local relation, the Project will be evaluated without codegen
1923+
testArrayOfPrimitiveTypeNotContainsNull()
1924+
// Test with cached relation, the Project will be evaluated with codegen
1925+
df.cache()
1926+
testArrayOfPrimitiveTypeNotContainsNull()
1927+
}
1928+
1929+
test("aggregate function - array for primitive type containing null") {
1930+
val df = Seq[Seq[Integer]](
1931+
Seq(1, 9, 8, 7),
1932+
Seq(5, null, 8, 9, 7, 2),
1933+
Seq.empty,
1934+
null
1935+
).toDF("i")
1936+
1937+
def testArrayOfPrimitiveTypeContainsNull(): Unit = {
1938+
checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"),
1939+
Seq(
1940+
Row(25),
1941+
Row(null),
1942+
Row(0),
1943+
Row(null)))
1944+
checkAnswer(
1945+
df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> coalesce(acc, 0) * 10)"),
1946+
Seq(
1947+
Row(250),
1948+
Row(0),
1949+
Row(0),
1950+
Row(null)))
1951+
}
1952+
1953+
// Test with local relation, the Project will be evaluated without codegen
1954+
testArrayOfPrimitiveTypeContainsNull()
1955+
// Test with cached relation, the Project will be evaluated with codegen
1956+
df.cache()
1957+
testArrayOfPrimitiveTypeContainsNull()
1958+
}
1959+
1960+
test("aggregate function - array for non-primitive type") {
1961+
val df = Seq(
1962+
(Seq("c", "a", "b"), "a"),
1963+
(Seq("b", null, "c", null), "b"),
1964+
(Seq.empty, "c"),
1965+
(null, "d")
1966+
).toDF("ss", "s")
1967+
1968+
def testNonPrimitiveType(): Unit = {
1969+
checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x))"),
1970+
Seq(
1971+
Row("acab"),
1972+
Row(null),
1973+
Row("c"),
1974+
Row(null)))
1975+
checkAnswer(
1976+
df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x), acc -> coalesce(acc , ''))"),
1977+
Seq(
1978+
Row("acab"),
1979+
Row(""),
1980+
Row("c"),
1981+
Row(null)))
1982+
}
1983+
1984+
// Test with local relation, the Project will be evaluated without codegen
1985+
testNonPrimitiveType()
1986+
// Test with cached relation, the Project will be evaluated with codegen
1987+
df.cache()
1988+
testNonPrimitiveType()
1989+
}
1990+
1991+
test("aggregate function - invalid") {
1992+
val df = Seq(
1993+
(Seq("c", "a", "b"), 1),
1994+
(Seq("b", null, "c", null), 2),
1995+
(Seq.empty, 3),
1996+
(null, 4)
1997+
).toDF("s", "i")
1998+
1999+
val ex1 = intercept[AnalysisException] {
2000+
df.selectExpr("aggregate(s, '', x -> x)")
2001+
}
2002+
assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match"))
2003+
2004+
val ex2 = intercept[AnalysisException] {
2005+
df.selectExpr("aggregate(s, '', (acc, x) -> x, (acc, x) -> x)")
2006+
}
2007+
assert(ex2.getMessage.contains("The number of lambda function arguments '2' does not match"))
2008+
2009+
val ex3 = intercept[AnalysisException] {
2010+
df.selectExpr("aggregate(i, 0, (acc, x) -> x)")
2011+
}
2012+
assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type"))
2013+
2014+
val ex4 = intercept[AnalysisException] {
2015+
df.selectExpr("aggregate(s, 0, (acc, x) -> x)")
2016+
}
2017+
assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type"))
2018+
}
2019+
18992020
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
19002021
import DataFrameFunctionsSuite.CodegenFallbackExpr
19012022
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {

0 commit comments

Comments
 (0)