Skip to content

Commit 362d0c2

Browse files
marin-maQCLyu
authored andcommitted
[GLUTEN-11251] Fix incorrect whole stage id in WholeStageTransformerExec (#11252)
1 parent 2e655ff commit 362d0c2

File tree

18 files changed

+259
-23
lines changed

18 files changed

+259
-23
lines changed

backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
514514
val childWithAdapter = ColumnarCollapseTransformStages.wrapInputIteratorTransformer(child)
515515
WholeStageTransformer(
516516
ProjectExecTransformer(child.output ++ appendedProjections, childWithAdapter))(
517-
ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet()
517+
ColumnarCollapseTransformStages
518+
.getTransformStageCounter(childWithAdapter)
519+
.incrementAndGet()
518520
)
519521
}
520522

backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHAggAndShuffleBenchmark.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,10 @@ object CHAggAndShuffleBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchma
156156
// Get the `FileSourceScanExecTransformer`
157157
val fileScan = executedPlan.collect { case scan: FileSourceScanExecTransformer => scan }.head
158158
val scanStage = WholeStageTransformer(fileScan)(
159-
ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet())
159+
ColumnarCollapseTransformStages
160+
.getTransformStageCounter(fileScan)
161+
.incrementAndGet()
162+
)
160163
val scanStageRDD = scanStage.executeColumnar()
161164

162165
// Get the total row count
@@ -200,7 +203,9 @@ object CHAggAndShuffleBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchma
200203
val projectFilter = executedPlan.collect { case project: ProjectExecTransformer => project }
201204
if (projectFilter.nonEmpty) {
202205
val projectFilterStage = WholeStageTransformer(projectFilter.head)(
203-
ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet())
206+
ColumnarCollapseTransformStages
207+
.getTransformStageCounter(projectFilter.head)
208+
.incrementAndGet())
204209
val projectFilterStageRDD = projectFilterStage.executeColumnar()
205210

