Skip to content

Commit 2581ea5

Browse files
andygroveclaude
andcommitted
Add doExecuteBroadcast support to CometNativeColumnarToRowExec
When native columnar-to-row conversion is enabled (now the default), CometNativeColumnarToRowExec is used instead of CometColumnarToRowExec. However, it was missing the doExecuteBroadcast implementation required for broadcast exchange operations, causing test failures. Changes: - Add doExecuteBroadcast implementation to CometNativeColumnarToRowExec that uses the native converter for broadcast data transformation - Update CometExecSuite test to handle both CometColumnarToRowExec and CometNativeColumnarToRowExec - Fix parent-child relationship check to account for InputAdapter wrapper nodes used by Spark's codegen - Remove nodeName override from CometNativeColumnarToRowExec Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 5c7da07 commit 2581ea5

File tree

2 files changed

+134
-15
lines changed

2 files changed

+134
-15
lines changed

spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,25 @@
1919

2020
package org.apache.spark.sql.comet
2121

22-
import org.apache.spark.TaskContext
22+
import java.util.UUID
23+
import java.util.concurrent.{Future, TimeoutException, TimeUnit}
24+
25+
import scala.concurrent.Promise
26+
import scala.util.control.NonFatal
27+
28+
import org.apache.spark.{broadcast, SparkException, TaskContext}
2329
import org.apache.spark.rdd.RDD
2430
import org.apache.spark.sql.catalyst.InternalRow
2531
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
2632
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
27-
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
33+
import org.apache.spark.sql.comet.util.{Utils => CometUtils}
34+
import org.apache.spark.sql.errors.QueryExecutionErrors
35+
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan, SQLExecution}
36+
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
37+
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
2838
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
2939
import org.apache.spark.sql.types.StructType
30-
import org.apache.spark.util.Utils
40+
import org.apache.spark.util.{SparkFatalException, Utils}
3141

3242
import org.apache.comet.{CometConf, NativeColumnarToRowConverter}
3343

@@ -53,9 +63,6 @@ case class CometNativeColumnarToRowExec(child: SparkPlan)
5363
// supportsColumnar requires to be only called on driver side, see also SPARK-37779.
5464
assert(Utils.isInRunningSparkTask || child.supportsColumnar)
5565

56-
// Use the same display name as CometColumnarToRowExec for plan compatibility
57-
override def nodeName: String = "CometColumnarToRow"
58-
5966
override def output: Seq[Attribute] = child.output
6067

6168
override def outputPartitioning: Partitioning = child.outputPartitioning
@@ -67,6 +74,105 @@ case class CometNativeColumnarToRowExec(child: SparkPlan)
6774
"numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"),
6875
"convertTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time in conversion"))
6976

