1717package org .apache .gluten .extension .columnar .rewrite
1818
1919import org .apache .gluten .backendsapi .BackendsApiManager
20+ import org .apache .gluten .config .GlutenConfig
21+ import org .apache .gluten .extension .columnar .heuristic .RewrittenNodeWall
2022import org .apache .gluten .sql .shims .SparkShimLoader
2123import org .apache .gluten .utils .PullOutProjectHelper
2224
2325import org .apache .spark .sql .catalyst .expressions ._
2426import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Complete , Partial }
2527import org .apache .spark .sql .execution ._
2628import org .apache .spark .sql .execution .aggregate .{BaseAggregateExec , TypedAggregateExpression }
27- import org .apache .spark .sql .execution .joins .{BaseJoinExec , HashJoin }
29+ import org .apache .spark .sql .execution .joins .{BaseJoinExec , BroadcastHashJoinExec , BroadcastNestedLoopJoinExec , HashJoin }
2830import org .apache .spark .sql .execution .python .ArrowEvalPythonExec
2931import org .apache .spark .sql .execution .window .WindowExec
3032
@@ -293,6 +295,17 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper {
293295 arrowEvalPythonExec)
294296
295297 case join : BaseJoinExec if needsPreProject(join) =>
298+ join match {
299+ case _ : BroadcastHashJoinExec | _ : BroadcastNestedLoopJoinExec
300+ if ! GlutenConfig .get.enableColumnarProject =>
301+ // If columnar project is disabled, we cannot pull out project for join, since ProjectExec
302+ // not override doExecuteBroadcast methods, we cannot add project between broadcast join
303+ // and broadcast exchange.
304+ throw new UnsupportedOperationException (" columnar project is disabled, " +
305+ " broadcast join operator does not support pull out pre-project, and it will fallback." )
306+ case _ =>
307+ }
308+
296309 // Spark has an improvement which would patch integer joins keys to a Long value.
297310 // But this improvement would cause adding extra project before hash join in velox,
298311 // disabling this improvement as below would help reduce the project.
@@ -312,6 +325,12 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper {
312325 val preProject = ProjectExec (
313326 eliminateProjectList(joinChild.outputSet, expressionMap.values.toSeq),
314327 joinChild)
328+ joinChild match {
329+ case r : RewrittenNodeWall =>
330+ r.originalChild.logicalLink.foreach(preProject.setLogicalLink)
331+ case _ =>
332+ joinChild.logicalLink.foreach(preProject.setLogicalLink)
333+ }
315334 (preProject, newJoinKeys, expressionMap)
316335 } else {
317336 (joinChild, joinKeys, expressionMap)
@@ -329,9 +348,11 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper {
329348 } else {
330349 join.condition
331350 }
332- ProjectExec (
333- join.output,
334- copyBaseJoinExec(join)(newLeft, newRight, newLeftKeys, newRightKeys, newCondition))
351+ val newJoin =
352+ copyBaseJoinExec(join)(newLeft, newRight, newLeftKeys, newRightKeys, newCondition)
353+ val newProject = ProjectExec (join.output, newJoin)
354+ newJoin.logicalLink.foreach(newProject.setLogicalLink)
355+ newProject
335356
336357 case _ => plan
337358 }
0 commit comments