@@ -34,7 +34,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
34
34
import org .apache .spark .annotation .{DeveloperApi , Since }
35
35
import org .apache .spark .api .python .{PythonEvalType , SimplePythonFunction }
36
36
import org .apache .spark .connect .proto
37
- import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , SqlCommand , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
37
+ import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
38
38
import org .apache .spark .connect .proto .ExecutePlanResponse .SqlCommandResult
39
39
import org .apache .spark .connect .proto .Parse .ParseFormat
40
40
import org .apache .spark .connect .proto .StreamingQueryManagerCommandResult .StreamingQueryInstance
@@ -68,7 +68,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
68
68
import org .apache .spark .sql .connect .service .{ExecuteHolder , SessionHolder , SparkConnectService }
69
69
import org .apache .spark .sql .connect .utils .MetricGenerator
70
70
import org .apache .spark .sql .errors .QueryCompilationErrors
71
- import org .apache .spark .sql .execution .QueryExecution
71
+ import org .apache .spark .sql .execution .{ QueryExecution , ShuffleCleanupMode }
72
72
import org .apache .spark .sql .execution .aggregate .{ScalaAggregator , TypedAggregateExpression }
73
73
import org .apache .spark .sql .execution .arrow .ArrowConverters
74
74
import org .apache .spark .sql .execution .command .{CreateViewCommand , ExternalCommandExecutor }
@@ -342,13 +342,6 @@ class SparkConnectPlanner(
342
342
}
343
343
}
344
344
345
- private def transformSqlWithRefs (query : proto.WithRelations ): LogicalPlan = {
346
- if (! isValidSQLWithRefs(query)) {
347
- throw InvalidInputErrors .invalidSQLWithReferences(query)
348
- }
349
- executeSQLWithRefs(query).logicalPlan
350
- }
351
-
352
345
private def transformSubqueryAlias (alias : proto.SubqueryAlias ): LogicalPlan = {
353
346
val aliasIdentifier =
354
347
if (alias.getQualifierCount > 0 ) {
@@ -2650,6 +2643,8 @@ class SparkConnectPlanner(
2650
2643
Some (transformWriteOperation(command.getWriteOperation))
2651
2644
case proto.Command .CommandTypeCase .WRITE_OPERATION_V2 =>
2652
2645
Some (transformWriteOperationV2(command.getWriteOperationV2))
2646
+ case proto.Command .CommandTypeCase .SQL_COMMAND =>
2647
+ Some (transformSqlCommand(command.getSqlCommand))
2653
2648
case _ =>
2654
2649
None
2655
2650
}
@@ -2660,7 +2655,8 @@ class SparkConnectPlanner(
2660
2655
responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2661
2656
val transformerOpt = transformCommand(command)
2662
2657
if (transformerOpt.isDefined) {
2663
- transformAndRunCommand(transformerOpt.get)
2658
+ val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2659
+ runCommand(transformerOpt.get(tracker), tracker, responseObserver, command)
2664
2660
return
2665
2661
}
2666
2662
command.getCommandTypeCase match {
@@ -2674,8 +2670,6 @@ class SparkConnectPlanner(
2674
2670
handleCreateViewCommand(command.getCreateDataframeView)
2675
2671
case proto.Command .CommandTypeCase .EXTENSION =>
2676
2672
handleCommandPlugin(command.getExtension)
2677
- case proto.Command .CommandTypeCase .SQL_COMMAND =>
2678
- handleSqlCommand(command.getSqlCommand, responseObserver)
2679
2673
case proto.Command .CommandTypeCase .WRITE_STREAM_OPERATION_START =>
2680
2674
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
2681
2675
case proto.Command .CommandTypeCase .STREAMING_QUERY_COMMAND =>
@@ -2780,12 +2774,8 @@ class SparkConnectPlanner(
2780
2774
.build())
2781
2775
}
2782
2776
2783
- private def handleSqlCommand (
2784
- command : SqlCommand ,
2785
- responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2786
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2787
-
2788
- val relation = if (command.hasInput) {
2777
+ private def getRelationFromSQLCommand (command : proto.SqlCommand ): proto.Relation = {
2778
+ if (command.hasInput) {
2789
2779
command.getInput
2790
2780
} else {
2791
2781
// for backward compatibility
@@ -2802,19 +2792,47 @@ class SparkConnectPlanner(
2802
2792
.build())
2803
2793
.build()
2804
2794
}
2795
+ }
2805
2796
2806
- val df = relation.getRelTypeCase match {
2797
+ private def transformSqlCommand (command : proto.SqlCommand )(
2798
+ tracker : QueryPlanningTracker ): LogicalPlan = {
2799
+ val relation = getRelationFromSQLCommand(command)
2800
+
2801
+ relation.getRelTypeCase match {
2807
2802
case proto.Relation .RelTypeCase .SQL =>
2808
- executeSQL (relation.getSql, tracker)
2803
+ transformSQL (relation.getSql, tracker)
2809
2804
case proto.Relation .RelTypeCase .WITH_RELATIONS =>
2810
- executeSQLWithRefs (relation.getWithRelations, tracker)
2805
+ transformSQLWithRefs (relation.getWithRelations, tracker)
2811
2806
case other =>
2812
2807
throw InvalidInputErrors .sqlCommandExpectsSqlOrWithRelations(other)
2813
2808
}
2809
+ }
2810
+
2811
+ private def runSQLCommand (
2812
+ command : LogicalPlan ,
2813
+ tracker : QueryPlanningTracker ,
2814
+ responseObserver : StreamObserver [ExecutePlanResponse ],
2815
+ protoSQLCommand : proto.SqlCommand ,
2816
+ shuffleCleanupMode : Option [ShuffleCleanupMode ]): Unit = {
2817
+ val isSqlScript = command.isInstanceOf [CompoundBody ]
2818
+ val refs = if (isSqlScript && protoSQLCommand.getInput.hasWithRelations) {
2819
+ protoSQLCommand.getInput.getWithRelations.getReferencesList.asScala
2820
+ .map(_.getSubqueryAlias)
2821
+ .toSeq
2822
+ } else {
2823
+ Seq .empty
2824
+ }
2825
+
2826
+ val df = runWithRefs(refs) {
2827
+ if (shuffleCleanupMode.isDefined) {
2828
+ Dataset .ofRows(session, command, tracker, shuffleCleanupMode.get)
2829
+ } else {
2830
+ Dataset .ofRows(session, command, tracker)
2831
+ }
2832
+ }
2814
2833
2815
2834
// Check if command or SQL Script has been executed.
2816
2835
val isCommand = df.queryExecution.commandExecuted.isInstanceOf [CommandResult ]
2817
- val isSqlScript = df.queryExecution.logical.isInstanceOf [CompoundBody ]
2818
2836
val rows = df.logicalPlan match {
2819
2837
case lr : LocalRelation => lr.data
2820
2838
case cr : CommandResult => cr.rows
@@ -2866,7 +2884,7 @@ class SparkConnectPlanner(
2866
2884
} else {
2867
2885
// No execution triggered for relations. Manually set ready
2868
2886
tracker.setReadyForExecution()
2869
- result.setRelation(relation )
2887
+ result.setRelation(getRelationFromSQLCommand(protoSQLCommand) )
2870
2888
}
2871
2889
executeHolder.eventsManager.postFinished(Some (rows.size))
2872
2890
// Exactly one SQL Command Result Batch
@@ -2908,59 +2926,83 @@ class SparkConnectPlanner(
2908
2926
true
2909
2927
}
2910
2928
2911
- private def executeSQLWithRefs (
2912
- query : proto.WithRelations ,
2913
- tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2914
- if (! isValidSQLWithRefs(query)) {
2915
- throw InvalidInputErrors .invalidSQLWithReferences(query)
2929
+ private def runWithRefs [T ](refs : Seq [proto.SubqueryAlias ])(f : => T ): T = {
2930
+ if (refs.isEmpty) {
2931
+ return f
2916
2932
}
2917
-
2918
- // Eagerly execute commands of the provided SQL string, with given references.
2919
- val sql = query.getRoot.getSql
2920
2933
this .synchronized {
2921
2934
try {
2922
- query.getReferencesList.asScala .foreach { ref =>
2935
+ refs .foreach { ref =>
2923
2936
Dataset
2924
- .ofRows(session, transformRelation(ref.getSubqueryAlias. getInput))
2925
- .createOrReplaceTempView(ref.getSubqueryAlias. getAlias)
2937
+ .ofRows(session, transformRelation(ref.getInput))
2938
+ .createOrReplaceTempView(ref.getAlias)
2926
2939
}
2927
- executeSQL(sql, tracker)
2940
+ f
2928
2941
} finally {
2929
2942
// drop all temporary views
2930
- query.getReferencesList.asScala .foreach { ref =>
2931
- session.catalog.dropTempView(ref.getSubqueryAlias. getAlias)
2943
+ refs .foreach { ref =>
2944
+ session.catalog.dropTempView(ref.getAlias)
2932
2945
}
2933
2946
}
2934
2947
}
2935
2948
}
2936
2949
2937
- private def executeSQL (
2938
- sql : proto.SQL ,
2950
+ private def transformSQLWithRefs (
2951
+ query : proto.WithRelations ,
2952
+ tracker : QueryPlanningTracker ): LogicalPlan = {
2953
+ if (! isValidSQLWithRefs(query)) {
2954
+ throw InvalidInputErrors .invalidSQLWithReferences(query)
2955
+ }
2956
+
2957
+ transformSQL(
2958
+ query.getRoot.getSql,
2959
+ tracker,
2960
+ query.getReferencesList.asScala.map(_.getSubqueryAlias).toSeq)
2961
+ }
2962
+
2963
+ private def executeSQLWithRefs (
2964
+ query : proto.WithRelations ,
2939
2965
tracker : QueryPlanningTracker = new QueryPlanningTracker ) = {
2940
2966
// Eagerly execute commands of the provided SQL string.
2967
+ Dataset .ofRows(session, transformSQLWithRefs(query, tracker), tracker)
2968
+ }
2969
+
2970
+ private def transformSQL (
2971
+ sql : proto.SQL ,
2972
+ tracker : QueryPlanningTracker ,
2973
+ refsToAnalyze : Seq [proto.SubqueryAlias ] = Seq .empty): LogicalPlan = {
2941
2974
val args = sql.getArgsMap
2942
2975
val namedArguments = sql.getNamedArgumentsMap
2943
2976
val posArgs = sql.getPosArgsList
2944
2977
val posArguments = sql.getPosArgumentsList
2945
- if (! namedArguments.isEmpty) {
2946
- session.sql (
2978
+ val parsedPlan = if (! namedArguments.isEmpty) {
2979
+ session.sqlParsedPlan (
2947
2980
sql.getQuery,
2948
2981
namedArguments.asScala.toMap.transform((_, e) => Column (transformExpression(e))),
2949
2982
tracker)
2950
2983
} else if (! posArguments.isEmpty) {
2951
- session.sql (
2984
+ session.sqlParsedPlan (
2952
2985
sql.getQuery,
2953
2986
posArguments.asScala.map(e => Column (transformExpression(e))).toArray,
2954
2987
tracker)
2955
2988
} else if (! args.isEmpty) {
2956
- session.sql (
2989
+ session.sqlParsedPlan (
2957
2990
sql.getQuery,
2958
2991
args.asScala.toMap.transform((_, v) => transformLiteral(v)),
2959
2992
tracker)
2960
2993
} else if (! posArgs.isEmpty) {
2961
- session.sql (sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2994
+ session.sqlParsedPlan (sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2962
2995
} else {
2963
- session.sql(sql.getQuery, Map .empty[String , Any ], tracker)
2996
+ session.sqlParsedPlan(sql.getQuery, Map .empty[String , Any ], tracker)
2997
+ }
2998
+ if (parsedPlan.isInstanceOf [CompoundBody ]) {
2999
+ // If the parsed plan is a CompoundBody, skip analysis and return it.
3000
+ // SQL scripting is a special case as execution occurs during the analysis phase.
3001
+ parsedPlan
3002
+ } else {
3003
+ runWithRefs(refsToAnalyze) {
3004
+ new QueryExecution (session, parsedPlan, tracker).analyzed
3005
+ }
2964
3006
}
2965
3007
}
2966
3008
@@ -3156,11 +3198,32 @@ class SparkConnectPlanner(
3156
3198
}
3157
3199
}
3158
3200
3159
- private def transformAndRunCommand (transformer : QueryPlanningTracker => LogicalPlan ): Unit = {
3160
- val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3161
- val qe = new QueryExecution (session, transformer(tracker), tracker)
3162
- qe.assertCommandExecuted()
3163
- executeHolder.eventsManager.postFinished()
3201
+ private [connect] def runCommand (
3202
+ command : LogicalPlan ,
3203
+ tracker : QueryPlanningTracker ,
3204
+ responseObserver : StreamObserver [ExecutePlanResponse ],
3205
+ protoCommand : proto.Command ,
3206
+ shuffleCleanupMode : Option [ShuffleCleanupMode ] = None ): Unit = {
3207
+ if (protoCommand.getCommandTypeCase == proto.Command .CommandTypeCase .SQL_COMMAND ) {
3208
+ runSQLCommand(
3209
+ command,
3210
+ tracker,
3211
+ responseObserver,
3212
+ protoCommand.getSqlCommand,
3213
+ shuffleCleanupMode)
3214
+ } else {
3215
+ val qe = if (shuffleCleanupMode.isDefined) {
3216
+ new QueryExecution (
3217
+ session,
3218
+ command,
3219
+ tracker = tracker,
3220
+ shuffleCleanupMode = shuffleCleanupMode.get)
3221
+ } else {
3222
+ new QueryExecution (session, command, tracker = tracker)
3223
+ }
3224
+ qe.assertCommandExecuted()
3225
+ executeHolder.eventsManager.postFinished()
3226
+ }
3164
3227
}
3165
3228
3166
3229
/**
@@ -4104,7 +4167,7 @@ class SparkConnectPlanner(
4104
4167
4105
4168
private def transformWithRelations (getWithRelations : proto.WithRelations ): LogicalPlan = {
4106
4169
if (isValidSQLWithRefs(getWithRelations)) {
4107
- transformSqlWithRefs (getWithRelations)
4170
+ executeSQLWithRefs (getWithRelations).logicalPlan
4108
4171
} else {
4109
4172
// Wrap the plan to keep the original planId.
4110
4173
val plan = Project (Seq (UnresolvedStar (None )), transformRelation(getWithRelations.getRoot))
0 commit comments