@@ -47,42 +47,46 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
47
47
SimplifyExtractValueOps ) :: Nil
48
48
}
49
49
50
- val idAtt = (' id ).long.notNull
51
- val nullableIdAtt = (' nullable_id ).long
50
+ private val idAtt = (' id ).long.notNull
51
+ private val nullableIdAtt = (' nullable_id ).long
52
52
53
- lazy val relation = LocalRelation (idAtt, nullableIdAtt)
53
+ private val relation = LocalRelation (idAtt, nullableIdAtt)
54
+ private val testRelation = LocalRelation (' a .int, ' b .int, ' c .int, ' d .double, ' e .int)
55
+
56
+ private def checkRule (originalQuery : LogicalPlan , correctAnswer : LogicalPlan ) = {
57
+ val optimized = Optimizer .execute(originalQuery.analyze)
58
+ assert(optimized.resolved, " optimized plans must be still resolvable" )
59
+ comparePlans(optimized, correctAnswer.analyze)
60
+ }
54
61
55
62
test(" explicit get from namedStruct" ) {
56
63
val query = relation
57
64
.select(
58
65
GetStructField (
59
66
CreateNamedStruct (Seq (" att" , ' id )),
60
67
0 ,
61
- None ) as " outerAtt" ).analyze
62
- val expected = relation.select(' id as " outerAtt" ).analyze
68
+ None ) as " outerAtt" )
69
+ val expected = relation.select(' id as " outerAtt" )
63
70
64
- comparePlans( Optimizer execute query, expected)
71
+ checkRule( query, expected)
65
72
}
66
73
67
74
test(" explicit get from named_struct- expression maintains original deduced alias" ) {
68
75
val query = relation
69
76
.select(GetStructField (CreateNamedStruct (Seq (" att" , ' id )), 0 , None ))
70
- .analyze
71
77
72
78
val expected = relation
73
79
.select(' id as " named_struct(att, id).att" )
74
- .analyze
75
80
76
- comparePlans( Optimizer execute query, expected)
81
+ checkRule( query, expected)
77
82
}
78
83
79
84
test(" collapsed getStructField ontop of namedStruct" ) {
80
85
val query = relation
81
86
.select(CreateNamedStruct (Seq (" att" , ' id )) as " struct1" )
82
87
.select(GetStructField (' struct1 , 0 , None ) as " struct1Att" )
83
- .analyze
84
- val expected = relation.select(' id as " struct1Att" ).analyze
85
- comparePlans(Optimizer execute query, expected)
88
+ val expected = relation.select(' id as " struct1Att" )
89
+ checkRule(query, expected)
86
90
}
87
91
88
92
test(" collapse multiple CreateNamedStruct/GetStructField pairs" ) {
@@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
94
98
.select(
95
99
GetStructField (' struct1 , 0 , None ) as " struct1Att1" ,
96
100
GetStructField (' struct1 , 1 , None ) as " struct1Att2" )
97
- .analyze
98
101
99
102
val expected =
100
103
relation.
101
104
select(
102
105
' id as " struct1Att1" ,
103
106
(' id * ' id ) as " struct1Att2" )
104
- .analyze
105
107
106
- comparePlans( Optimizer execute query, expected)
108
+ checkRule( query, expected)
107
109
}
108
110
109
111
test(" collapsed2 - deduced names" ) {
@@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
115
117
.select(
116
118
GetStructField (' struct1 , 0 , None ),
117
119
GetStructField (' struct1 , 1 , None ))
118
- .analyze
119
120
120
121
val expected =
121
122
relation.
122
123
select(
123
124
' id as " struct1.att1" ,
124
125
(' id * ' id ) as " struct1.att2" )
125
- .analyze
126
126
127
- comparePlans( Optimizer execute query, expected)
127
+ checkRule( query, expected)
128
128
}
129
129
130
130
test(" simplified array ops" ) {
@@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
151
151
1 ,
152
152
false ),
153
153
1 ) as " a4" )
154
- .analyze
155
154
156
155
val expected = relation
157
156
.select(
@@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
161
160
" att2" , ((' id + 1L ) * (' id + 1L )))) as " a2" ,
162
161
(' id + 1L ) as " a3" ,
163
162
(' id + 1L ) as " a4" )
164
- .analyze
165
- comparePlans(Optimizer execute query, expected)
163
+ checkRule(query, expected)
166
164
}
167
165
168
166
test(" SPARK-22570: CreateArray should not create a lot of global variables" ) {
@@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
188
186
GetStructField (GetMapValue (' m , " r1" ), 0 , None ) as " a2" ,
189
187
GetMapValue (' m , " r32" ) as " a3" ,
190
188
GetStructField (GetMapValue (' m , " r32" ), 0 , None ) as " a4" )
191
- .analyze
192
189
193
190
val expected =
194
191
relation.select(
@@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
201
198
)
202
199
) as " a3" ,
203
200
Literal .create(null , LongType ) as " a4" )
204
- .analyze
205
- comparePlans(Optimizer execute query, expected)
201
+ checkRule(query, expected)
206
202
}
207
203
208
204
test(" simplify map ops, constant lookup, dynamic keys" ) {
@@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
216
212
(' id + 3L ), (' id + 4L ),
217
213
(' id + 4L ), (' id + 5L ))),
218
214
13L ) as " a" )
219
- .analyze
220
215
221
216
val expected = relation
222
217
.select(
@@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
225
220
(EqualTo (13L , (' id + 1L )), (' id + 2L )),
226
221
(EqualTo (13L , (' id + 2L )), (' id + 3L )),
227
222
(Literal (true ), ' id ))) as " a" )
228
- .analyze
229
- comparePlans(Optimizer execute query, expected)
223
+ checkRule(query, expected)
230
224
}
231
225
232
226
test(" simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys" ) {
@@ -240,16 +234,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
240
234
(' id + 3L ), (' id + 4L ),
241
235
(' id + 4L ), (' id + 5L ))),
242
236
(' id + 3L )) as " a" )
243
- .analyze
244
237
val expected = relation
245
238
.select(
246
239
CaseWhen (Seq (
247
240
(EqualTo (' id + 3L , ' id ), (' id + 1L )),
248
241
(EqualTo (' id + 3L , (' id + 1L )), (' id + 2L )),
249
242
(EqualTo (' id + 3L , (' id + 2L )), (' id + 3L )),
250
243
(Literal (true ), (' id + 4L )))) as " a" )
251
- .analyze
252
- comparePlans(Optimizer execute query, expected)
244
+ checkRule(query, expected)
253
245
}
254
246
255
247
test(" simplify map ops, no positive match" ) {
@@ -263,16 +255,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
263
255
(' id + 3L ), (' id + 4L ),
264
256
(' id + 4L ), (' id + 5L ))),
265
257
' id + 30L ) as " a" )
266
- .analyze
267
258
val expected = relation.select(
268
259
CaseWhen (Seq (
269
260
(EqualTo (' id + 30L , ' id ), (' id + 1L )),
270
261
(EqualTo (' id + 30L , (' id + 1L )), (' id + 2L )),
271
262
(EqualTo (' id + 30L , (' id + 2L )), (' id + 3L )),
272
263
(EqualTo (' id + 30L , (' id + 3L )), (' id + 4L )),
273
264
(EqualTo (' id + 30L , (' id + 4L )), (' id + 5L )))) as " a" )
274
- .analyze
275
- comparePlans(Optimizer execute rel, expected)
265
+ checkRule(rel, expected)
276
266
}
277
267
278
268
test(" simplify map ops, constant lookup, mixed keys, eliminated constants" ) {
@@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
287
277
(' id + 3L ), (' id + 4L ),
288
278
(' id + 4L ), (' id + 5L ))),
289
279
13L ) as " a" )
290
- .analyze
291
280
292
281
val expected = relation
293
282
.select(
@@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
297
286
(' id + 2L ), (' id + 3L ),
298
287
(' id + 3L ), (' id + 4L ),
299
288
(' id + 4L ), (' id + 5L ))) as " a" )
300
- .analyze
301
289
302
- comparePlans( Optimizer execute rel, expected)
290
+ checkRule( rel, expected)
303
291
}
304
292
305
293
test(" simplify map ops, potential dynamic match with null value + an absolute constant match" ) {
@@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
314
302
(' id + 3L ), (' id + 4L ),
315
303
(' id + 4L ), (' id + 5L ))),
316
304
2L ) as " a" )
317
- .analyze
318
305
319
306
val expected = relation
320
307
.select(
@@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
327
314
// but it cannot override a potential match with ('id + 2L),
328
315
// which is exactly what [[Coalesce]] would do in this case.
329
316
(Literal .TrueLiteral , ' id ))) as " a" )
330
- .analyze
331
- comparePlans(Optimizer execute rel, expected)
317
+ checkRule(rel, expected)
318
+ }
319
+
320
+ test(" SPARK-23500: Simplify array ops that are not at the top node" ) {
321
+ val query = LocalRelation (' id .long)
322
+ .select(
323
+ CreateArray (Seq (
324
+ CreateNamedStruct (Seq (
325
+ " att1" , ' id ,
326
+ " att2" , ' id * ' id )),
327
+ CreateNamedStruct (Seq (
328
+ " att1" , ' id + 1 ,
329
+ " att2" , (' id + 1 ) * (' id + 1 ))
330
+ ))
331
+ ) as " arr" )
332
+ .select(
333
+ GetStructField (GetArrayItem (' arr , 1 ), 0 , None ) as " a1" ,
334
+ GetArrayItem (
335
+ GetArrayStructFields (' arr ,
336
+ StructField (" att1" , LongType , nullable = false ),
337
+ ordinal = 0 ,
338
+ numFields = 1 ,
339
+ containsNull = false ),
340
+ ordinal = 1 ) as " a2" )
341
+ .orderBy(' id .asc)
342
+
343
+ val expected = LocalRelation (' id .long)
344
+ .select(
345
+ (' id + 1L ) as " a1" ,
346
+ (' id + 1L ) as " a2" )
347
+ .orderBy(' id .asc)
348
+ checkRule(query, expected)
349
+ }
350
+
351
+ test(" SPARK-23500: Simplify map ops that are not top nodes" ) {
352
+ val query =
353
+ LocalRelation (' id .long)
354
+ .select(
355
+ CreateMap (Seq (
356
+ " r1" , ' id ,
357
+ " r2" , ' id + 1L )) as " m" )
358
+ .select(
359
+ GetMapValue (' m , " r1" ) as " a1" ,
360
+ GetMapValue (' m , " r32" ) as " a2" )
361
+ .orderBy(' id .asc)
362
+ .select(' a1 , ' a2 )
363
+
364
+ val expected =
365
+ LocalRelation (' id .long).select(
366
+ ' id as " a1" ,
367
+ Literal .create(null , LongType ) as " a2" )
368
+ .orderBy(' id .asc)
369
+ checkRule(query, expected)
332
370
}
333
371
334
372
test(" SPARK-23500: Simplify complex ops that aren't at the plan root" ) {
335
373
val structRel = relation
336
374
.select(GetStructField (CreateNamedStruct (Seq (" att1" , ' nullable_id )), 0 , None ) as " foo" )
337
- .groupBy($" foo" )(" 1" ).analyze
375
+ .groupBy($" foo" )(" 1" )
338
376
val structExpected = relation
339
377
.select(' nullable_id as " foo" )
340
- .groupBy($" foo" )(" 1" ).analyze
341
- comparePlans( Optimizer execute structRel, structExpected)
378
+ .groupBy($" foo" )(" 1" )
379
+ checkRule( structRel, structExpected)
342
380
343
381
// These tests must use nullable attributes from the base relation for the following reason:
344
382
// in the 'original' plans below, the Aggregate node produced by groupBy() has a
@@ -351,29 +389,63 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
351
389
// SPARK-23634.
352
390
val arrayRel = relation
353
391
.select(GetArrayItem (CreateArray (Seq (' nullable_id , ' nullable_id + 1L )), 0 ) as " a1" )
354
- .groupBy($" a1" )(" 1" ).analyze
355
- val arrayExpected = relation.select(' nullable_id as " a1" ).groupBy($" a1" )(" 1" ).analyze
356
- comparePlans( Optimizer execute arrayRel, arrayExpected)
392
+ .groupBy($" a1" )(" 1" )
393
+ val arrayExpected = relation.select(' nullable_id as " a1" ).groupBy($" a1" )(" 1" )
394
+ checkRule( arrayRel, arrayExpected)
357
395
358
396
val mapRel = relation
359
397
.select(GetMapValue (CreateMap (Seq (" id" , ' nullable_id )), " id" ) as " m1" )
360
- .groupBy($" m1" )(" 1" ).analyze
398
+ .groupBy($" m1" )(" 1" )
361
399
val mapExpected = relation
362
400
.select(' nullable_id as " m1" )
363
- .groupBy($" m1" )(" 1" ).analyze
364
- comparePlans( Optimizer execute mapRel, mapExpected)
401
+ .groupBy($" m1" )(" 1" )
402
+ checkRule( mapRel, mapExpected)
365
403
}
366
404
367
405
test(" SPARK-23500: Ensure that aggregation expressions are not simplified" ) {
368
406
// Make sure that aggregation exprs are correctly ignored. Maps can't be used in
369
407
// grouping exprs so aren't tested here.
370
408
val structAggRel = relation.groupBy(
371
409
CreateNamedStruct (Seq (" att1" , ' nullable_id )))(
372
- GetStructField (CreateNamedStruct (Seq (" att1" , ' nullable_id )), 0 , None )).analyze
373
- comparePlans( Optimizer execute structAggRel, structAggRel)
410
+ GetStructField (CreateNamedStruct (Seq (" att1" , ' nullable_id )), 0 , None ))
411
+ checkRule( structAggRel, structAggRel)
374
412
375
413
val arrayAggRel = relation.groupBy(
376
- CreateArray (Seq (' nullable_id )))(GetArrayItem (CreateArray (Seq (' nullable_id )), 0 )).analyze
377
- comparePlans(Optimizer execute arrayAggRel, arrayAggRel)
414
+ CreateArray (Seq (' nullable_id )))(GetArrayItem (CreateArray (Seq (' nullable_id )), 0 ))
415
+ checkRule(arrayAggRel, arrayAggRel)
416
+
417
+ // This could be done if we had a more complex rule that checks that
418
+ // the CreateMap does not come from key.
419
+ val originalQuery = relation
420
+ .groupBy(' id )(
421
+ GetMapValue (CreateMap (Seq (' id , ' id + 1L )), 0L ) as " a"
422
+ )
423
+ checkRule(originalQuery, originalQuery)
424
+ }
425
+
426
+ test(" SPARK-23500: namedStruct and getField in the same Project #1" ) {
427
+ val originalQuery =
428
+ testRelation
429
+ .select(
430
+ namedStruct(" col1" , ' b , " col2" , ' c ).as(" s1" ), ' a , ' b )
431
+ .select(' s1 getField " col2" as ' s1Col2 ,
432
+ namedStruct(" col1" , ' a , " col2" , ' b ).as(" s2" ))
433
+ .select(' s1Col2 , ' s2 getField " col2" as ' s2Col2 )
434
+ val correctAnswer =
435
+ testRelation
436
+ .select(' c as ' s1Col2 , ' b as ' s2Col2 )
437
+ checkRule(originalQuery, correctAnswer)
438
+ }
439
+
440
+ test(" SPARK-23500: namedStruct and getField in the same Project #2" ) {
441
+ val originalQuery =
442
+ testRelation
443
+ .select(
444
+ namedStruct(" col1" , ' b , " col2" , ' c ) getField " col2" as ' sCol2 ,
445
+ namedStruct(" col1" , ' a , " col2" , ' c ) getField " col1" as ' sCol1 )
446
+ val correctAnswer =
447
+ testRelation
448
+ .select(' c as ' sCol2 , ' a as ' sCol1 )
449
+ checkRule(originalQuery, correctAnswer)
378
450
}
379
451
}
0 commit comments