1919
2020package org .apache .spark .sql .comet
2121
22- import org .apache .spark .TaskContext
22+ import java .util .UUID
23+ import java .util .concurrent .{Future , TimeoutException , TimeUnit }
24+
25+ import scala .concurrent .Promise
26+ import scala .util .control .NonFatal
27+
28+ import org .apache .spark .{broadcast , SparkException , TaskContext }
2329import org .apache .spark .rdd .RDD
2430import org .apache .spark .sql .catalyst .InternalRow
2531import org .apache .spark .sql .catalyst .expressions .{Attribute , SortOrder }
2632import org .apache .spark .sql .catalyst .plans .physical .Partitioning
27- import org .apache .spark .sql .execution .{ColumnarToRowTransition , SparkPlan }
33+ import org .apache .spark .sql .comet .util .{Utils => CometUtils }
34+ import org .apache .spark .sql .errors .QueryExecutionErrors
35+ import org .apache .spark .sql .execution .{ColumnarToRowTransition , SparkPlan , SQLExecution }
36+ import org .apache .spark .sql .execution .adaptive .BroadcastQueryStageExec
37+ import org .apache .spark .sql .execution .exchange .ReusedExchangeExec
2838import org .apache .spark .sql .execution .metric .{SQLMetric , SQLMetrics }
2939import org .apache .spark .sql .types .StructType
30- import org .apache .spark .util .Utils
40+ import org .apache .spark .util .{ SparkFatalException , Utils }
3141
3242import org .apache .comet .{CometConf , NativeColumnarToRowConverter }
3343
@@ -53,9 +63,6 @@ case class CometNativeColumnarToRowExec(child: SparkPlan)
5363 // supportsColumnar requires to be only called on driver side, see also SPARK-37779.
5464 assert(Utils .isInRunningSparkTask || child.supportsColumnar)
5565
56- // Use the same display name as CometColumnarToRowExec for plan compatibility
57- override def nodeName : String = " CometColumnarToRow"
58-
5966 override def output : Seq [Attribute ] = child.output
6067
6168 override def outputPartitioning : Partitioning = child.outputPartitioning
@@ -67,6 +74,105 @@ case class CometNativeColumnarToRowExec(child: SparkPlan)
6774 " numInputBatches" -> SQLMetrics .createMetric(sparkContext, " number of input batches" ),
6875 " convertTime" -> SQLMetrics .createNanoTimingMetric(sparkContext, " time in conversion" ))
6976
77+ @ transient
78+ private lazy val promise = Promise [broadcast.Broadcast [Any ]]()
79+
80+ @ transient
81+ private val timeout : Long = conf.broadcastTimeout
82+
83+ private val runId : UUID = UUID .randomUUID
84+
85+ private lazy val cometBroadcastExchange = findCometBroadcastExchange(child)
86+
87+ @ transient
88+ lazy val relationFuture : Future [broadcast.Broadcast [Any ]] = {
89+ SQLExecution .withThreadLocalCaptured[broadcast.Broadcast [Any ]](
90+ session,
91+ CometBroadcastExchangeExec .executionContext) {
92+ try {
93+ // Setup a job group here so later it may get cancelled by groupId if necessary.
94+ sparkContext.setJobGroup(
95+ runId.toString,
96+ s " CometNativeColumnarToRow broadcast exchange (runId $runId) " ,
97+ interruptOnCancel = true )
98+
99+ val numOutputRows = longMetric(" numOutputRows" )
100+ val numInputBatches = longMetric(" numInputBatches" )
101+ val localSchema = this .schema
102+ val batchSize = CometConf .COMET_BATCH_SIZE .get()
103+ val broadcastColumnar = child.executeBroadcast()
104+ val serializedBatches =
105+ broadcastColumnar.value.asInstanceOf [Array [org.apache.spark.util.io.ChunkedByteBuffer ]]
106+
107+ // Use native converter to convert columnar data to rows
108+ val converter = new NativeColumnarToRowConverter (localSchema, batchSize)
109+ try {
110+ val rows = serializedBatches.iterator
111+ .flatMap(CometUtils .decodeBatches(_, this .getClass.getSimpleName))
112+ .flatMap { batch =>
113+ numInputBatches += 1
114+ numOutputRows += batch.numRows()
115+ converter.convert(batch)
116+ }
117+
118+ val mode = cometBroadcastExchange.get.mode
119+ val relation = mode.transform(rows, Some (numOutputRows.value))
120+ val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true )
121+ val executionId = sparkContext.getLocalProperty(SQLExecution .EXECUTION_ID_KEY )
122+ SQLMetrics .postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
123+ promise.trySuccess(broadcasted)
124+ broadcasted
125+ } finally {
126+ converter.close()
127+ }
128+ } catch {
129+ // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
130+ // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
131+ // will catch this exception and re-throw the wrapped fatal throwable.
132+ case oe : OutOfMemoryError =>
133+ val ex = new SparkFatalException (oe)
134+ promise.tryFailure(ex)
135+ throw ex
136+ case e if ! NonFatal (e) =>
137+ val ex = new SparkFatalException (e)
138+ promise.tryFailure(ex)
139+ throw ex
140+ case e : Throwable =>
141+ promise.tryFailure(e)
142+ throw e
143+ }
144+ }
145+ }
146+
147+ override def doExecuteBroadcast [T ](): broadcast.Broadcast [T ] = {
148+ if (cometBroadcastExchange.isEmpty) {
149+ throw new SparkException (
150+ " CometNativeColumnarToRowExec only supports doExecuteBroadcast when child contains a " +
151+ " CometBroadcastExchange, but got " + child)
152+ }
153+
154+ try {
155+ relationFuture.get(timeout, TimeUnit .SECONDS ).asInstanceOf [broadcast.Broadcast [T ]]
156+ } catch {
157+ case ex : TimeoutException =>
158+ logError(s " Could not execute broadcast in $timeout secs. " , ex)
159+ if (! relationFuture.isDone) {
160+ sparkContext.cancelJobGroup(runId.toString)
161+ relationFuture.cancel(true )
162+ }
163+ throw QueryExecutionErrors .executeBroadcastTimeoutError(timeout, Some (ex))
164+ }
165+ }
166+
167+ private def findCometBroadcastExchange (op : SparkPlan ): Option [CometBroadcastExchangeExec ] = {
168+ op match {
169+ case b : CometBroadcastExchangeExec => Some (b)
170+ case b : BroadcastQueryStageExec => findCometBroadcastExchange(b.plan)
171+ case b : ReusedExchangeExec => findCometBroadcastExchange(b.child)
172+ case _ => op.children.collectFirst(Function .unlift(findCometBroadcastExchange))
173+ }
174+ }
175+
70176 override def doExecute (): RDD [InternalRow ] = {
71177 val numOutputRows = longMetric(" numOutputRows" )
72178 val numInputBatches = longMetric(" numInputBatches" )
0 commit comments