Skip to content

Commit 70c2008

Browse files
heyihongcloud-fan
authored andcommitted
[SPARK-53097][CONNECT][SQL] Make WriteOperationV2 in SparkConnectPlanner side effect free
### What changes were proposed in this pull request? This PR refactors the `WriteOperationV2` handling in `SparkConnectPlanner` to make it side-effect free by separating the transformation and execution phases. ### Why are the changes needed? Make WriteOperationV2 side-effect free ### Does this PR introduce _any_ user-facing change? **No**. This is a purely internal refactoring. ### How was this patch tested? Existing tests (e.g. [ReadwriterV2ParityTests](https://github.com/apache/spark/blob/master/python/pyspark/sql/tests/connect/test_parity_readwriter.py)) ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.3.9 Closes #51813 from heyihong/SPARK-53097. Authored-by: Yihong He <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent cf3ab76 commit 70c2008

File tree

3 files changed

+73
-63
lines changed

3 files changed

+73
-63
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,13 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
8181
dataframe).foreach(responseObserver.onNext)
8282
case proto.Plan.OpTypeCase.COMMAND =>
8383
val command = request.getPlan.getCommand
84-
planner.transformCommand(command, tracker) match {
85-
case Some(plan) =>
86-
val qe =
87-
new QueryExecution(session, plan, tracker, shuffleCleanupMode = shuffleCleanupMode)
84+
planner.transformCommand(command) match {
85+
case Some(transformer) =>
86+
val qe = new QueryExecution(
87+
session,
88+
transformer(tracker),
89+
tracker,
90+
shuffleCleanupMode = shuffleCleanupMode)
8891
qe.assertCommandExecuted()
8992
executeHolder.eventsManager.postFinished()
9093
case None =>

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2645,12 +2645,12 @@ class SparkConnectPlanner(
26452645
process(command, new MockObserver())
26462646
}
26472647

2648-
def transformCommand(
2649-
command: proto.Command,
2650-
tracker: QueryPlanningTracker): Option[LogicalPlan] = {
2648+
def transformCommand(command: proto.Command): Option[QueryPlanningTracker => LogicalPlan] = {
26512649
command.getCommandTypeCase match {
26522650
case proto.Command.CommandTypeCase.WRITE_OPERATION =>
2653-
Some(transformWriteOperation(command.getWriteOperation, tracker))
2651+
Some(transformWriteOperation(command.getWriteOperation))
2652+
case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
2653+
Some(transformWriteOperationV2(command.getWriteOperationV2))
26542654
case _ =>
26552655
None
26562656
}
@@ -2659,19 +2659,20 @@ class SparkConnectPlanner(
26592659
def process(
26602660
command: proto.Command,
26612661
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
2662+
val transformerOpt = transformCommand(command)
2663+
if (transformerOpt.isDefined) {
2664+
transformAndRunCommand(transformerOpt.get)
2665+
return
2666+
}
26622667
command.getCommandTypeCase match {
26632668
case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
26642669
handleRegisterUserDefinedFunction(command.getRegisterFunction)
26652670
case proto.Command.CommandTypeCase.REGISTER_TABLE_FUNCTION =>
26662671
handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction)
26672672
case proto.Command.CommandTypeCase.REGISTER_DATA_SOURCE =>
26682673
handleRegisterUserDefinedDataSource(command.getRegisterDataSource)
2669-
case proto.Command.CommandTypeCase.WRITE_OPERATION =>
2670-
handleWriteOperation(command.getWriteOperation)
26712674
case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
26722675
handleCreateViewCommand(command.getCreateDataframeView)
2673-
case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
2674-
handleWriteOperationV2(command.getWriteOperationV2)
26752676
case proto.Command.CommandTypeCase.EXTENSION =>
26762677
handleCommandPlugin(command.getExtension)
26772678
case proto.Command.CommandTypeCase.SQL_COMMAND =>
@@ -3088,8 +3089,16 @@ class SparkConnectPlanner(
30883089
executeHolder.eventsManager.postFinished()
30893090
}
30903091

3091-
private def transformWriteOperation(
3092-
writeOperation: proto.WriteOperation,
3092+
/**
3093+
* Transforms the write operation.
3094+
*
3095+
* The input write operation contains a reference to the input plan and transforms it to the
3096+
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
3097+
* parameters of the WriteOperation into the corresponding methods calls.
3098+
*
3099+
* @param writeOperation
3100+
*/
3101+
private def transformWriteOperation(writeOperation: proto.WriteOperation)(
30933102
tracker: QueryPlanningTracker): LogicalPlan = {
30943103
// Transform the input plan into the logical plan.
30953104
val plan = transformRelation(writeOperation.getInput)
@@ -3148,41 +3157,27 @@ class SparkConnectPlanner(
31483157
}
31493158
}
31503159

3151-
private def runCommand(command: LogicalPlan, tracker: QueryPlanningTracker): Unit = {
3152-
val qe = new QueryExecution(session, command, tracker)
3153-
qe.assertCommandExecuted()
3154-
}
3155-
3156-
/**
3157-
* Transforms the write operation and executes it.
3158-
*
3159-
* The input write operation contains a reference to the input plan and transforms it to the
3160-
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
3161-
* parameters of the WriteOperation into the corresponding methods calls.
3162-
*
3163-
* @param writeOperation
3164-
*/
3165-
private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit = {
3160+
private def transformAndRunCommand(transformer: QueryPlanningTracker => LogicalPlan): Unit = {
31663161
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3167-
runCommand(transformWriteOperation(writeOperation, tracker), tracker)
3168-
3162+
val qe = new QueryExecution(session, transformer(tracker), tracker)
3163+
qe.assertCommandExecuted()
31693164
executeHolder.eventsManager.postFinished()
31703165
}
31713166

31723167
/**
3173-
* Transforms the write operation and executes it.
3168+
* Transforms the write operation.
31743169
*
31753170
* The input write operation contains a reference to the input plan and transforms it to the
31763171
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
31773172
* parameters of the WriteOperation into the corresponding methods calls.
31783173
*
31793174
* @param writeOperation
31803175
*/
3181-
private def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = {
3176+
private def transformWriteOperationV2(writeOperation: proto.WriteOperationV2)(
3177+
tracker: QueryPlanningTracker): LogicalPlan = {
31823178
// Transform the input plan into the logical plan.
31833179
val plan = transformRelation(writeOperation.getInput)
31843180
// And create a Dataset from the plan.
3185-
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
31863181
val dataset = Dataset.ofRows(session, plan, tracker)
31873182

31883183
val w = dataset.writeTo(table = writeOperation.getTableName)
@@ -3213,32 +3208,28 @@ class SparkConnectPlanner(
32133208
writeOperation.getMode match {
32143209
case proto.WriteOperationV2.Mode.MODE_CREATE =>
32153210
if (writeOperation.hasProvider) {
3216-
w.using(writeOperation.getProvider).create()
3217-
} else {
3218-
w.create()
3211+
w.using(writeOperation.getProvider)
32193212
}
3213+
w.createCommand()
32203214
case proto.WriteOperationV2.Mode.MODE_OVERWRITE =>
3221-
w.overwrite(Column(transformExpression(writeOperation.getOverwriteCondition)))
3215+
w.overwriteCommand(Column(transformExpression(writeOperation.getOverwriteCondition)))
32223216
case proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS =>
3223-
w.overwritePartitions()
3217+
w.overwritePartitionsCommand()
32243218
case proto.WriteOperationV2.Mode.MODE_APPEND =>
3225-
w.append()
3219+
w.appendCommand()
32263220
case proto.WriteOperationV2.Mode.MODE_REPLACE =>
32273221
if (writeOperation.hasProvider) {
3228-
w.using(writeOperation.getProvider).replace()
3229-
} else {
3230-
w.replace()
3222+
w.using(writeOperation.getProvider)
32313223
}
3224+
w.replaceCommand(orCreate = false)
32323225
case proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE =>
32333226
if (writeOperation.hasProvider) {
3234-
w.using(writeOperation.getProvider).createOrReplace()
3235-
} else {
3236-
w.createOrReplace()
3227+
w.using(writeOperation.getProvider)
32373228
}
3229+
w.replaceCommand(orCreate = true)
32383230
case other =>
32393231
throw InvalidInputErrors.invalidEnum(other)
32403232
}
3241-
executeHolder.eventsManager.postFinished()
32423233
}
32433234

32443235
private def handleWriteStreamOperationStart(

sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,17 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
148148

149149
/** @inheritdoc */
150150
override def create(): Unit = {
151-
runCommand(
152-
CreateTableAsSelect(
153-
UnresolvedIdentifier(tableName),
154-
partitioning.getOrElse(Seq.empty) ++ clustering,
155-
logicalPlan,
156-
buildTableSpec(),
157-
options.toMap,
158-
false))
151+
runCommand(createCommand())
152+
}
153+
154+
private[sql] def createCommand(): LogicalPlan = {
155+
CreateTableAsSelect(
156+
UnresolvedIdentifier(tableName),
157+
partitioning.getOrElse(Seq.empty) ++ clustering,
158+
logicalPlan,
159+
buildTableSpec(),
160+
options.toMap,
161+
false)
159162
}
160163

161164
private def buildTableSpec(): UnresolvedTableSpec = {
@@ -186,28 +189,37 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
186189
/** @inheritdoc */
187190
@throws(classOf[NoSuchTableException])
188191
def append(): Unit = {
189-
val append = AppendData.byName(
192+
runCommand(appendCommand())
193+
}
194+
195+
private[sql] def appendCommand(): LogicalPlan = {
196+
AppendData.byName(
190197
UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)),
191198
logicalPlan, options.toMap)
192-
runCommand(append)
193199
}
194200

195201
/** @inheritdoc */
196202
@throws(classOf[NoSuchTableException])
197203
def overwrite(condition: Column): Unit = {
198-
val overwrite = OverwriteByExpression.byName(
204+
runCommand(overwriteCommand(condition))
205+
}
206+
207+
private[sql] def overwriteCommand(condition: Column): LogicalPlan = {
208+
OverwriteByExpression.byName(
199209
UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
200210
logicalPlan, expression(condition), options.toMap)
201-
runCommand(overwrite)
202211
}
203212

204213
/** @inheritdoc */
205214
@throws(classOf[NoSuchTableException])
206215
def overwritePartitions(): Unit = {
207-
val dynamicOverwrite = OverwritePartitionsDynamic.byName(
216+
runCommand(overwritePartitionsCommand())
217+
}
218+
219+
private[sql] def overwritePartitionsCommand(): LogicalPlan = {
220+
OverwritePartitionsDynamic.byName(
208221
UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
209222
logicalPlan, options.toMap)
210-
runCommand(dynamicOverwrite)
211223
}
212224

213225
/**
@@ -220,13 +232,17 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
220232
}
221233

222234
private def internalReplace(orCreate: Boolean): Unit = {
223-
runCommand(ReplaceTableAsSelect(
235+
runCommand(replaceCommand(orCreate))
236+
}
237+
238+
private[sql] def replaceCommand(orCreate: Boolean): LogicalPlan = {
239+
ReplaceTableAsSelect(
224240
UnresolvedIdentifier(tableName),
225241
partitioning.getOrElse(Seq.empty) ++ clustering,
226242
logicalPlan,
227243
buildTableSpec(),
228244
writeOptions = options.toMap,
229-
orCreate = orCreate))
245+
orCreate = orCreate)
230246
}
231247
}
232248

0 commit comments

Comments
 (0)