Skip to content

Commit bc8d093

Browse files
gatorsmilecloud-fan
authored andcommitted
[SPARK-23500][SQL][FOLLOWUP] Fix complex type simplification rules to apply to entire plan
## What changes were proposed in this pull request? This PR is to improve the test coverage of the original PR apache#20687 ## How was this patch tested? N/A Author: gatorsmile <[email protected]> Closes apache#20911 from gatorsmile/addTests.
1 parent 5b5a36e commit bc8d093

File tree

2 files changed

+233
-52
lines changed

2 files changed

+233
-52
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala

Lines changed: 124 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -47,42 +47,46 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
4747
SimplifyExtractValueOps) :: Nil
4848
}
4949

50-
val idAtt = ('id).long.notNull
51-
val nullableIdAtt = ('nullable_id).long
50+
private val idAtt = ('id).long.notNull
51+
private val nullableIdAtt = ('nullable_id).long
5252

53-
lazy val relation = LocalRelation(idAtt, nullableIdAtt)
53+
private val relation = LocalRelation(idAtt, nullableIdAtt)
54+
private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int)
55+
56+
private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = {
57+
val optimized = Optimizer.execute(originalQuery.analyze)
58+
assert(optimized.resolved, "optimized plans must be still resolvable")
59+
comparePlans(optimized, correctAnswer.analyze)
60+
}
5461

5562
test("explicit get from namedStruct") {
5663
val query = relation
5764
.select(
5865
GetStructField(
5966
CreateNamedStruct(Seq("att", 'id )),
6067
0,
61-
None) as "outerAtt").analyze
62-
val expected = relation.select('id as "outerAtt").analyze
68+
None) as "outerAtt")
69+
val expected = relation.select('id as "outerAtt")
6370

64-
comparePlans(Optimizer execute query, expected)
71+
checkRule(query, expected)
6572
}
6673

6774
test("explicit get from named_struct- expression maintains original deduced alias") {
6875
val query = relation
6976
.select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None))
70-
.analyze
7177

7278
val expected = relation
7379
.select('id as "named_struct(att, id).att")
74-
.analyze
7580

76-
comparePlans(Optimizer execute query, expected)
81+
checkRule(query, expected)
7782
}
7883

7984
test("collapsed getStructField ontop of namedStruct") {
8085
val query = relation
8186
.select(CreateNamedStruct(Seq("att", 'id)) as "struct1")
8287
.select(GetStructField('struct1, 0, None) as "struct1Att")
83-
.analyze
84-
val expected = relation.select('id as "struct1Att").analyze
85-
comparePlans(Optimizer execute query, expected)
88+
val expected = relation.select('id as "struct1Att")
89+
checkRule(query, expected)
8690
}
8791

8892
test("collapse multiple CreateNamedStruct/GetStructField pairs") {
@@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
9498
.select(
9599
GetStructField('struct1, 0, None) as "struct1Att1",
96100
GetStructField('struct1, 1, None) as "struct1Att2")
97-
.analyze
98101

99102
val expected =
100103
relation.
101104
select(
102105
'id as "struct1Att1",
103106
('id * 'id) as "struct1Att2")
104-
.analyze
105107

106-
comparePlans(Optimizer execute query, expected)
108+
checkRule(query, expected)
107109
}
108110

109111
test("collapsed2 - deduced names") {
@@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
115117
.select(
116118
GetStructField('struct1, 0, None),
117119
GetStructField('struct1, 1, None))
118-
.analyze
119120

120121
val expected =
121122
relation.
122123
select(
123124
'id as "struct1.att1",
124125
('id * 'id) as "struct1.att2")
125-
.analyze
126126

127-
comparePlans(Optimizer execute query, expected)
127+
checkRule(query, expected)
128128
}
129129

130130
test("simplified array ops") {
@@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
151151
1,
152152
false),
153153
1) as "a4")
154-
.analyze
155154

156155
val expected = relation
157156
.select(
@@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
161160
"att2", (('id + 1L) * ('id + 1L)))) as "a2",
162161
('id + 1L) as "a3",
163162
('id + 1L) as "a4")
164-
.analyze
165-
comparePlans(Optimizer execute query, expected)
163+
checkRule(query, expected)
166164
}
167165

168166
test("SPARK-22570: CreateArray should not create a lot of global variables") {
@@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
188186
GetStructField(GetMapValue('m, "r1"), 0, None) as "a2",
189187
GetMapValue('m, "r32") as "a3",
190188
GetStructField(GetMapValue('m, "r32"), 0, None) as "a4")
191-
.analyze
192189

193190
val expected =
194191
relation.select(
@@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
201198
)
202199
) as "a3",
203200
Literal.create(null, LongType) as "a4")
204-
.analyze
205-
comparePlans(Optimizer execute query, expected)
201+
checkRule(query, expected)
206202
}
207203

208204
test("simplify map ops, constant lookup, dynamic keys") {
@@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
216212
('id + 3L), ('id + 4L),
217213
('id + 4L), ('id + 5L))),
218214
13L) as "a")
219-
.analyze
220215

221216
val expected = relation
222217
.select(
@@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
225220
(EqualTo(13L, ('id + 1L)), ('id + 2L)),
226221
(EqualTo(13L, ('id + 2L)), ('id + 3L)),
227222
(Literal(true), 'id))) as "a")
228-
.analyze
229-
comparePlans(Optimizer execute query, expected)
223+
checkRule(query, expected)
230224
}
231225

232226
test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") {
@@ -240,16 +234,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
240234
('id + 3L), ('id + 4L),
241235
('id + 4L), ('id + 5L))),
242236
('id + 3L)) as "a")
243-
.analyze
244237
val expected = relation
245238
.select(
246239
CaseWhen(Seq(
247240
(EqualTo('id + 3L, 'id), ('id + 1L)),
248241
(EqualTo('id + 3L, ('id + 1L)), ('id + 2L)),
249242
(EqualTo('id + 3L, ('id + 2L)), ('id + 3L)),
250243
(Literal(true), ('id + 4L)))) as "a")
251-
.analyze
252-
comparePlans(Optimizer execute query, expected)
244+
checkRule(query, expected)
253245
}
254246

255247
test("simplify map ops, no positive match") {
@@ -263,16 +255,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
263255
('id + 3L), ('id + 4L),
264256
('id + 4L), ('id + 5L))),
265257
'id + 30L) as "a")
266-
.analyze
267258
val expected = relation.select(
268259
CaseWhen(Seq(
269260
(EqualTo('id + 30L, 'id), ('id + 1L)),
270261
(EqualTo('id + 30L, ('id + 1L)), ('id + 2L)),
271262
(EqualTo('id + 30L, ('id + 2L)), ('id + 3L)),
272263
(EqualTo('id + 30L, ('id + 3L)), ('id + 4L)),
273264
(EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a")
274-
.analyze
275-
comparePlans(Optimizer execute rel, expected)
265+
checkRule(rel, expected)
276266
}
277267

278268
test("simplify map ops, constant lookup, mixed keys, eliminated constants") {
@@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
287277
('id + 3L), ('id + 4L),
288278
('id + 4L), ('id + 5L))),
289279
13L) as "a")
290-
.analyze
291280

292281
val expected = relation
293282
.select(
@@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
297286
('id + 2L), ('id + 3L),
298287
('id + 3L), ('id + 4L),
299288
('id + 4L), ('id + 5L))) as "a")
300-
.analyze
301289

302-
comparePlans(Optimizer execute rel, expected)
290+
checkRule(rel, expected)
303291
}
304292

305293
test("simplify map ops, potential dynamic match with null value + an absolute constant match") {
@@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
314302
('id + 3L), ('id + 4L),
315303
('id + 4L), ('id + 5L))),
316304
2L ) as "a")
317-
.analyze
318305

319306
val expected = relation
320307
.select(
@@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
327314
// but it cannot override a potential match with ('id + 2L),
328315
// which is exactly what [[Coalesce]] would do in this case.
329316
(Literal.TrueLiteral, 'id))) as "a")
330-
.analyze
331-
comparePlans(Optimizer execute rel, expected)
317+
checkRule(rel, expected)
318+
}
319+
320+
test("SPARK-23500: Simplify array ops that are not at the top node") {
321+
val query = LocalRelation('id.long)
322+
.select(
323+
CreateArray(Seq(
324+
CreateNamedStruct(Seq(
325+
"att1", 'id,
326+
"att2", 'id * 'id)),
327+
CreateNamedStruct(Seq(
328+
"att1", 'id + 1,
329+
"att2", ('id + 1) * ('id + 1))
330+
))
331+
) as "arr")
332+
.select(
333+
GetStructField(GetArrayItem('arr, 1), 0, None) as "a1",
334+
GetArrayItem(
335+
GetArrayStructFields('arr,
336+
StructField("att1", LongType, nullable = false),
337+
ordinal = 0,
338+
numFields = 1,
339+
containsNull = false),
340+
ordinal = 1) as "a2")
341+
.orderBy('id.asc)
342+
343+
val expected = LocalRelation('id.long)
344+
.select(
345+
('id + 1L) as "a1",
346+
('id + 1L) as "a2")
347+
.orderBy('id.asc)
348+
checkRule(query, expected)
349+
}
350+
351+
test("SPARK-23500: Simplify map ops that are not top nodes") {
352+
val query =
353+
LocalRelation('id.long)
354+
.select(
355+
CreateMap(Seq(
356+
"r1", 'id,
357+
"r2", 'id + 1L)) as "m")
358+
.select(
359+
GetMapValue('m, "r1") as "a1",
360+
GetMapValue('m, "r32") as "a2")
361+
.orderBy('id.asc)
362+
.select('a1, 'a2)
363+
364+
val expected =
365+
LocalRelation('id.long).select(
366+
'id as "a1",
367+
Literal.create(null, LongType) as "a2")
368+
.orderBy('id.asc)
369+
checkRule(query, expected)
332370
}
333371

334372
test("SPARK-23500: Simplify complex ops that aren't at the plan root") {
335373
val structRel = relation
336374
.select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo")
337-
.groupBy($"foo")("1").analyze
375+
.groupBy($"foo")("1")
338376
val structExpected = relation
339377
.select('nullable_id as "foo")
340-
.groupBy($"foo")("1").analyze
341-
comparePlans(Optimizer execute structRel, structExpected)
378+
.groupBy($"foo")("1")
379+
checkRule(structRel, structExpected)
342380

343381
// These tests must use nullable attributes from the base relation for the following reason:
344382
// in the 'original' plans below, the Aggregate node produced by groupBy() has a
@@ -351,29 +389,63 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
351389
// SPARK-23634.
352390
val arrayRel = relation
353391
.select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1")
354-
.groupBy($"a1")("1").analyze
355-
val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze
356-
comparePlans(Optimizer execute arrayRel, arrayExpected)
392+
.groupBy($"a1")("1")
393+
val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1")
394+
checkRule(arrayRel, arrayExpected)
357395

358396
val mapRel = relation
359397
.select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1")
360-
.groupBy($"m1")("1").analyze
398+
.groupBy($"m1")("1")
361399
val mapExpected = relation
362400
.select('nullable_id as "m1")
363-
.groupBy($"m1")("1").analyze
364-
comparePlans(Optimizer execute mapRel, mapExpected)
401+
.groupBy($"m1")("1")
402+
checkRule(mapRel, mapExpected)
365403
}
366404

367405
test("SPARK-23500: Ensure that aggregation expressions are not simplified") {
368406
// Make sure that aggregation exprs are correctly ignored. Maps can't be used in
369407
// grouping exprs so aren't tested here.
370408
val structAggRel = relation.groupBy(
371409
CreateNamedStruct(Seq("att1", 'nullable_id)))(
372-
GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze
373-
comparePlans(Optimizer execute structAggRel, structAggRel)
410+
GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None))
411+
checkRule(structAggRel, structAggRel)
374412

375413
val arrayAggRel = relation.groupBy(
376-
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze
377-
comparePlans(Optimizer execute arrayAggRel, arrayAggRel)
414+
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0))
415+
checkRule(arrayAggRel, arrayAggRel)
416+
417+
// This could be done if we had a more complex rule that checks that
418+
// the CreateMap does not come from key.
419+
val originalQuery = relation
420+
.groupBy('id)(
421+
GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a"
422+
)
423+
checkRule(originalQuery, originalQuery)
424+
}
425+
426+
test("SPARK-23500: namedStruct and getField in the same Project #1") {
427+
val originalQuery =
428+
testRelation
429+
.select(
430+
namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b)
431+
.select('s1 getField "col2" as 's1Col2,
432+
namedStruct("col1", 'a, "col2", 'b).as("s2"))
433+
.select('s1Col2, 's2 getField "col2" as 's2Col2)
434+
val correctAnswer =
435+
testRelation
436+
.select('c as 's1Col2, 'b as 's2Col2)
437+
checkRule(originalQuery, correctAnswer)
438+
}
439+
440+
test("SPARK-23500: namedStruct and getField in the same Project #2") {
441+
val originalQuery =
442+
testRelation
443+
.select(
444+
namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2,
445+
namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1)
446+
val correctAnswer =
447+
testRelation
448+
.select('c as 'sCol2, 'a as 'sCol1)
449+
checkRule(originalQuery, correctAnswer)
378450
}
379451
}

0 commit comments

Comments
 (0)