Skip to content

Commit 9ea8d3d

Browse files
attilapiroshvanhovell
authored andcommitted
[SPARK-22362][SQL] Add unit test for Window Aggregate Functions
## What changes were proposed in this pull request? Improving the test coverage of window functions focusing on missing test for window aggregate functions. No new UDAF test is added as it has been tested already. ## How was this patch tested? Only new tests were added, automated tests were executed. Author: “attilapiros” <[email protected]> Author: Attila Zsolt Piros <[email protected]> Closes apache#20046 from attilapiros/SPARK-22362.
1 parent a471880 commit 9ea8d3d

File tree

3 files changed

+294
-12
lines changed

3 files changed

+294
-12
lines changed

sql/core/src/test/resources/sql-tests/inputs/window.sql

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,15 @@ ntile(2) OVER w AS ntile,
7676
row_number() OVER w AS row_number,
7777
var_pop(val) OVER w AS var_pop,
7878
var_samp(val) OVER w AS var_samp,
79-
approx_count_distinct(val) OVER w AS approx_count_distinct
79+
approx_count_distinct(val) OVER w AS approx_count_distinct,
80+
covar_pop(val, val_long) OVER w AS covar_pop,
81+
corr(val, val_long) OVER w AS corr,
82+
stddev_samp(val) OVER w AS stddev_samp,
83+
stddev_pop(val) OVER w AS stddev_pop,
84+
collect_list(val) OVER w AS collect_list,
85+
collect_set(val) OVER w AS collect_set,
86+
skewness(val_double) OVER w AS skewness,
87+
kurtosis(val_double) OVER w AS kurtosis
8088
FROM testData
8189
WINDOW w AS (PARTITION BY cate ORDER BY val)
8290
ORDER BY cate, val;

sql/core/src/test/resources/sql-tests/results/window.sql.out

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -273,22 +273,30 @@ ntile(2) OVER w AS ntile,
273273
row_number() OVER w AS row_number,
274274
var_pop(val) OVER w AS var_pop,
275275
var_samp(val) OVER w AS var_samp,
276-
approx_count_distinct(val) OVER w AS approx_count_distinct
276+
approx_count_distinct(val) OVER w AS approx_count_distinct,
277+
covar_pop(val, val_long) OVER w AS covar_pop,
278+
corr(val, val_long) OVER w AS corr,
279+
stddev_samp(val) OVER w AS stddev_samp,
280+
stddev_pop(val) OVER w AS stddev_pop,
281+
collect_list(val) OVER w AS collect_list,
282+
collect_set(val) OVER w AS collect_set,
283+
skewness(val_double) OVER w AS skewness,
284+
kurtosis(val_double) OVER w AS kurtosis
277285
FROM testData
278286
WINDOW w AS (PARTITION BY cate ORDER BY val)
279287
ORDER BY cate, val
280288
-- !query 17 schema
281-
struct<val:int,cate:string,max:int,min:int,min:int,count:bigint,sum:bigint,avg:double,stddev:double,first_value:int,first_value_ignore_null:int,first_value_contain_null:int,last_value:int,last_value_ignore_null:int,last_value_contain_null:int,rank:int,dense_rank:int,cume_dist:double,percent_rank:double,ntile:int,row_number:int,var_pop:double,var_samp:double,approx_count_distinct:bigint>
289+
struct<val:int,cate:string,max:int,min:int,min:int,count:bigint,sum:bigint,avg:double,stddev:double,first_value:int,first_value_ignore_null:int,first_value_contain_null:int,last_value:int,last_value_ignore_null:int,last_value_contain_null:int,rank:int,dense_rank:int,cume_dist:double,percent_rank:double,ntile:int,row_number:int,var_pop:double,var_samp:double,approx_count_distinct:bigint,covar_pop:double,corr:double,stddev_samp:double,stddev_pop:double,collect_list:array<int>,collect_set:array<int>,skewness:double,kurtosis:double>
282290
-- !query 17 output
283-
NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0
284-
3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1
285-
NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0
286-
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1
287-
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1
288-
2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2
289-
1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1
290-
2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2
291-
3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3
291+
NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL
292+
3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 0.0 NaN NaN 0.0 [3] [3] NaN NaN
293+
NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NaN NaN
294+
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5
295+
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5
296+
2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 4.772185885555555E8 1.0 0.5773502691896258 0.4714045207910317 [1,1,2] [1,2] 1.1539890888012805 -0.6672217220327235
297+
1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 NULL NULL NaN 0.0 [1] [1] NaN NaN
298+
2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NaN 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013
299+
3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 5.3687091175E8 1.0 1.0 0.816496580927726 [1,2,3] [1,2,3] 0.7057890433107311 -1.4999999999999984
292300

