@@ -35,7 +35,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
35
35
import org .apache .spark .annotation .{DeveloperApi , Since }
36
36
import org .apache .spark .api .python .{PythonEvalType , SimplePythonFunction }
37
37
import org .apache .spark .connect .proto
38
- import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , SqlCommand , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
38
+ import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
39
39
import org .apache .spark .connect .proto .ExecutePlanResponse .SqlCommandResult
40
40
import org .apache .spark .connect .proto .Parse .ParseFormat
41
41
import org .apache .spark .connect .proto .StreamingQueryManagerCommandResult .StreamingQueryInstance
@@ -69,7 +69,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
69
69
import org .apache .spark .sql .connect .service .{ExecuteHolder , SessionHolder , SparkConnectService }
70
70
import org .apache .spark .sql .connect .utils .MetricGenerator
71
71
import org .apache .spark .sql .errors .QueryCompilationErrors
72
- import org .apache .spark .sql .execution .QueryExecution
72
+ import org .apache .spark .sql .execution .{ QueryExecution , ShuffleCleanupMode }
73
73
import org .apache .spark .sql .execution .aggregate .{ScalaAggregator , TypedAggregateExpression }
74
74
import org .apache .spark .sql .execution .arrow .ArrowConverters
75
75
import org .apache .spark .sql .execution .command .{CreateViewCommand , ExternalCommandExecutor }
@@ -2651,6 +2651,8 @@ class SparkConnectPlanner(
2651
2651
Some (transformWriteOperation(command.getWriteOperation))
2652
2652
case proto.Command .CommandTypeCase .WRITE_OPERATION_V2 =>
2653
2653
Some (transformWriteOperationV2(command.getWriteOperationV2))
2654
+ case proto.Command .CommandTypeCase .SQL_COMMAND =>
2655
+ Some (transformSqlCommand(command.getSqlCommand))
2654
2656
case _ =>
2655
2657
None
2656
2658
}
@@ -2661,7 +2663,8 @@ class SparkConnectPlanner(
2661
2663
responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2662
2664
val transformerOpt = transformCommand(command)
2663
2665
if (transformerOpt.isDefined) {
2664
- transformAndRunCommand(transformerOpt.get)
2666
+ val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2667
+ runCommand(transformerOpt.get(tracker), tracker, responseObserver, command)
2665
2668
return
2666
2669
}
2667
2670
command.getCommandTypeCase match {
@@ -2675,8 +2678,6 @@ class SparkConnectPlanner(
2675
2678
handleCreateViewCommand(command.getCreateDataframeView)
2676
2679
case proto.Command .CommandTypeCase .EXTENSION =>
2677
2680
handleCommandPlugin(command.getExtension)
2678
- case proto.Command .CommandTypeCase .SQL_COMMAND =>
2679
- handleSqlCommand(command.getSqlCommand, responseObserver)
2680
2681
case proto.Command .CommandTypeCase .WRITE_STREAM_OPERATION_START =>
2681
2682
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
2682
2683
case proto.Command .CommandTypeCase .STREAMING_QUERY_COMMAND =>
@@ -2781,12 +2782,8 @@ class SparkConnectPlanner(
2781
2782
.build())
2782
2783
}
2783
2784
2784
- private def handleSqlCommand (
2785
- command : SqlCommand ,
2786
- responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2787
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2788
-
2789
- val relation = if (command.hasInput) {
2785
+ private def getRelationFromSQLCommand (command : proto.SqlCommand ): proto.Relation = {
2786
+ if (command.hasInput) {
2790
2787
command.getInput
2791
2788
} else {
2792
2789
// for backward compatibility
@@ -2803,15 +2800,33 @@ class SparkConnectPlanner(
2803
2800
.build())
2804
2801
.build()
2805
2802
}
2803
+ }
2804
+
2805
+ private def transformSqlCommand (command : proto.SqlCommand )(
2806
+ tracker : QueryPlanningTracker ): LogicalPlan = {
2807
+ val relation = getRelationFromSQLCommand(command)
2806
2808
2807
- val df = relation.getRelTypeCase match {
2809
+ relation.getRelTypeCase match {
2808
2810
case proto.Relation .RelTypeCase .SQL =>
2809
- executeSQL (relation.getSql, tracker)
2811
+ transformSQL (relation.getSql, tracker)
2810
2812
case proto.Relation .RelTypeCase .WITH_RELATIONS =>
2811
- executeSQLWithRefs (relation.getWithRelations, tracker)
2813
+ transformSQLWithRefs (relation.getWithRelations, tracker)
2812
2814
case other =>
2813
2815
throw InvalidInputErrors .sqlCommandExpectsSqlOrWithRelations(other)
2814
2816
}
2817
+ }
2818
+
2819
+ private def runSQLCommand (
2820
+ command : LogicalPlan ,
2821
+ tracker : QueryPlanningTracker ,
2822
+ responseObserver : StreamObserver [ExecutePlanResponse ],
2823
+ protoCommand : proto.Command ,
2824
+ shuffleCleanupMode : Option [ShuffleCleanupMode ]): Unit = {
2825
+ val df = if (shuffleCleanupMode.isDefined) {
2826
+ Dataset .ofRows(session, command, tracker, shuffleCleanupMode.get)
2827
+ } else {
2828
+ Dataset .ofRows(session, command, tracker)
2829
+ }
2815
2830
2816
2831
// Check if command or SQL Script has been executed.
2817
2832
val isCommand = df.queryExecution.commandExecuted.isInstanceOf [CommandResult ]
@@ -2867,7 +2882,7 @@ class SparkConnectPlanner(
2867
2882
} else {
2868
2883
// No execution triggered for relations. Manually set ready
2869
2884
tracker.setReadyForExecution()
2870
- result.setRelation(relation )
2885
+ result.setRelation(getRelationFromSQLCommand(protoCommand.getSqlCommand) )
2871
2886
}
2872
2887
executeHolder.eventsManager.postFinished(Some (rows.size))
2873
2888
// Exactly one SQL Command Result Batch
@@ -2909,9 +2924,9 @@ class SparkConnectPlanner(
2909
2924
true
2910
2925
}
2911
2926
2912
- private def executeSQLWithRefs (
2927
+ private def transformSQLWithRefs (
2913
2928
query : proto.WithRelations ,
2914
- tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2929
+ tracker : QueryPlanningTracker = new QueryPlanningTracker ): LogicalPlan = {
2915
2930
if (! isValidSQLWithRefs(query)) {
2916
2931
throw InvalidInputErrors .invalidSQLWithReferences(query)
2917
2932
}
@@ -2925,7 +2940,7 @@ class SparkConnectPlanner(
2925
2940
.ofRows(session, transformRelation(ref.getSubqueryAlias.getInput))
2926
2941
.createOrReplaceTempView(ref.getSubqueryAlias.getAlias)
2927
2942
}
2928
- executeSQL (sql, tracker)
2943
+ transformSQL (sql, tracker)
2929
2944
} finally {
2930
2945
// drop all temporary views
2931
2946
query.getReferencesList.asScala.foreach { ref =>
@@ -2935,36 +2950,48 @@ class SparkConnectPlanner(
2935
2950
}
2936
2951
}
2937
2952
2938
- private def executeSQL (
2939
- sql : proto.SQL ,
2953
+ private def executeSQLWithRefs (
2954
+ query : proto.WithRelations ,
2940
2955
tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2956
+ Dataset .ofRows(session, transformSQLWithRefs(query, tracker), tracker)
2957
+ }
2958
+
2959
+ private def transformSQL (
2960
+ sql : proto.SQL ,
2961
+ tracker : QueryPlanningTracker = new QueryPlanningTracker ): LogicalPlan = {
2941
2962
// Eagerly execute commands of the provided SQL string.
2942
2963
val args = sql.getArgsMap
2943
2964
val namedArguments = sql.getNamedArgumentsMap
2944
2965
val posArgs = sql.getPosArgsList
2945
2966
val posArguments = sql.getPosArgumentsList
2946
2967
if (! namedArguments.isEmpty) {
2947
- session.sql (
2968
+ session.sqlParsedPlan (
2948
2969
sql.getQuery,
2949
2970
namedArguments.asScala.toMap.transform((_, e) => Column (transformExpression(e))),
2950
2971
tracker)
2951
2972
} else if (! posArguments.isEmpty) {
2952
- session.sql (
2973
+ session.sqlParsedPlan (
2953
2974
sql.getQuery,
2954
2975
posArguments.asScala.map(e => Column (transformExpression(e))).toArray,
2955
2976
tracker)
2956
2977
} else if (! args.isEmpty) {
2957
- session.sql (
2978
+ session.sqlParsedPlan (
2958
2979
sql.getQuery,
2959
2980
args.asScala.toMap.transform((_, v) => transformLiteral(v)),
2960
2981
tracker)
2961
2982
} else if (! posArgs.isEmpty) {
2962
- session.sql (sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2983
+ session.sqlParsedPlan (sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2963
2984
} else {
2964
- session.sql (sql.getQuery, Map .empty[String , Any ], tracker)
2985
+ session.sqlParsedPlan (sql.getQuery, Map .empty[String , Any ], tracker)
2965
2986
}
2966
2987
}
2967
2988
2989
+ private def executeSQL (
2990
+ sql : proto.SQL ,
2991
+ tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2992
+ Dataset .ofRows(session, transformSQL(sql, tracker), tracker)
2993
+ }
2994
+
2968
2995
private def handleRegisterUserDefinedFunction (
2969
2996
fun : proto.CommonInlineUserDefinedFunction ): Unit = {
2970
2997
fun.getFunctionCase match {
@@ -3157,11 +3184,27 @@ class SparkConnectPlanner(
3157
3184
}
3158
3185
}
3159
3186
3160
- private def transformAndRunCommand (transformer : QueryPlanningTracker => LogicalPlan ): Unit = {
3161
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3162
- val qe = new QueryExecution (session, transformer(tracker), tracker)
3163
- qe.assertCommandExecuted()
3164
- executeHolder.eventsManager.postFinished()
3187
+ private [connect] def runCommand (
3188
+ command : LogicalPlan ,
3189
+ tracker : QueryPlanningTracker ,
3190
+ responseObserver : StreamObserver [ExecutePlanResponse ],
3191
+ protoCommand : proto.Command ,
3192
+ shuffleCleanupMode : Option [ShuffleCleanupMode ] = None ): Unit = {
3193
+ if (protoCommand.getCommandTypeCase == proto.Command .CommandTypeCase .SQL_COMMAND ) {
3194
+ runSQLCommand(command, tracker, responseObserver, protoCommand, shuffleCleanupMode)
3195
+ } else {
3196
+ val qe = if (shuffleCleanupMode.isDefined) {
3197
+ new QueryExecution (
3198
+ session,
3199
+ command,
3200
+ tracker = tracker,
3201
+ shuffleCleanupMode = shuffleCleanupMode.get)
3202
+ } else {
3203
+ new QueryExecution (session, command, tracker = tracker)
3204
+ }
3205
+ qe.assertCommandExecuted()
3206
+ executeHolder.eventsManager.postFinished()
3207
+ }
3165
3208
}
3166
3209
3167
3210
/**
0 commit comments