1717package org .apache .gluten .extension .columnar .heuristic
1818
1919import org .apache .gluten .extension .columnar .{FallbackTag , FallbackTags }
20+ import org .apache .gluten .extension .columnar .FallbackTag .Converter
2021import org .apache .gluten .extension .columnar .rewrite .RewriteSingleNode
2122
2223import org .apache .spark .rdd .RDD
2324import org .apache .spark .sql .catalyst .InternalRow
2425import org .apache .spark .sql .catalyst .expressions .{Attribute , SortOrder }
2526import org .apache .spark .sql .catalyst .plans .physical .Partitioning
2627import org .apache .spark .sql .catalyst .rules .Rule
27- import org .apache .spark .sql .execution .{LeafExecNode , ProjectExec , SparkPlan }
28+ import org .apache .spark .sql .execution .{LeafExecNode , SparkPlan }
2829
2930case class RewrittenNodeWall (originalChild : SparkPlan ) extends LeafExecNode {
3031 override protected def doExecute (): RDD [InternalRow ] = throw new UnsupportedOperationException ()
@@ -50,16 +51,14 @@ class RewriteSparkPlanRulesManager private (
5051 FallbackTags .maybeOffloadable(plan) && rewriteRules.exists(_.isRewritable(plan))
5152 }
5253
53- private def getFallbackTagBack (rewrittenPlan : SparkPlan ): Option [FallbackTag ] = {
54+ private def getFallbackTags (rewrittenPlan : SparkPlan ): Seq [ Option [FallbackTag ] ] = {
5455 // The rewritten plan may contain more nodes than origin, for now it should only be
5556 // `ProjectExec`.
5657 // TODO: Find a better approach than checking `p.isInstanceOf[ProjectExec]` which is not
5758 // general.
58- val target = rewrittenPlan.collect {
59- case p if ! p.isInstanceOf [ProjectExec ] && ! p. isInstanceOf [ RewrittenNodeWall ] => p
59+ rewrittenPlan.collect {
60+ case p if ! p.isInstanceOf [RewrittenNodeWall ] => FallbackTags .getOption(p)
6061 }
61- assert(target.size == 1 )
62- FallbackTags .getOption(target.head)
6362 }
6463
6564 private def applyRewriteRules (origin : SparkPlan ): (SparkPlan , Option [String ]) = {
@@ -99,10 +98,12 @@ class RewriteSparkPlanRulesManager private (
9998 origin
10099 } else {
101100 validateRule.apply(rewrittenPlan)
102- val tag = getFallbackTagBack(rewrittenPlan)
103- if (tag.isDefined) {
104- // If the rewritten plan is still not transformable, return the original plan.
105- FallbackTags .add(origin, tag.get)
101+ val fallbackTags = getFallbackTags(rewrittenPlan)
102+ if (fallbackTags.exists(_.isDefined)) {
103+ // If the rewritten origin node or inserted project is still not
104+ // transformable, return the original plan.
105+ val reason = fallbackTags.collect { case Some (s) => s.reason() }.mkString(" , " )
106+ FallbackTags .add(origin, Converter .FromString .from(reason).get)
106107 origin
107108 } else {
108109 rewrittenPlan.transformUp {
0 commit comments