Skip to content

Commit 48a0cc0

Browse files
authored
fix: Check reused broadcast plan in non-AQE and make setNumPartitions thread safe (#2398)
* fix: Check broadcast plan of ReusedExchangeExec * thread-safety
1 parent 9994657 commit 48a0cc0

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ object CometBroadcastExchangeExec {
276276
*/
277277
class CometBatchRDD(
278278
sc: SparkContext,
279-
numPartitions: Int,
279+
@volatile var numPartitions: Int,
280280
value: broadcast.Broadcast[Array[ChunkedByteBuffer]])
281281
extends RDD[ColumnarBatch](sc, Nil) {
282282

@@ -289,6 +289,12 @@ class CometBatchRDD(
289289
partition.value.value.toIterator
290290
.flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName))
291291
}
292+
293+
def withNumPartitions(numPartitions: Int): CometBatchRDD = {
294+
this.numPartitions = numPartitions
295+
this
296+
}
297+
292298
}
293299

294300
class CometBatchPartition(

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,15 @@ abstract class CometNativeExec extends CometExec {
238238
case (_: CometBroadcastExchangeExec, _) => false
239239
case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false
240240
case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false
241+
case (ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) => false
241242
case _ => true
242243
}
243244

244245
val containsBroadcastInput = sparkPlans.exists {
245246
case _: CometBroadcastExchangeExec => true
246247
case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
247248
case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true
249+
case ReusedExchangeExec(_, _: CometBroadcastExchangeExec) => true
248250
case _ => false
249251
}
250252

@@ -272,16 +274,28 @@ abstract class CometNativeExec extends CometExec {
272274
sparkPlans.zipWithIndex.foreach { case (plan, idx) =>
273275
plan match {
274276
case c: CometBroadcastExchangeExec =>
275-
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
277+
inputs += c
278+
.executeColumnar()
279+
.asInstanceOf[CometBatchRDD]
280+
.withNumPartitions(firstNonBroadcastPlanNumPartitions)
276281
case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) =>
277-
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
282+
inputs += c
283+
.executeColumnar()
284+
.asInstanceOf[CometBatchRDD]
285+
.withNumPartitions(firstNonBroadcastPlanNumPartitions)
278286
case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
279-
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
287+
inputs += c
288+
.executeColumnar()
289+
.asInstanceOf[CometBatchRDD]
290+
.withNumPartitions(firstNonBroadcastPlanNumPartitions)
280291
case BroadcastQueryStageExec(
281292
_,
282293
ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
283294
_) =>
284-
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
295+
inputs += c
296+
.executeColumnar()
297+
.asInstanceOf[CometBatchRDD]
298+
.withNumPartitions(firstNonBroadcastPlanNumPartitions)
285299
case _: CometNativeExec =>
286300
// no-op
287301
case _ if idx == firstNonBroadcastPlan.get._2 =>

0 commit comments

Comments
 (0)