Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 9909be3

Browse files
ConeyLiucloud-fan
authored andcommitted
[SPARK-21072][SQL] TreeNode.mapChildren should only apply to the children node.
## What changes were proposed in this pull request? Just as the function name and comments of `TreeNode.mapChildren` mentioned, the function should be apply to all currently node children. So, the follow code should judge whether it is the children node. https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L342 ## How was this patch tested? Existing tests. Author: Xianyang Liu <[email protected]> Closes apache#18284 from ConeyLiu/treenode. (cherry picked from commit 87ab0ce) Signed-off-by: Wenchen Fan <[email protected]>
1 parent a585c87 commit 9909be3

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
340340
arg
341341
}
342342
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
343-
val newChild1 = f(arg1.asInstanceOf[BaseType])
344-
val newChild2 = f(arg2.asInstanceOf[BaseType])
343+
val newChild1 = if (containsChild(arg1)) {
344+
f(arg1.asInstanceOf[BaseType])
345+
} else {
346+
arg1.asInstanceOf[BaseType]
347+
}
348+
349+
val newChild2 = if (containsChild(arg2)) {
350+
f(arg2.asInstanceOf[BaseType])
351+
} else {
352+
arg2.asInstanceOf[BaseType]
353+
}
354+
345355
if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
346356
changed = true
347357
(newChild1, newChild2)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
5454
override def output: Seq[Attribute] = Nil
5555
}
5656

57-
case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable {
57+
case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable {
5858
override def children: Seq[Expression] = map.values.toSeq
5959
override def nullable: Boolean = true
6060
override def dataType: NullType = NullType
6161
override lazy val resolved = true
6262
}
6363

64+
case class SeqTupleExpression(sons: Seq[(Expression, Expression)],
65+
nonSons: Seq[(Expression, Expression)]) extends Unevaluable {
66+
override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2))
67+
override def nullable: Boolean = true
68+
override def dataType: NullType = NullType
69+
override lazy val resolved = true
70+
}
71+
6472
case class JsonTestTreeNode(arg: Any) extends LeafNode {
6573
override def output: Seq[Attribute] = Seq.empty[Attribute]
6674
}
@@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite {
146154
assert(actual === Dummy(None))
147155
}
148156

157+
test("mapChildren should only works on children") {
158+
val children = Seq((Literal(1), Literal(2)))
159+
val nonChildren = Seq((Literal(3), Literal(4)))
160+
val before = SeqTupleExpression(children, nonChildren)
161+
val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) }
162+
val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), nonChildren)
163+
164+
val actual = before mapChildren toZero
165+
assert(actual === expect)
166+
}
167+
149168
test("preserves origin") {
150169
CurrentOrigin.setPosition(1, 1)
151170
val add = Add(Literal(1), Literal(1))

0 commit comments

Comments
 (0)