@@ -19,6 +19,8 @@ package org.apache.spark.sql
19
19
20
20
import java .sql .{Date , Timestamp }
21
21
22
+ import scala .collection .mutable
23
+
22
24
import org .apache .spark .TestUtils .{assertNotSpilled , assertSpilled }
23
25
import org .apache .spark .sql .expressions .{MutableAggregationBuffer , UserDefinedAggregateFunction , Window }
24
26
import org .apache .spark .sql .functions ._
@@ -86,6 +88,236 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
86
88
assert(e.message.contains(" requires window to be ordered" ))
87
89
}
88
90
91
+ test(" corr, covar_pop, stddev_pop functions in specific window" ) {
92
+ val df = Seq (
93
+ (" a" , " p1" , 10.0 , 20.0 ),
94
+ (" b" , " p1" , 20.0 , 10.0 ),
95
+ (" c" , " p2" , 20.0 , 20.0 ),
96
+ (" d" , " p2" , 20.0 , 20.0 ),
97
+ (" e" , " p3" , 0.0 , 0.0 ),
98
+ (" f" , " p3" , 6.0 , 12.0 ),
99
+ (" g" , " p3" , 6.0 , 12.0 ),
100
+ (" h" , " p3" , 8.0 , 16.0 ),
101
+ (" i" , " p4" , 5.0 , 5.0 )).toDF(" key" , " partitionId" , " value1" , " value2" )
102
+ checkAnswer(
103
+ df.select(
104
+ $" key" ,
105
+ corr(" value1" , " value2" ).over(Window .partitionBy(" partitionId" )
106
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)),
107
+ covar_pop(" value1" , " value2" )
108
+ .over(Window .partitionBy(" partitionId" )
109
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)),
110
+ var_pop(" value1" )
111
+ .over(Window .partitionBy(" partitionId" )
112
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)),
113
+ stddev_pop(" value1" )
114
+ .over(Window .partitionBy(" partitionId" )
115
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)),
116
+ var_pop(" value2" )
117
+ .over(Window .partitionBy(" partitionId" )
118
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)),
119
+ stddev_pop(" value2" )
120
+ .over(Window .partitionBy(" partitionId" )
121
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing))),
122
+
123
+ // As stddev_pop(expr) = sqrt(var_pop(expr))
124
+ // the "stddev_pop" column can be calculated from the "var_pop" column.
125
+ //
126
+ // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2))
127
+ // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns.
128
+ Seq (
129
+ Row (" a" , - 1.0 , - 25.0 , 25.0 , 5.0 , 25.0 , 5.0 ),
130
+ Row (" b" , - 1.0 , - 25.0 , 25.0 , 5.0 , 25.0 , 5.0 ),
131
+ Row (" c" , null , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ),
132
+ Row (" d" , null , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ),
133
+ Row (" e" , 1.0 , 18.0 , 9.0 , 3.0 , 36.0 , 6.0 ),
134
+ Row (" f" , 1.0 , 18.0 , 9.0 , 3.0 , 36.0 , 6.0 ),
135
+ Row (" g" , 1.0 , 18.0 , 9.0 , 3.0 , 36.0 , 6.0 ),
136
+ Row (" h" , 1.0 , 18.0 , 9.0 , 3.0 , 36.0 , 6.0 ),
137
+ Row (" i" , Double .NaN , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 )))
138
+ }
139
+
140
+ test(" covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window" ) {
141
+ val df = Seq (
142
+ (" a" , " p1" , 10.0 , 20.0 ),
143
+ (" b" , " p1" , 20.0 , 10.0 ),
144
+ (" c" , " p2" , 20.0 , 20.0 ),
145
+ (" d" , " p2" , 20.0 , 20.0 ),
146
+ (" e" , " p3" , 0.0 , 0.0 ),
147
+ (" f" , " p3" , 6.0 , 12.0 ),
148
+ (" g" , " p3" , 6.0 , 12.0 ),
149
+ (" h" , " p3" , 8.0 , 16.0 ),
150
+ (" i" , " p4" , 5.0 , 5.0 )).toDF(" key" , " partitionId" , " value1" , " value2" )
151
+ checkAnswer(
152
+ df.select(
153
+ $" key" ,
154
+ covar_samp(" value1" , " value2" ).over(Window .partitionBy(" partitionId" )
155
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)),
156
+ var_samp(" value1" ).over(Window .partitionBy(" partitionId" )
157
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)),
158
+ variance(" value1" ).over(Window .partitionBy(" partitionId" )
159
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)),
160
+ stddev_samp(" value1" ).over(Window .partitionBy(" partitionId" )
161
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)),
162
+ stddev(" value1" ).over(Window .partitionBy(" partitionId" )
163
+ .orderBy(" key" ).rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing))
164
+ ),
165
+ Seq (
166
+ Row (" a" , - 50.0 , 50.0 , 50.0 , 7.0710678118654755 , 7.0710678118654755 ),
167
+ Row (" b" , - 50.0 , 50.0 , 50.0 , 7.0710678118654755 , 7.0710678118654755 ),
168
+ Row (" c" , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ),
169
+ Row (" d" , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ),
170
+ Row (" e" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
171
+ Row (" f" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
172
+ Row (" g" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
173
+ Row (" h" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
174
+ Row (" i" , Double .NaN , Double .NaN , Double .NaN , Double .NaN , Double .NaN )))
175
+ }
176
+
177
+ test(" collect_list in ascending ordered window" ) {
178
+ val df = Seq (
179
+ (" a" , " p1" , " 1" ),
180
+ (" b" , " p1" , " 2" ),
181
+ (" c" , " p1" , " 2" ),
182
+ (" d" , " p1" , null ),
183
+ (" e" , " p1" , " 3" ),
184
+ (" f" , " p2" , " 10" ),
185
+ (" g" , " p2" , " 11" ),
186
+ (" h" , " p3" , " 20" ),
187
+ (" i" , " p4" , null )).toDF(" key" , " partition" , " value" )
188
+ checkAnswer(
189
+ df.select(
190
+ $" key" ,
191
+ sort_array(
192
+ collect_list(" value" ).over(Window .partitionBy($" partition" ).orderBy($" value" )
193
+ .rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)))),
194
+ Seq (
195
+ Row (" a" , Array (" 1" , " 2" , " 2" , " 3" )),
196
+ Row (" b" , Array (" 1" , " 2" , " 2" , " 3" )),
197
+ Row (" c" , Array (" 1" , " 2" , " 2" , " 3" )),
198
+ Row (" d" , Array (" 1" , " 2" , " 2" , " 3" )),
199
+ Row (" e" , Array (" 1" , " 2" , " 2" , " 3" )),
200
+ Row (" f" , Array (" 10" , " 11" )),
201
+ Row (" g" , Array (" 10" , " 11" )),
202
+ Row (" h" , Array (" 20" )),
203
+ Row (" i" , Array ())))
204
+ }
205
+
206
+ test(" collect_list in descending ordered window" ) {
207
+ val df = Seq (
208
+ (" a" , " p1" , " 1" ),
209
+ (" b" , " p1" , " 2" ),
210
+ (" c" , " p1" , " 2" ),
211
+ (" d" , " p1" , null ),
212
+ (" e" , " p1" , " 3" ),
213
+ (" f" , " p2" , " 10" ),
214
+ (" g" , " p2" , " 11" ),
215
+ (" h" , " p3" , " 20" ),
216
+ (" i" , " p4" , null )).toDF(" key" , " partition" , " value" )
217
+ checkAnswer(
218
+ df.select(
219
+ $" key" ,
220
+ sort_array(
221
+ collect_list(" value" ).over(Window .partitionBy($" partition" ).orderBy($" value" .desc)
222
+ .rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)))),
223
+ Seq (
224
+ Row (" a" , Array (" 1" , " 2" , " 2" , " 3" )),
225
+ Row (" b" , Array (" 1" , " 2" , " 2" , " 3" )),
226
+ Row (" c" , Array (" 1" , " 2" , " 2" , " 3" )),
227
+ Row (" d" , Array (" 1" , " 2" , " 2" , " 3" )),
228
+ Row (" e" , Array (" 1" , " 2" , " 2" , " 3" )),
229
+ Row (" f" , Array (" 10" , " 11" )),
230
+ Row (" g" , Array (" 10" , " 11" )),
231
+ Row (" h" , Array (" 20" )),
232
+ Row (" i" , Array ())))
233
+ }
234
+
235
+ test(" collect_set in window" ) {
236
+ val df = Seq (
237
+ (" a" , " p1" , " 1" ),
238
+ (" b" , " p1" , " 2" ),
239
+ (" c" , " p1" , " 2" ),
240
+ (" d" , " p1" , " 3" ),
241
+ (" e" , " p1" , " 3" ),
242
+ (" f" , " p2" , " 10" ),
243
+ (" g" , " p2" , " 11" ),
244
+ (" h" , " p3" , " 20" )).toDF(" key" , " partition" , " value" )
245
+ checkAnswer(
246
+ df.select(
247
+ $" key" ,
248
+ sort_array(
249
+ collect_set(" value" ).over(Window .partitionBy($" partition" ).orderBy($" value" )
250
+ .rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)))),
251
+ Seq (
252
+ Row (" a" , Array (" 1" , " 2" , " 3" )),
253
+ Row (" b" , Array (" 1" , " 2" , " 3" )),
254
+ Row (" c" , Array (" 1" , " 2" , " 3" )),
255
+ Row (" d" , Array (" 1" , " 2" , " 3" )),
256
+ Row (" e" , Array (" 1" , " 2" , " 3" )),
257
+ Row (" f" , Array (" 10" , " 11" )),
258
+ Row (" g" , Array (" 10" , " 11" )),
259
+ Row (" h" , Array (" 20" ))))
260
+ }
261
+
262
+ test(" skewness and kurtosis functions in window" ) {
263
+ val df = Seq (
264
+ (" a" , " p1" , 1.0 ),
265
+ (" b" , " p1" , 1.0 ),
266
+ (" c" , " p1" , 2.0 ),
267
+ (" d" , " p1" , 2.0 ),
268
+ (" e" , " p1" , 3.0 ),
269
+ (" f" , " p1" , 3.0 ),
270
+ (" g" , " p1" , 3.0 ),
271
+ (" h" , " p2" , 1.0 ),
272
+ (" i" , " p2" , 2.0 ),
273
+ (" j" , " p2" , 5.0 )).toDF(" key" , " partition" , " value" )
274
+ checkAnswer(
275
+ df.select(
276
+ $" key" ,
277
+ skewness(" value" ).over(Window .partitionBy(" partition" ).orderBy($" key" )
278
+ .rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing)),
279
+ kurtosis(" value" ).over(Window .partitionBy(" partition" ).orderBy($" key" )
280
+ .rowsBetween(Window .unboundedPreceding, Window .unboundedFollowing))),
281
+ // results are checked by scipy.stats.skew() and scipy.stats.kurtosis()
282
+ Seq (
283
+ Row (" a" , - 0.27238010581457267 , - 1.506920415224914 ),
284
+ Row (" b" , - 0.27238010581457267 , - 1.506920415224914 ),
285
+ Row (" c" , - 0.27238010581457267 , - 1.506920415224914 ),
286
+ Row (" d" , - 0.27238010581457267 , - 1.506920415224914 ),
287
+ Row (" e" , - 0.27238010581457267 , - 1.506920415224914 ),
288
+ Row (" f" , - 0.27238010581457267 , - 1.506920415224914 ),
289
+ Row (" g" , - 0.27238010581457267 , - 1.506920415224914 ),
290
+ Row (" h" , 0.5280049792181881 , - 1.5000000000000013 ),
291
+ Row (" i" , 0.5280049792181881 , - 1.5000000000000013 ),
292
+ Row (" j" , 0.5280049792181881 , - 1.5000000000000013 )))
293
+ }
294
+
295
+ test(" aggregation function on invalid column" ) {
296
+ val df = Seq ((1 , " 1" )).toDF(" key" , " value" )
297
+ val e = intercept[AnalysisException ](
298
+ df.select($" key" , count(" invalid" ).over()))
299
+ assert(e.message.contains(" cannot resolve '`invalid`' given input columns: [key, value]" ))
300
+ }
301
+
302
+ test(" numerical aggregate functions on string column" ) {
303
+ val df = Seq ((1 , " a" , " b" )).toDF(" key" , " value1" , " value2" )
304
+ checkAnswer(
305
+ df.select($" key" ,
306
+ var_pop(" value1" ).over(),
307
+ variance(" value1" ).over(),
308
+ stddev_pop(" value1" ).over(),
309
+ stddev(" value1" ).over(),
310
+ sum(" value1" ).over(),
311
+ mean(" value1" ).over(),
312
+ avg(" value1" ).over(),
313
+ corr(" value1" , " value2" ).over(),
314
+ covar_pop(" value1" , " value2" ).over(),
315
+ covar_samp(" value1" , " value2" ).over(),
316
+ skewness(" value1" ).over(),
317
+ kurtosis(" value1" ).over()),
318
+ Seq (Row (1 , null , null , null , null , null , null , null , null , null , null , null , null )))
319
+ }
320
+
89
321
test(" statistical functions" ) {
90
322
val df = Seq ((" a" , 1 ), (" a" , 1 ), (" a" , 2 ), (" a" , 2 ), (" b" , 4 ), (" b" , 3 ), (" b" , 2 )).
91
323
toDF(" key" , " value" )
@@ -232,6 +464,40 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
232
464
Row (" b" , 2 , null , null , null , null , null , null )))
233
465
}
234
466
467
+ test(" last/first on descending ordered window" ) {
468
+ val nullStr : String = null
469
+ val df = Seq (
470
+ (" a" , 0 , nullStr),
471
+ (" a" , 1 , " x" ),
472
+ (" a" , 2 , " y" ),
473
+ (" a" , 3 , " z" ),
474
+ (" a" , 4 , " v" ),
475
+ (" b" , 1 , " k" ),
476
+ (" b" , 2 , " l" ),
477
+ (" b" , 3 , nullStr)).
478
+ toDF(" key" , " order" , " value" )
479
+ val window = Window .partitionBy($" key" ).orderBy($" order" .desc)
480
+ checkAnswer(
481
+ df.select(
482
+ $" key" ,
483
+ $" order" ,
484
+ first($" value" ).over(window),
485
+ first($" value" , ignoreNulls = false ).over(window),
486
+ first($" value" , ignoreNulls = true ).over(window),
487
+ last($" value" ).over(window),
488
+ last($" value" , ignoreNulls = false ).over(window),
489
+ last($" value" , ignoreNulls = true ).over(window)),
490
+ Seq (
491
+ Row (" a" , 0 , " v" , " v" , " v" , null , null , " x" ),
492
+ Row (" a" , 1 , " v" , " v" , " v" , " x" , " x" , " x" ),
493
+ Row (" a" , 2 , " v" , " v" , " v" , " y" , " y" , " y" ),
494
+ Row (" a" , 3 , " v" , " v" , " v" , " z" , " z" , " z" ),
495
+ Row (" a" , 4 , " v" , " v" , " v" , " v" , " v" , " v" ),
496
+ Row (" b" , 1 , null , null , " l" , " k" , " k" , " k" ),
497
+ Row (" b" , 2 , null , null , " l" , " l" , " l" , " l" ),
498
+ Row (" b" , 3 , null , null , null , null , null , null )))
499
+ }
500
+
235
501
test(" SPARK-12989 ExtractWindowExpressions treats alias as regular attribute" ) {
236
502
val src = Seq ((0 , 3 , 5 )).toDF(" a" , " b" , " c" )
237
503
.withColumn(" Data" , struct(" a" , " b" ))
0 commit comments