Skip to content

Commit c1e1120

Browse files
authored
fix: Avoid spark plan execution cache preventing CometBatchRDD numPartitions change (#2420)
* fix: Avoid spark plan execution cache preventing CometBatchRDD numPartitions change * refactor
1 parent 9e1f70f commit c1e1120

File tree

2 files changed

+19
-31
lines changed

2 files changed

+19
-31
lines changed

spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scala.concurrent.{ExecutionContext, Promise}
2626
import scala.concurrent.duration.NANOSECONDS
2727
import scala.util.control.NonFatal
2828

29-
import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
29+
import org.apache.spark.{broadcast, Partition, SparkContext, SparkException, TaskContext}
3030
import org.apache.spark.rdd.RDD
3131
import org.apache.spark.sql.catalyst.InternalRow
3232
import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -102,14 +102,8 @@ case class CometBroadcastExchangeExec(
102102
@transient
103103
private lazy val maxBroadcastRows = 512000000
104104

105-
private var numPartitions: Option[Int] = None
106-
107-
def setNumPartitions(numPartitions: Int): CometBroadcastExchangeExec = {
108-
this.numPartitions = Some(numPartitions)
109-
this
110-
}
111105
def getNumPartitions(): Int = {
112-
numPartitions.getOrElse(child.executeColumnar().getNumPartitions)
106+
child.executeColumnar().getNumPartitions
113107
}
114108

115109
@transient
@@ -224,6 +218,18 @@ case class CometBroadcastExchangeExec(
224218
new CometBatchRDD(sparkContext, getNumPartitions(), broadcasted)
225219
}
226220

221+
// After https://issues.apache.org/jira/browse/SPARK-48195, Spark plan will cache created RDD.
222+
// Since we may change the number of partitions in CometBatchRDD,
223+
// we need a method that always creates a new CometBatchRDD.
224+
def executeColumnar(numPartitions: Int): RDD[ColumnarBatch] = {
225+
if (isCanonicalizedPlan) {
226+
throw SparkException.internalError("A canonicalized plan is not supposed to be executed.")
227+
}
228+
229+
val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]()
230+
new CometBatchRDD(sparkContext, numPartitions, broadcasted)
231+
}
232+
227233
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
228234
try {
229235
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
@@ -276,7 +282,7 @@ object CometBroadcastExchangeExec {
276282
*/
277283
class CometBatchRDD(
278284
sc: SparkContext,
279-
@volatile var numPartitions: Int,
285+
val numPartitions: Int,
280286
value: broadcast.Broadcast[Array[ChunkedByteBuffer]])
281287
extends RDD[ColumnarBatch](sc, Nil) {
282288

@@ -289,12 +295,6 @@ class CometBatchRDD(
289295
partition.value.value.toIterator
290296
.flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName))
291297
}
292-
293-
def withNumPartitions(numPartitions: Int): CometBatchRDD = {
294-
this.numPartitions = numPartitions
295-
this
296-
}
297-
298298
}
299299

300300
class CometBatchPartition(

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -274,28 +274,16 @@ abstract class CometNativeExec extends CometExec {
274274
sparkPlans.zipWithIndex.foreach { case (plan, idx) =>
275275
plan match {
276276
case c: CometBroadcastExchangeExec =>
277-
inputs += c
278-
.executeColumnar()
279-
.asInstanceOf[CometBatchRDD]
280-
.withNumPartitions(firstNonBroadcastPlanNumPartitions)
277+
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
281278
case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) =>
282-
inputs += c
283-
.executeColumnar()
284-
.asInstanceOf[CometBatchRDD]
285-
.withNumPartitions(firstNonBroadcastPlanNumPartitions)
279+
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
286280
case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
287-
inputs += c
288-
.executeColumnar()
289-
.asInstanceOf[CometBatchRDD]
290-
.withNumPartitions(firstNonBroadcastPlanNumPartitions)
281+
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
291282
case BroadcastQueryStageExec(
292283
_,
293284
ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
294285
_) =>
295-
inputs += c
296-
.executeColumnar()
297-
.asInstanceOf[CometBatchRDD]
298-
.withNumPartitions(firstNonBroadcastPlanNumPartitions)
286+
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
299287
case _: CometNativeExec =>
300288
// no-op
301289
case _ if idx == firstNonBroadcastPlan.get._2 =>

0 commit comments

Comments
 (0)