77+
@transient
78+
private lazy val promise = Promise[broadcast.Broadcast[Any]]()
79+
80+
@transient
81+
private val timeout: Long = conf.broadcastTimeout
82+
83+
private val runId: UUID = UUID.randomUUID
84+
85+
private lazy val cometBroadcastExchange = findCometBroadcastExchange(child)
86+
87+
@transient
88+
lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
89+
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
90+
session,
91+
CometBroadcastExchangeExec.executionContext) {
92+
try {
93+
// Setup a job group here so later it may get cancelled by groupId if necessary.
94+
sparkContext.setJobGroup(
95+
runId.toString,
96+
s"CometNativeColumnarToRow broadcast exchange (runId $runId)",
97+
interruptOnCancel = true)
98+
99+
val numOutputRows = longMetric("numOutputRows")
100+
val numInputBatches = longMetric("numInputBatches")
101+
val localSchema = this.schema
102+
val batchSize = CometConf.COMET_BATCH_SIZE.get()
103+
val broadcastColumnar = child.executeBroadcast()
104+
val serializedBatches =
105+
broadcastColumnar.value.asInstanceOf[Array[org.apache.spark.util.io.ChunkedByteBuffer]]
106+
107+
// Use native converter to convert columnar data to rows
108+
val converter = new NativeColumnarToRowConverter(localSchema, batchSize)
109+
try {
110+
val rows = serializedBatches.iterator
111+
.flatMap(CometUtils.decodeBatches(_, this.getClass.getSimpleName))
112+
.flatMap { batch =>
113+
numInputBatches += 1
114+
numOutputRows += batch.numRows()
115+
converter.convert(batch)
116+
}
117+
118+
val mode = cometBroadcastExchange.get.mode
119+
val relation = mode.transform(rows, Some(numOutputRows.value))
120+
val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true)
121+
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
122+
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
123+
promise.trySuccess(broadcasted)
124+
broadcasted
125+
} finally {
126+
converter.close()
127+
}
128+
} catch {
129+
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
130+
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
131+
// will catch this exception and re-throw the wrapped fatal throwable.
132+
case oe: OutOfMemoryError =>
133+
val ex = new SparkFatalException(oe)
134+
promise.tryFailure(ex)
135+
throw ex
136+
case e if !NonFatal(e) =>
137+
val ex = new SparkFatalException(e)
138+
promise.tryFailure(ex)
139+
throw ex
140+
case e: Throwable =>
141+
promise.tryFailure(e)
142+
throw e
143+
}
144+
}
145+
}
146+
147+
override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
148+
if (cometBroadcastExchange.isEmpty) {
149+
throw new SparkException(
150+
"CometNativeColumnarToRowExec only supports doExecuteBroadcast when child contains a " +
151+
"CometBroadcastExchange, but got " + child)
152+
}
153+
154+
try {
155+
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
156+
} catch {
157+
case ex: TimeoutException =>
158+
logError(s"Could not execute broadcast in $timeout secs.", ex)
159+
if (!relationFuture.isDone) {
160+
sparkContext.cancelJobGroup(runId.toString)
161+
relationFuture.cancel(true)
162+
}
163+
throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex))
164+
}
165+
}
166+
167+
private def findCometBroadcastExchange(op: SparkPlan): Option[CometBroadcastExchangeExec] = {
168+
op match {
169+
case b: CometBroadcastExchangeExec => Some(b)
170+
case b: BroadcastQueryStageExec => findCometBroadcastExchange(b.plan)
171+
case b: ReusedExchangeExec => findCometBroadcastExchange(b.child)
172+
case _ => op.children.collectFirst(Function.unlift(findCometBroadcastExchange))
173+
}
174+
}
175+
70176
override def doExecute(): RDD[InternalRow] = {
71177
val numOutputRows = longMetric("numOutputRows")
72178
val numInputBatches = longMetric("numInputBatches")

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, He
3535
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate}
3636
import org.apache.spark.sql.comet._
3737
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
38-
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
38+
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SparkPlan, SQLExecution, UnionExec}
3939
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec}
4040
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
4141
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
@@ -864,9 +864,11 @@ class CometExecSuite extends CometTestBase {
864864
checkSparkAnswerAndOperator(df)
865865

866866
// Before AQE: one CometBroadcastExchange, no CometColumnarToRow
867-
var columnarToRowExec = stripAQEPlan(df.queryExecution.executedPlan).collect {
868-
case s: CometColumnarToRowExec => s
869-
}
867+
var columnarToRowExec: Seq[SparkPlan] =
868+
stripAQEPlan(df.queryExecution.executedPlan).collect {
869+
case s: CometColumnarToRowExec => s
870+
case s: CometNativeColumnarToRowExec => s
871+
}
870872
assert(columnarToRowExec.isEmpty)
871873

872874
// Disable CometExecRule after the initial plan is generated. The CometSortMergeJoin and
@@ -880,14 +882,25 @@ class CometExecSuite extends CometTestBase {
880882
// After AQE: CometBroadcastExchange has to be converted to rows to conform to Spark
881883
// BroadcastHashJoin.
882884
val plan = stripAQEPlan(df.queryExecution.executedPlan)
883-
columnarToRowExec = plan.collect { case s: CometColumnarToRowExec =>
884-
s
885+
columnarToRowExec = plan.collect {
886+
case s: CometColumnarToRowExec => s
887+
case s: CometNativeColumnarToRowExec => s
885888
}
886889
assert(columnarToRowExec.length == 1)
887890

888-
// This ColumnarToRowExec should be the immediate child of BroadcastHashJoinExec
889-
val parent = plan.find(_.children.contains(columnarToRowExec.head))
890-
assert(parent.get.isInstanceOf[BroadcastHashJoinExec])
891+
// This ColumnarToRowExec should be a descendant of BroadcastHashJoinExec (possibly
892+
// wrapped by InputAdapter for codegen).
893+
val broadcastJoins = plan.collect { case b: BroadcastHashJoinExec => b }
894+
assert(broadcastJoins.nonEmpty, s"Expected BroadcastHashJoinExec in plan:\n$plan")
895+
val hasC2RDescendant = broadcastJoins.exists { join =>
896+
join.find {
897+
case _: CometColumnarToRowExec | _: CometNativeColumnarToRowExec => true
898+
case _ => false
899+
}.isDefined
900+
}
901+
assert(
902+
hasC2RDescendant,
903+
"BroadcastHashJoinExec should have a columnar-to-row descendant")
891904

892905
// There should be a CometBroadcastExchangeExec under CometColumnarToRowExec
893906
val broadcastQueryStage =

0 commit comments

Comments
 (0)