206211
chAllStagesBenchmark.addCase(s"Project Stage", executedCnt) {

backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,8 @@ object GlutenDeltaFileFormatWriter extends LoggingShims {
260260
nativeSortPlan
261261
}
262262
val newPlan = sortPlan.child match {
263-
case WholeStageTransformer(wholeStageChild, materializeInput) =>
264-
WholeStageTransformer(addNativeSort(wholeStageChild),
265-
materializeInput)(ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet())
263+
case wst @ WholeStageTransformer(wholeStageChild, _) =>
264+
wst.withNewChildren(Seq(addNativeSort(wholeStageChild)))
266265
case other =>
267266
Transitions.toBatchPlan(sortPlan, VeloxBatchType)
268267
}

backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ object VeloxRuleApi {
139139
.getExtendedColumnarPostRules()
140140
.foreach(each => injector.injectPost(c => each(c.session)))
141141
injector.injectPost(c => ColumnarCollapseTransformStages(new GlutenConfig(c.sqlConf)))
142+
injector.injectPost(_ => GenerateTransformStageId())
142143
injector.injectPost(c => CudfNodeValidationRule(new GlutenConfig(c.sqlConf)))
143144

144145
injector.injectPost(c => GlutenNoopWriterRule(c.session))
@@ -240,6 +241,7 @@ object VeloxRuleApi {
240241
.getExtendedColumnarPostRules()
241242
.foreach(each => injector.injectPostTransform(c => each(c.session)))
242243
injector.injectPostTransform(c => ColumnarCollapseTransformStages(new GlutenConfig(c.sqlConf)))
244+
injector.injectPostTransform(_ => GenerateTransformStageId())
243245
injector.injectPostTransform(c => CudfNodeValidationRule(new GlutenConfig(c.sqlConf)))
244246
injector.injectPostTransform(c => GlutenNoopWriterRule(c.session))
245247
injector.injectPostTransform(c => RemoveGlutenTableCacheColumnarToRow(c.session))

gluten-substrait/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ case class TakeOrderedAndProjectExecTransformer(
137137
LimitExecTransformer(localSortPlan, limitBeforeShuffleOffset, limit)
138138
}
139139
val transformStageCounter: AtomicInteger =
140-
ColumnarCollapseTransformStages.transformStageCounter
140+
ColumnarCollapseTransformStages.getTransformStageCounter(child)
141141
val finalLimitPlan = if (hasShuffle) {
142142
limitBeforeShuffle
143143
} else {

gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ trait UnaryTransformSupport extends TransformSupport with UnaryExecNode {
155155
}
156156

157157
case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = false)(
158-
val transformStageId: Int
158+
var transformStageId: Int
159159
) extends WholeStageTransformerGenerateTreeStringShim
160160
with UnaryTransformSupport {
161161

gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,7 @@ case class InputIteratorTransformer(child: SparkPlan) extends UnaryTransformSupp
139139
* created, e.g. for special fallback handling when an existing WholeStageTransformer failed to
140140
* generate/compile code.
141141
*/
142-
case class ColumnarCollapseTransformStages(
143-
glutenConf: GlutenConfig,
144-
transformStageCounter: AtomicInteger = ColumnarCollapseTransformStages.transformStageCounter)
145-
extends Rule[SparkPlan] {
142+
case class ColumnarCollapseTransformStages(glutenConf: GlutenConfig) extends Rule[SparkPlan] {
146143

147144
def apply(plan: SparkPlan): SparkPlan = {
148145
insertWholeStageTransformer(plan)
@@ -176,8 +173,8 @@ case class ColumnarCollapseTransformStages(
176173
private def insertWholeStageTransformer(plan: SparkPlan): SparkPlan = {
177174
plan match {
178175
case t if supportTransform(t) =>
179-
WholeStageTransformer(t.withNewChildren(t.children.map(insertInputIteratorTransformer)))(
180-
transformStageCounter.incrementAndGet())
176+
// transformStageId will be updated by rule `GenerateTransformStageId`.
177+
WholeStageTransformer(t.withNewChildren(t.children.map(insertInputIteratorTransformer)))(-1)
181178
case other =>
182179
other.withNewChildren(other.children.map(insertWholeStageTransformer))
183180
}
@@ -213,9 +210,20 @@ case class ColumnarInputAdapter(child: SparkPlan)
213210
}
214211

215212
object ColumnarCollapseTransformStages {
216-
val transformStageCounter = new AtomicInteger(0)
217-
218213
def wrapInputIteratorTransformer(plan: SparkPlan): TransformSupport = {
219214
InputIteratorTransformer(ColumnarInputAdapter(plan))
220215
}
216+
217+
def getTransformStageCounter(plan: SparkPlan): AtomicInteger = {
218+
new AtomicInteger(findMaxTransformStageId(plan))
219+
}
220+
221+
private def findMaxTransformStageId(plan: SparkPlan): Int = {
222+
plan match {
223+
case wst: WholeStageTransformer =>
224+
wst.transformStageId
225+
case _ =>
226+
plan.children.map(findMaxTransformStageId).foldLeft(0)(Math.max)
227+
}
228+
}
221229
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution
18+
19+
import org.apache.gluten.exception.GlutenException
20+
import org.apache.gluten.execution.WholeStageTransformer
21+
import org.apache.gluten.sql.shims.SparkShimLoader
22+
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, BroadcastQueryStageExec, ShuffleQueryStageExec}
25+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec, ShuffleExchangeLike}
26+
27+
import java.util
28+
import java.util.Collections.newSetFromMap
29+
import java.util.concurrent.atomic.AtomicInteger
30+
31+
/**
32+
* Generate `transformStageId` for `WholeStageTransformerExec`. This rule updates the whole plan
33+
* tree with * incremental and unique transform stage id before the final execution.
34+
*
35+
* In Spark, the whole stage id is generated by incrementing a global counter. In Gluten, it's not
36+
* possible to use global counter for id generation, especially in the case of AQE.
37+
*/
38+
case class GenerateTransformStageId() extends Rule[SparkPlan] with AdaptiveSparkPlanHelper {
39+
private val transformStageCounter: AtomicInteger = new AtomicInteger(0)
40+
41+
private val wholeStageTransformerCache =
42+
newSetFromMap[WholeStageTransformer](new util.IdentityHashMap())
43+
44+
def apply(plan: SparkPlan): SparkPlan = {
45+
updateStageId(plan)
46+
plan
47+
}
48+
49+
private def updateStageId(plan: SparkPlan): Unit = {
50+
plan match {
51+
case b: BroadcastQueryStageExec =>
52+
b.plan match {
53+
case b: BroadcastExchangeLike => updateStageId(b)
54+
case _: ReusedExchangeExec =>
55+
case _ =>
56+
throw new GlutenException(s"wrong plan for broadcast stage:\n ${plan.treeString}")
57+
}
58+
case s: ShuffleQueryStageExec =>
59+
s.plan match {
60+
case s: ShuffleExchangeLike => updateStageId(s)
61+
case _: ReusedExchangeExec =>
62+
case _ =>
63+
throw new GlutenException(s"wrong plan for shuffle stage:\n ${plan.treeString}")
64+
}
65+
case aqe: AdaptiveSparkPlanExec if SparkShimLoader.getSparkShims.isFinalAdaptivePlan(aqe) =>
66+
updateStageId(stripAQEPlan(aqe))
67+
case wst: WholeStageTransformer if !wholeStageTransformerCache.contains(wst) =>
68+
updateStageId(wst.child)
69+
wst.transformStageId = transformStageCounter.incrementAndGet()
70+
wholeStageTransformerCache.add(wst)
71+
case plan =>
72+
plan.subqueries.foreach(updateStageId)
73+
plan.children.foreach(updateStageId)
74+
}
75+
}
76+
}

gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,7 @@ object GlutenImplicits {
9595
}
9696

9797
private def isFinalAdaptivePlan(p: AdaptiveSparkPlanExec): Boolean = {
98-
val args = p.argString(Int.MaxValue)
99-
val index = args.indexOf("isFinalPlan=")
100-
assert(index >= 0)
101-
args.substring(index + "isFinalPlan=".length).trim.toBoolean
98+
SparkShimLoader.getSparkShims.isFinalAdaptivePlan(p)
10299
}
103100

104101
private def collectFallbackNodes(

gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
2323
import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions}
2424

2525
import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, SparkPlan}
26-
import org.apache.spark.sql.execution.ColumnarCollapseTransformStages.transformStageCounter
2726

2827
trait GlutenFormatWriterInjectsBase extends GlutenFormatWriterInjects {
2928
private lazy val transform = HeuristicTransform.static()
@@ -66,7 +65,7 @@ trait GlutenFormatWriterInjectsBase extends GlutenFormatWriterInjects {
6665
// and cannot provide const-ness.
6766
val transformedWithAdapter = injectAdapter(transformed)
6867
val wst = WholeStageTransformer(transformedWithAdapter, materializeInput = true)(
69-
transformStageCounter.incrementAndGet())
68+
ColumnarCollapseTransformStages.getTransformStageCounter(transformed).incrementAndGet())
7069
val wstWithTransitions = BackendsApiManager.getSparkPlanExecApiInstance.genColumnarToCarrierRow(
7170
InsertTransitions.create(outputsColumnar = true, wst.batchType()).apply(wst))
7271
wstWithTransitions

0 commit comments

Comments
 (0)