Skip to content

Commit cb66f9a

Browse files
andygroveclaude
andcommitted
perf: cache and broadcast serialized plans across partitions
Serialize native query plans once and broadcast to all executors, avoiding repeated protobuf serialization for each partition. This optimization: - Adds serializePlan() method to serialize an Operator once - Adds getCometIterator() overload accepting pre-serialized bytes - Updates getNativeLimitRDD to broadcast the serialized plan - Updates CometTakeOrderedAndProjectExec to broadcast the topK plan For a query with 1000 partitions across 10 executors, this reduces plan serialization from 1000x to 1x, and plan transfer from 1000x to 10x (once per executor via broadcast). Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 86f6eb6 commit cb66f9a

File tree

3 files changed

+56
-12
lines changed

3 files changed

+56
-12
lines changed

spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,13 @@ object CometExecUtils {
5353
limit: Int,
5454
offset: Int = 0): RDD[ColumnarBatch] = {
5555
val numParts = childPlan.getNumPartitions
56+
val numOutputCols = outputAttribute.length
57+
// Serialize the plan once and broadcast to all executors to avoid repeated serialization
58+
val serializedPlan = CometExec.serializePlan(
59+
CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get)
60+
val broadcastPlan = childPlan.sparkContext.broadcast(serializedPlan)
5661
childPlan.mapPartitionsWithIndexInternal { case (idx, iter) =>
57-
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get
58-
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx)
62+
CometExec.getCometIterator(Seq(iter), numOutputCols, broadcastPlan.value, numParts, idx)
5963
}
6064
}
6165

spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,20 @@ case class CometTakeOrderedAndProjectExec(
133133
CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit)
134134
} else {
135135
val numParts = childRDD.getNumPartitions
136+
val numOutputCols = child.output.length
137+
// Serialize the plan once and broadcast to avoid repeated serialization
138+
val serializedTopK = CometExec.serializePlan(
139+
CometExecUtils
140+
.getTopKNativePlan(child.output, sortOrder, child, limit)
141+
.get)
142+
val broadcastTopK = sparkContext.broadcast(serializedTopK)
136143
childRDD.mapPartitionsWithIndexInternal { case (idx, iter) =>
137-
val topK =
138-
CometExecUtils
139-
.getTopKNativePlan(child.output, sortOrder, child, limit)
140-
.get
141-
CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx)
144+
CometExec.getCometIterator(
145+
Seq(iter),
146+
numOutputCols,
147+
broadcastTopK.value,
148+
numParts,
149+
idx)
142150
}
143151
}
144152

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,20 @@ object CometExec {
114114

115115
def newIterId: Long = curId.getAndIncrement()
116116

117+
/**
118+
* Serializes a native plan operator to a byte array. This method should be called once outside
119+
* of partition iteration, and the resulting bytes can be reused across all partitions to avoid
120+
* repeated serialization overhead.
121+
*/
122+
def serializePlan(nativePlan: Operator): Array[Byte] = {
123+
val size = nativePlan.getSerializedSize
124+
val bytes = new Array[Byte](size)
125+
val codedOutput = CodedOutputStream.newInstance(bytes)
126+
nativePlan.writeTo(codedOutput)
127+
codedOutput.checkNoSpaceLeft()
128+
bytes
129+
}
130+
117131
def getCometIterator(
118132
inputs: Seq[Iterator[ColumnarBatch]],
119133
numOutputCols: Int,
@@ -131,6 +145,28 @@ object CometExec {
131145
encryptedFilePaths = Seq.empty)
132146
}
133147

148+
/**
149+
* Creates a CometExecIterator from pre-serialized plan bytes. Use this overload when the same
150+
* plan is used across multiple partitions to avoid serializing the plan repeatedly.
151+
*/
152+
def getCometIterator(
153+
inputs: Seq[Iterator[ColumnarBatch]],
154+
numOutputCols: Int,
155+
serializedPlan: Array[Byte],
156+
numParts: Int,
157+
partitionIdx: Int): CometExecIterator = {
158+
new CometExecIterator(
159+
newIterId,
160+
inputs,
161+
numOutputCols,
162+
serializedPlan,
163+
CometMetricNode(Map.empty),
164+
numParts,
165+
partitionIdx,
166+
broadcastedHadoopConfForEncryption = None,
167+
encryptedFilePaths = Seq.empty)
168+
}
169+
134170
def getCometIterator(
135171
inputs: Seq[Iterator[ColumnarBatch]],
136172
numOutputCols: Int,
@@ -140,11 +176,7 @@ object CometExec {
140176
partitionIdx: Int,
141177
broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]],
142178
encryptedFilePaths: Seq[String]): CometExecIterator = {
143-
val size = nativePlan.getSerializedSize
144-
val bytes = new Array[Byte](size)
145-
val codedOutput = CodedOutputStream.newInstance(bytes)
146-
nativePlan.writeTo(codedOutput)
147-
codedOutput.checkNoSpaceLeft()
179+
val bytes = serializePlan(nativePlan)
148180
new CometExecIterator(
149181
newIterId,
150182
inputs,

0 commit comments

Comments
 (0)