293301

294302
-- !query 18

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql
1919

2020
import java.sql.{Date, Timestamp}
2121

22+
import scala.collection.mutable
23+
2224
import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
2325
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
2426
import org.apache.spark.sql.functions._
@@ -86,6 +88,236 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
8688
assert(e.message.contains("requires window to be ordered"))
8789
}
8890

91+
test("corr, covar_pop, stddev_pop functions in specific window") {
92+
val df = Seq(
93+
("a", "p1", 10.0, 20.0),
94+
("b", "p1", 20.0, 10.0),
95+
("c", "p2", 20.0, 20.0),
96+
("d", "p2", 20.0, 20.0),
97+
("e", "p3", 0.0, 0.0),
98+
("f", "p3", 6.0, 12.0),
99+
("g", "p3", 6.0, 12.0),
100+
("h", "p3", 8.0, 16.0),
101+
("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2")
102+
checkAnswer(
103+
df.select(
104+
$"key",
105+
corr("value1", "value2").over(Window.partitionBy("partitionId")
106+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
107+
covar_pop("value1", "value2")
108+
.over(Window.partitionBy("partitionId")
109+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
110+
var_pop("value1")
111+
.over(Window.partitionBy("partitionId")
112+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
113+
stddev_pop("value1")
114+
.over(Window.partitionBy("partitionId")
115+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
116+
var_pop("value2")
117+
.over(Window.partitionBy("partitionId")
118+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
119+
stddev_pop("value2")
120+
.over(Window.partitionBy("partitionId")
121+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))),
122+
123+
// As stddev_pop(expr) = sqrt(var_pop(expr))
124+
// the "stddev_pop" column can be calculated from the "var_pop" column.
125+
//
126+
// As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2))
127+
// the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns.
128+
Seq(
129+
Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0),
130+
Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0),
131+
Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0),
132+
Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0),
133+
Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
134+
Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
135+
Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
136+
Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
137+
Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0)))
138+
}
139+
140+
test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") {
141+
val df = Seq(
142+
("a", "p1", 10.0, 20.0),
143+
("b", "p1", 20.0, 10.0),
144+
("c", "p2", 20.0, 20.0),
145+
("d", "p2", 20.0, 20.0),
146+
("e", "p3", 0.0, 0.0),
147+
("f", "p3", 6.0, 12.0),
148+
("g", "p3", 6.0, 12.0),
149+
("h", "p3", 8.0, 16.0),
150+
("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2")
151+
checkAnswer(
152+
df.select(
153+
$"key",
154+
covar_samp("value1", "value2").over(Window.partitionBy("partitionId")
155+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
156+
var_samp("value1").over(Window.partitionBy("partitionId")
157+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
158+
variance("value1").over(Window.partitionBy("partitionId")
159+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
160+
stddev_samp("value1").over(Window.partitionBy("partitionId")
161+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
162+
stddev("value1").over(Window.partitionBy("partitionId")
163+
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
164+
),
165+
Seq(
166+
Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
167+
Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
168+
Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ),
169+
Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ),
170+
Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
171+
Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
172+
Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
173+
Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
174+
Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)))
175+
}
176+
177+
test("collect_list in ascending ordered window") {
178+
val df = Seq(
179+
("a", "p1", "1"),
180+
("b", "p1", "2"),
181+
("c", "p1", "2"),
182+
("d", "p1", null),
183+
("e", "p1", "3"),
184+
("f", "p2", "10"),
185+
("g", "p2", "11"),
186+
("h", "p3", "20"),
187+
("i", "p4", null)).toDF("key", "partition", "value")
188+
checkAnswer(
189+
df.select(
190+
$"key",
191+
sort_array(
192+
collect_list("value").over(Window.partitionBy($"partition").orderBy($"value")
193+
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))),
194+
Seq(
195+
Row("a", Array("1", "2", "2", "3")),
196+
Row("b", Array("1", "2", "2", "3")),
197+
Row("c", Array("1", "2", "2", "3")),
198+
Row("d", Array("1", "2", "2", "3")),
199+
Row("e", Array("1", "2", "2", "3")),
200+
Row("f", Array("10", "11")),
201+
Row("g", Array("10", "11")),
202+
Row("h", Array("20")),
203+
Row("i", Array())))
204+
}
205+
206+
test("collect_list in descending ordered window") {
207+
val df = Seq(
208+
("a", "p1", "1"),
209+
("b", "p1", "2"),
210+
("c", "p1", "2"),
211+
("d", "p1", null),
212+
("e", "p1", "3"),
213+
("f", "p2", "10"),
214+
("g", "p2", "11"),
215+
("h", "p3", "20"),
216+
("i", "p4", null)).toDF("key", "partition", "value")
217+
checkAnswer(
218+
df.select(
219+
$"key",
220+
sort_array(
221+
collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc)
222+
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))),
223+
Seq(
224+
Row("a", Array("1", "2", "2", "3")),
225+
Row("b", Array("1", "2", "2", "3")),
226+
Row("c", Array("1", "2", "2", "3")),
227+
Row("d", Array("1", "2", "2", "3")),
228+
Row("e", Array("1", "2", "2", "3")),
229+
Row("f", Array("10", "11")),
230+
Row("g", Array("10", "11")),
231+
Row("h", Array("20")),
232+
Row("i", Array())))
233+
}
234+
235+
test("collect_set in window") {
236+
val df = Seq(
237+
("a", "p1", "1"),
238+
("b", "p1", "2"),
239+
("c", "p1", "2"),
240+
("d", "p1", "3"),
241+
("e", "p1", "3"),
242+
("f", "p2", "10"),
243+
("g", "p2", "11"),
244+
("h", "p3", "20")).toDF("key", "partition", "value")
245+
checkAnswer(
246+
df.select(
247+
$"key",
248+
sort_array(
249+
collect_set("value").over(Window.partitionBy($"partition").orderBy($"value")
250+
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))),
251+
Seq(
252+
Row("a", Array("1", "2", "3")),
253+
Row("b", Array("1", "2", "3")),
254+
Row("c", Array("1", "2", "3")),
255+
Row("d", Array("1", "2", "3")),
256+
Row("e", Array("1", "2", "3")),
257+
Row("f", Array("10", "11")),
258+
Row("g", Array("10", "11")),
259+
Row("h", Array("20"))))
260+
}
261+
262+
test("skewness and kurtosis functions in window") {
263+
val df = Seq(
264+
("a", "p1", 1.0),
265+
("b", "p1", 1.0),
266+
("c", "p1", 2.0),
267+
("d", "p1", 2.0),
268+
("e", "p1", 3.0),
269+
("f", "p1", 3.0),
270+
("g", "p1", 3.0),
271+
("h", "p2", 1.0),
272+
("i", "p2", 2.0),
273+
("j", "p2", 5.0)).toDF("key", "partition", "value")
274+
checkAnswer(
275+
df.select(
276+
$"key",
277+
skewness("value").over(Window.partitionBy("partition").orderBy($"key")
278+
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
279+
kurtosis("value").over(Window.partitionBy("partition").orderBy($"key")
280+
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))),
281+
// results are checked by scipy.stats.skew() and scipy.stats.kurtosis()
282+
Seq(
283+
Row("a", -0.27238010581457267, -1.506920415224914),
284+
Row("b", -0.27238010581457267, -1.506920415224914),
285+
Row("c", -0.27238010581457267, -1.506920415224914),
286+
Row("d", -0.27238010581457267, -1.506920415224914),
287+
Row("e", -0.27238010581457267, -1.506920415224914),
288+
Row("f", -0.27238010581457267, -1.506920415224914),
289+
Row("g", -0.27238010581457267, -1.506920415224914),
290+
Row("h", 0.5280049792181881, -1.5000000000000013),
291+
Row("i", 0.5280049792181881, -1.5000000000000013),
292+
Row("j", 0.5280049792181881, -1.5000000000000013)))
293+
}
294+
295+
test("aggregation function on invalid column") {
296+
val df = Seq((1, "1")).toDF("key", "value")
297+
val e = intercept[AnalysisException](
298+
df.select($"key", count("invalid").over()))
299+
assert(e.message.contains("cannot resolve '`invalid`' given input columns: [key, value]"))
300+
}
301+
302+
test("numerical aggregate functions on string column") {
303+
val df = Seq((1, "a", "b")).toDF("key", "value1", "value2")
304+
checkAnswer(
305+
df.select($"key",
306+
var_pop("value1").over(),
307+
variance("value1").over(),
308+
stddev_pop("value1").over(),
309+
stddev("value1").over(),
310+
sum("value1").over(),
311+
mean("value1").over(),
312+
avg("value1").over(),
313+
corr("value1", "value2").over(),
314+
covar_pop("value1", "value2").over(),
315+
covar_samp("value1", "value2").over(),
316+
skewness("value1").over(),
317+
kurtosis("value1").over()),
318+
Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null)))
319+
}
320+
89321
test("statistical functions") {
90322
val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)).
91323
toDF("key", "value")
@@ -232,6 +464,40 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
232464
Row("b", 2, null, null, null, null, null, null)))
233465
}
234466

