diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7801cd347f7dc..df2762ae92160 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -630,6 +630,20 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] foreach(actualFunc) } + /** + * A variant of [[foreachWithSubqueries]] with pruning support. + * Only traverses nodes that match the given condition. + */ + def foreachWithSubqueriesAndPruning( + cond: TreePatternBits => Boolean)(f: PlanType => Unit): Unit = { + if (!cond.apply(this)) { + return + } + f(this) + subqueries.foreach(_.foreachWithSubqueriesAndPruning(cond)(f)) + children.foreach(_.foreachWithSubqueriesAndPruning(cond)(f)) + } + /** * A variant of `collect`. This method not only apply the given function to all elements in this * plan, also considering all the plans in its (nested) subqueries. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala index 03ed466e2b039..91f990be7bb23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -25,7 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, ListQuery, Literal, NamedExpression, Rand} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern} import org.apache.spark.sql.types.IntegerType class QueryPlanSuite extends SparkFunSuite { @@ -160,4 +162,27 @@ class QueryPlanSuite extends SparkFunSuite { val planAfterTestRule = testRule(plan) assert(planAfterTestRule.output(0).nullable) } + + test("SPARK-54865: pruning works correctly in foreachWithSubqueriesAndPruning") { + val a: NamedExpression = AttributeReference("a", IntegerType)() + val plan = Project( + Seq(a), + Filter( + ListQuery(Project( + Seq(a), + UnresolvedRelation(TableIdentifier("t", None)) + )), + UnresolvedRelation(TableIdentifier("t", None)) + ) + ) + + val visited = ArrayBuffer[LogicalPlan]() + plan.foreachWithSubqueriesAndPruning(_.containsPattern(TreePattern.FILTER)) { p => + visited += p + } + + // Only 2 nodes contain FILTER pattern: outer Project and Filter + assert(visited.size == 2) + assert(visited.forall(_.containsPattern(TreePattern.FILTER))) + } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 0d95fe31e0637..b2c32df4d8635 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -242,7 +242,7 @@ class SparkConnectPlanner( } if (executeHolderOpt.isDefined) { - plan.transformUpWithSubqueriesAndPruning(_.containsPattern(TreePattern.COLLECT_METRICS)) { + plan.foreachWithSubqueriesAndPruning(_.containsPattern(TreePattern.COLLECT_METRICS)) { case collectMetrics: CollectMetrics if !collectMetrics.child.isStreaming => // TODO this might be too complex for no good reason. It might // be easier to inspect the plan after it completes. @@ -250,9 +250,10 @@ class SparkConnectPlanner( collectMetrics.name, collectMetrics.dataframeId) executeHolder.addObservation(collectMetrics.name, observation) - collectMetrics + case _ => } - } else plan + } + plan } private def transformRelationPlugin(extension: ProtoAny): LogicalPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala index 308651b449fd0..b5ec18d5ff122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala @@ -20,6 +20,7 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.sql.{Observation, Row} import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics +import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener @@ -54,7 +55,8 @@ private[sql] class ObservationManager(session: SparkSession) { private def tryComplete(qe: QueryExecution): Unit = { val allMetrics = qe.observedMetrics - qe.logical.foreach { + qe.logical.foreachWithSubqueriesAndPruning( + _.containsPattern(TreePattern.COLLECT_METRICS)) { case c: CollectMetrics => val keyExists = observations.containsKey((c.name, c.dataframeId)) val metrics = allMetrics.get(c.name)