@@ -35,7 +35,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
35
35
import org .apache .spark .annotation .{DeveloperApi , Since }
36
36
import org .apache .spark .api .python .{PythonEvalType , SimplePythonFunction }
37
37
import org .apache .spark .connect .proto
38
- import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , SqlCommand , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
38
+ import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
39
39
import org .apache .spark .connect .proto .ExecutePlanResponse .SqlCommandResult
40
40
import org .apache .spark .connect .proto .Parse .ParseFormat
41
41
import org .apache .spark .connect .proto .StreamingQueryManagerCommandResult .StreamingQueryInstance
@@ -69,7 +69,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
69
69
import org .apache .spark .sql .connect .service .{ExecuteHolder , SessionHolder , SparkConnectService }
70
70
import org .apache .spark .sql .connect .utils .MetricGenerator
71
71
import org .apache .spark .sql .errors .QueryCompilationErrors
72
- import org .apache .spark .sql .execution .QueryExecution
72
+ import org .apache .spark .sql .execution .{ QueryExecution , ShuffleCleanupMode }
73
73
import org .apache .spark .sql .execution .aggregate .{ScalaAggregator , TypedAggregateExpression }
74
74
import org .apache .spark .sql .execution .arrow .ArrowConverters
75
75
import org .apache .spark .sql .execution .command .{CreateViewCommand , ExternalCommandExecutor }
@@ -343,13 +343,6 @@ class SparkConnectPlanner(
343
343
}
344
344
}
345
345
346
- private def transformSqlWithRefs (query : proto.WithRelations ): LogicalPlan = {
347
- if (! isValidSQLWithRefs(query)) {
348
- throw InvalidInputErrors .invalidSQLWithReferences(query)
349
- }
350
- executeSQLWithRefs(query).logicalPlan
351
- }
352
-
353
346
private def transformSubqueryAlias (alias : proto.SubqueryAlias ): LogicalPlan = {
354
347
val aliasIdentifier =
355
348
if (alias.getQualifierCount > 0 ) {
@@ -2651,6 +2644,8 @@ class SparkConnectPlanner(
2651
2644
Some (transformWriteOperation(command.getWriteOperation))
2652
2645
case proto.Command .CommandTypeCase .WRITE_OPERATION_V2 =>
2653
2646
Some (transformWriteOperationV2(command.getWriteOperationV2))
2647
+ case proto.Command .CommandTypeCase .SQL_COMMAND =>
2648
+ Some (transformSqlCommand(command.getSqlCommand))
2654
2649
case _ =>
2655
2650
None
2656
2651
}
@@ -2661,7 +2656,8 @@ class SparkConnectPlanner(
2661
2656
responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2662
2657
val transformerOpt = transformCommand(command)
2663
2658
if (transformerOpt.isDefined) {
2664
- transformAndRunCommand(transformerOpt.get)
2659
+ val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2660
+ runCommand(transformerOpt.get(tracker), tracker, responseObserver, command)
2665
2661
return
2666
2662
}
2667
2663
command.getCommandTypeCase match {
@@ -2675,8 +2671,6 @@ class SparkConnectPlanner(
2675
2671
handleCreateViewCommand(command.getCreateDataframeView)
2676
2672
case proto.Command .CommandTypeCase .EXTENSION =>
2677
2673
handleCommandPlugin(command.getExtension)
2678
- case proto.Command .CommandTypeCase .SQL_COMMAND =>
2679
- handleSqlCommand(command.getSqlCommand, responseObserver)
2680
2674
case proto.Command .CommandTypeCase .WRITE_STREAM_OPERATION_START =>
2681
2675
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
2682
2676
case proto.Command .CommandTypeCase .STREAMING_QUERY_COMMAND =>
@@ -2781,12 +2775,8 @@ class SparkConnectPlanner(
2781
2775
.build())
2782
2776
}
2783
2777
2784
- private def handleSqlCommand (
2785
- command : SqlCommand ,
2786
- responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2787
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2788
-
2789
- val relation = if (command.hasInput) {
2778
+ private def getRelationFromSQLCommand (command : proto.SqlCommand ): proto.Relation = {
2779
+ if (command.hasInput) {
2790
2780
command.getInput
2791
2781
} else {
2792
2782
// for backward compatibility
@@ -2803,15 +2793,33 @@ class SparkConnectPlanner(
2803
2793
.build())
2804
2794
.build()
2805
2795
}
2796
+ }
2797
+
2798
+ private def transformSqlCommand (command : proto.SqlCommand )(
2799
+ tracker : QueryPlanningTracker ): LogicalPlan = {
2800
+ val relation = getRelationFromSQLCommand(command)
2806
2801
2807
- val df = relation.getRelTypeCase match {
2802
+ relation.getRelTypeCase match {
2808
2803
case proto.Relation .RelTypeCase .SQL =>
2809
- executeSQL (relation.getSql, tracker)
2804
+ transformSQL (relation.getSql, tracker)
2810
2805
case proto.Relation .RelTypeCase .WITH_RELATIONS =>
2811
- executeSQLWithRefs (relation.getWithRelations, tracker)
2806
+ transformSQLWithRefs (relation.getWithRelations, tracker)
2812
2807
case other =>
2813
2808
throw InvalidInputErrors .sqlCommandExpectsSqlOrWithRelations(other)
2814
2809
}
2810
+ }
2811
+
2812
+ private def runSQLCommand (
2813
+ command : LogicalPlan ,
2814
+ tracker : QueryPlanningTracker ,
2815
+ responseObserver : StreamObserver [ExecutePlanResponse ],
2816
+ protoCommand : proto.Command ,
2817
+ shuffleCleanupMode : Option [ShuffleCleanupMode ]): Unit = {
2818
+ val df = if (shuffleCleanupMode.isDefined) {
2819
+ Dataset .ofRows(session, command, tracker, shuffleCleanupMode.get)
2820
+ } else {
2821
+ Dataset .ofRows(session, command, tracker)
2822
+ }
2815
2823
2816
2824
// Check if command or SQL Script has been executed.
2817
2825
val isCommand = df.queryExecution.commandExecuted.isInstanceOf [CommandResult ]
@@ -2867,7 +2875,7 @@ class SparkConnectPlanner(
2867
2875
} else {
2868
2876
// No execution triggered for relations. Manually set ready
2869
2877
tracker.setReadyForExecution()
2870
- result.setRelation(relation )
2878
+ result.setRelation(getRelationFromSQLCommand(protoCommand.getSqlCommand) )
2871
2879
}
2872
2880
executeHolder.eventsManager.postFinished(Some (rows.size))
2873
2881
// Exactly one SQL Command Result Batch
@@ -2909,9 +2917,9 @@ class SparkConnectPlanner(
2909
2917
true
2910
2918
}
2911
2919
2912
- private def executeSQLWithRefs (
2920
+ private def transformSQLWithRefs (
2913
2921
query : proto.WithRelations ,
2914
- tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2922
+ tracker : QueryPlanningTracker ) : LogicalPlan = {
2915
2923
if (! isValidSQLWithRefs(query)) {
2916
2924
throw InvalidInputErrors .invalidSQLWithReferences(query)
2917
2925
}
@@ -2925,7 +2933,7 @@ class SparkConnectPlanner(
2925
2933
.ofRows(session, transformRelation(ref.getSubqueryAlias.getInput))
2926
2934
.createOrReplaceTempView(ref.getSubqueryAlias.getAlias)
2927
2935
}
2928
- executeSQL (sql, tracker)
2936
+ transformSQL (sql, tracker)
2929
2937
} finally {
2930
2938
// drop all temporary views
2931
2939
query.getReferencesList.asScala.foreach { ref =>
@@ -2935,36 +2943,46 @@ class SparkConnectPlanner(
2935
2943
}
2936
2944
}
2937
2945
2938
- private def executeSQL (
2939
- sql : proto.SQL ,
2946
+ private def executeSQLWithRefs (
2947
+ query : proto.WithRelations ,
2940
2948
tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2949
+ Dataset .ofRows(session, transformSQLWithRefs(query, tracker), tracker)
2950
+ }
2951
+
2952
+ private def transformSQL (sql : proto.SQL , tracker : QueryPlanningTracker ): LogicalPlan = {
2941
2953
// Eagerly execute commands of the provided SQL string.
2942
2954
val args = sql.getArgsMap
2943
2955
val namedArguments = sql.getNamedArgumentsMap
2944
2956
val posArgs = sql.getPosArgsList
2945
2957
val posArguments = sql.getPosArgumentsList
2946
2958
if (! namedArguments.isEmpty) {
2947
- session.sql (
2959
+ session.sqlParsedPlan (
2948
2960
sql.getQuery,
2949
2961
namedArguments.asScala.toMap.transform((_, e) => Column (transformExpression(e))),
2950
2962
tracker)
2951
2963
} else if (! posArguments.isEmpty) {
2952
- session.sql (
2964
+ session.sqlParsedPlan (
2953
2965
sql.getQuery,
2954
2966
posArguments.asScala.map(e => Column (transformExpression(e))).toArray,
2955
2967
tracker)
2956
2968
} else if (! args.isEmpty) {
2957
- session.sql (
2969
+ session.sqlParsedPlan (
2958
2970
sql.getQuery,
2959
2971
args.asScala.toMap.transform((_, v) => transformLiteral(v)),
2960
2972
tracker)
2961
2973
} else if (! posArgs.isEmpty) {
2962
- session.sql (sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2974
+ session.sqlParsedPlan (sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2963
2975
} else {
2964
- session.sql (sql.getQuery, Map .empty[String , Any ], tracker)
2976
+ session.sqlParsedPlan (sql.getQuery, Map .empty[String , Any ], tracker)
2965
2977
}
2966
2978
}
2967
2979
2980
+ private def executeSQL (
2981
+ sql : proto.SQL ,
2982
+ tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2983
+ Dataset .ofRows(session, transformSQL(sql, tracker), tracker)
2984
+ }
2985
+
2968
2986
private def handleRegisterUserDefinedFunction (
2969
2987
fun : proto.CommonInlineUserDefinedFunction ): Unit = {
2970
2988
fun.getFunctionCase match {
@@ -3157,11 +3175,27 @@ class SparkConnectPlanner(
3157
3175
}
3158
3176
}
3159
3177
3160
- private def transformAndRunCommand (transformer : QueryPlanningTracker => LogicalPlan ): Unit = {
3161
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3162
- val qe = new QueryExecution (session, transformer(tracker), tracker)
3163
- qe.assertCommandExecuted()
3164
- executeHolder.eventsManager.postFinished()
3178
+ private [connect] def runCommand (
3179
+ command : LogicalPlan ,
3180
+ tracker : QueryPlanningTracker ,
3181
+ responseObserver : StreamObserver [ExecutePlanResponse ],
3182
+ protoCommand : proto.Command ,
3183
+ shuffleCleanupMode : Option [ShuffleCleanupMode ] = None ): Unit = {
3184
+ if (protoCommand.getCommandTypeCase == proto.Command .CommandTypeCase .SQL_COMMAND ) {
3185
+ runSQLCommand(command, tracker, responseObserver, protoCommand, shuffleCleanupMode)
3186
+ } else {
3187
+ val qe = if (shuffleCleanupMode.isDefined) {
3188
+ new QueryExecution (
3189
+ session,
3190
+ command,
3191
+ tracker = tracker,
3192
+ shuffleCleanupMode = shuffleCleanupMode.get)
3193
+ } else {
3194
+ new QueryExecution (session, command, tracker = tracker)
3195
+ }
3196
+ qe.assertCommandExecuted()
3197
+ executeHolder.eventsManager.postFinished()
3198
+ }
3165
3199
}
3166
3200
3167
3201
/**
@@ -4105,7 +4139,7 @@ class SparkConnectPlanner(
4105
4139
4106
4140
private def transformWithRelations (getWithRelations : proto.WithRelations ): LogicalPlan = {
4107
4141
if (isValidSQLWithRefs(getWithRelations)) {
4108
- transformSqlWithRefs (getWithRelations)
4142
+ executeSQLWithRefs (getWithRelations).logicalPlan
4109
4143
} else {
4110
4144
// Wrap the plan to keep the original planId.
4111
4145
val plan = Project (Seq (UnresolvedStar (None )), transformRelation(getWithRelations.getRoot))
0 commit comments