467+
test("last/first on descending ordered window") {
468+
val nullStr: String = null
469+
val df = Seq(
470+
("a", 0, nullStr),
471+
("a", 1, "x"),
472+
("a", 2, "y"),
473+
("a", 3, "z"),
474+
("a", 4, "v"),
475+
("b", 1, "k"),
476+
("b", 2, "l"),
477+
("b", 3, nullStr)).
478+
toDF("key", "order", "value")
479+
val window = Window.partitionBy($"key").orderBy($"order".desc)
480+
checkAnswer(
481+
df.select(
482+
$"key",
483+
$"order",
484+
first($"value").over(window),
485+
first($"value", ignoreNulls = false).over(window),
486+
first($"value", ignoreNulls = true).over(window),
487+
last($"value").over(window),
488+
last($"value", ignoreNulls = false).over(window),
489+
last($"value", ignoreNulls = true).over(window)),
490+
Seq(
491+
Row("a", 0, "v", "v", "v", null, null, "x"),
492+
Row("a", 1, "v", "v", "v", "x", "x", "x"),
493+
Row("a", 2, "v", "v", "v", "y", "y", "y"),
494+
Row("a", 3, "v", "v", "v", "z", "z", "z"),
495+
Row("a", 4, "v", "v", "v", "v", "v", "v"),
496+
Row("b", 1, null, null, "l", "k", "k", "k"),
497+
Row("b", 2, null, null, "l", "l", "l", "l"),
498+
Row("b", 3, null, null, null, null, null, null)))
499+
}
500+
235501
test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") {
236502
val src = Seq((0, 3, 5)).toDF("a", "b", "c")
237503
.withColumn("Data", struct("a", "b"))

0 commit comments

Comments
 (0)