Skip to content

Commit a3fc246

Browse files
committed
fallback project
1 parent eff6286 commit a3fc246

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/RewriteSparkPlanRulesManager.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
package org.apache.gluten.extension.columnar.heuristic
1818

1919
import org.apache.gluten.extension.columnar.{FallbackTag, FallbackTags}
20+
import org.apache.gluten.extension.columnar.FallbackTag.Converter
2021
import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
2122

2223
import org.apache.spark.rdd.RDD
2324
import org.apache.spark.sql.catalyst.InternalRow
2425
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
2526
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2627
import org.apache.spark.sql.catalyst.rules.Rule
27-
import org.apache.spark.sql.execution.{LeafExecNode, ProjectExec, SparkPlan}
28+
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
2829

2930
case class RewrittenNodeWall(originalChild: SparkPlan) extends LeafExecNode {
3031
override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException()
@@ -50,16 +51,14 @@ class RewriteSparkPlanRulesManager private (
5051
FallbackTags.maybeOffloadable(plan) && rewriteRules.exists(_.isRewritable(plan))
5152
}
5253

53-
private def getFallbackTagBack(rewrittenPlan: SparkPlan): Option[FallbackTag] = {
54+
private def getFallbackTags(rewrittenPlan: SparkPlan): Seq[Option[FallbackTag]] = {
5455
// The rewritten plan may contain more nodes than origin, for now it should only be
5556
// `ProjectExec`.
5657
// TODO: Find a better approach than checking `p.isInstanceOf[ProjectExec]` which is not
5758
// general.
58-
val target = rewrittenPlan.collect {
59-
case p if !p.isInstanceOf[ProjectExec] && !p.isInstanceOf[RewrittenNodeWall] => p
59+
rewrittenPlan.collect {
60+
case p if !p.isInstanceOf[RewrittenNodeWall] => FallbackTags.getOption(p)
6061
}
61-
assert(target.size == 1)
62-
FallbackTags.getOption(target.head)
6362
}
6463

6564
private def applyRewriteRules(origin: SparkPlan): (SparkPlan, Option[String]) = {
@@ -99,10 +98,12 @@ class RewriteSparkPlanRulesManager private (
9998
origin
10099
} else {
101100
validateRule.apply(rewrittenPlan)
102-
val tag = getFallbackTagBack(rewrittenPlan)
103-
if (tag.isDefined) {
104-
// If the rewritten plan is still not transformable, return the original plan.
105-
FallbackTags.add(origin, tag.get)
101+
val fallbackTags = getFallbackTags(rewrittenPlan)
102+
if (fallbackTags.exists(_.isDefined)) {
103+
// If the rewritten origin node or inserted project is still not
104+
// transformable, return the original plan.
105+
val reason = fallbackTags.collect { case Some(s) => s.reason() }.mkString(", ")
106+
FallbackTags.add(origin, Converter.FromString.from(reason).get)
106107
origin
107108
} else {
108109
rewrittenPlan.transformUp {

0 commit comments

Comments
 (0)