@@ -2645,12 +2645,12 @@ class SparkConnectPlanner(
2645
2645
process(command, new MockObserver ())
2646
2646
}
2647
2647
2648
- def transformCommand (
2649
- command : proto.Command ,
2650
- tracker : QueryPlanningTracker ): Option [LogicalPlan ] = {
2648
+ def transformCommand (command : proto.Command ): Option [QueryPlanningTracker => LogicalPlan ] = {
2651
2649
command.getCommandTypeCase match {
2652
2650
case proto.Command .CommandTypeCase .WRITE_OPERATION =>
2653
- Some (transformWriteOperation(command.getWriteOperation, tracker))
2651
+ Some (transformWriteOperation(command.getWriteOperation))
2652
+ case proto.Command .CommandTypeCase .WRITE_OPERATION_V2 =>
2653
+ Some (transformWriteOperationV2(command.getWriteOperationV2))
2654
2654
case _ =>
2655
2655
None
2656
2656
}
@@ -2659,19 +2659,20 @@ class SparkConnectPlanner(
2659
2659
def process (
2660
2660
command : proto.Command ,
2661
2661
responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2662
+ val transformerOpt = transformCommand(command)
2663
+ if (transformerOpt.isDefined) {
2664
+ transformAndRunCommand(transformerOpt.get)
2665
+ return
2666
+ }
2662
2667
command.getCommandTypeCase match {
2663
2668
case proto.Command .CommandTypeCase .REGISTER_FUNCTION =>
2664
2669
handleRegisterUserDefinedFunction(command.getRegisterFunction)
2665
2670
case proto.Command .CommandTypeCase .REGISTER_TABLE_FUNCTION =>
2666
2671
handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction)
2667
2672
case proto.Command .CommandTypeCase .REGISTER_DATA_SOURCE =>
2668
2673
handleRegisterUserDefinedDataSource(command.getRegisterDataSource)
2669
- case proto.Command .CommandTypeCase .WRITE_OPERATION =>
2670
- handleWriteOperation(command.getWriteOperation)
2671
2674
case proto.Command .CommandTypeCase .CREATE_DATAFRAME_VIEW =>
2672
2675
handleCreateViewCommand(command.getCreateDataframeView)
2673
- case proto.Command .CommandTypeCase .WRITE_OPERATION_V2 =>
2674
- handleWriteOperationV2(command.getWriteOperationV2)
2675
2676
case proto.Command .CommandTypeCase .EXTENSION =>
2676
2677
handleCommandPlugin(command.getExtension)
2677
2678
case proto.Command .CommandTypeCase .SQL_COMMAND =>
@@ -3088,8 +3089,16 @@ class SparkConnectPlanner(
3088
3089
executeHolder.eventsManager.postFinished()
3089
3090
}
3090
3091
3091
- private def transformWriteOperation (
3092
- writeOperation : proto.WriteOperation ,
3092
+ /**
3093
+ * Transforms the write operation.
3094
+ *
3095
+ * The input write operation contains a reference to the input plan and transforms it to the
3096
+ * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
3097
+ * parameters of the WriteOperation into the corresponding methods calls.
3098
+ *
3099
+ * @param writeOperation
3100
+ */
3101
+ private def transformWriteOperation (writeOperation : proto.WriteOperation )(
3093
3102
tracker : QueryPlanningTracker ): LogicalPlan = {
3094
3103
// Transform the input plan into the logical plan.
3095
3104
val plan = transformRelation(writeOperation.getInput)
@@ -3148,41 +3157,27 @@ class SparkConnectPlanner(
3148
3157
}
3149
3158
}
3150
3159
3151
- private def runCommand (command : LogicalPlan , tracker : QueryPlanningTracker ): Unit = {
3152
- val qe = new QueryExecution (session, command, tracker)
3153
- qe.assertCommandExecuted()
3154
- }
3155
-
3156
- /**
3157
- * Transforms the write operation and executes it.
3158
- *
3159
- * The input write operation contains a reference to the input plan and transforms it to the
3160
- * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
3161
- * parameters of the WriteOperation into the corresponding methods calls.
3162
- *
3163
- * @param writeOperation
3164
- */
3165
- private def handleWriteOperation (writeOperation : proto.WriteOperation ): Unit = {
3160
+ private def transformAndRunCommand (transformer : QueryPlanningTracker => LogicalPlan ): Unit = {
3166
3161
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3167
- runCommand(transformWriteOperation(writeOperation, tracker), tracker)
3168
-
3162
+ val qe = new QueryExecution (session, transformer( tracker), tracker)
3163
+ qe.assertCommandExecuted()
3169
3164
executeHolder.eventsManager.postFinished()
3170
3165
}
3171
3166
3172
3167
/**
3173
- * Transforms the write operation and executes it .
3168
+ * Transforms the write operation.
3174
3169
*
3175
3170
* The input write operation contains a reference to the input plan and transforms it to the
3176
3171
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
3177
3172
* parameters of the WriteOperation into the corresponding methods calls.
3178
3173
*
3179
3174
* @param writeOperation
3180
3175
*/
3181
- private def handleWriteOperationV2 (writeOperation : proto.WriteOperationV2 ): Unit = {
3176
+ private def transformWriteOperationV2 (writeOperation : proto.WriteOperationV2 )(
3177
+ tracker : QueryPlanningTracker ): LogicalPlan = {
3182
3178
// Transform the input plan into the logical plan.
3183
3179
val plan = transformRelation(writeOperation.getInput)
3184
3180
// And create a Dataset from the plan.
3185
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3186
3181
val dataset = Dataset .ofRows(session, plan, tracker)
3187
3182
3188
3183
val w = dataset.writeTo(table = writeOperation.getTableName)
@@ -3213,32 +3208,28 @@ class SparkConnectPlanner(
3213
3208
writeOperation.getMode match {
3214
3209
case proto.WriteOperationV2 .Mode .MODE_CREATE =>
3215
3210
if (writeOperation.hasProvider) {
3216
- w.using(writeOperation.getProvider).create()
3217
- } else {
3218
- w.create()
3211
+ w.using(writeOperation.getProvider)
3219
3212
}
3213
+ w.createCommand()
3220
3214
case proto.WriteOperationV2 .Mode .MODE_OVERWRITE =>
3221
- w.overwrite (Column (transformExpression(writeOperation.getOverwriteCondition)))
3215
+ w.overwriteCommand (Column (transformExpression(writeOperation.getOverwriteCondition)))
3222
3216
case proto.WriteOperationV2 .Mode .MODE_OVERWRITE_PARTITIONS =>
3223
- w.overwritePartitions ()
3217
+ w.overwritePartitionsCommand ()
3224
3218
case proto.WriteOperationV2 .Mode .MODE_APPEND =>
3225
- w.append ()
3219
+ w.appendCommand ()
3226
3220
case proto.WriteOperationV2 .Mode .MODE_REPLACE =>
3227
3221
if (writeOperation.hasProvider) {
3228
- w.using(writeOperation.getProvider).replace()
3229
- } else {
3230
- w.replace()
3222
+ w.using(writeOperation.getProvider)
3231
3223
}
3224
+ w.replaceCommand(orCreate = false )
3232
3225
case proto.WriteOperationV2 .Mode .MODE_CREATE_OR_REPLACE =>
3233
3226
if (writeOperation.hasProvider) {
3234
- w.using(writeOperation.getProvider).createOrReplace()
3235
- } else {
3236
- w.createOrReplace()
3227
+ w.using(writeOperation.getProvider)
3237
3228
}
3229
+ w.replaceCommand(orCreate = true )
3238
3230
case other =>
3239
3231
throw InvalidInputErrors .invalidEnum(other)
3240
3232
}
3241
- executeHolder.eventsManager.postFinished()
3242
3233
}
3243
3234
3244
3235
private def handleWriteStreamOperationStart (
0 commit comments