Skip to content

Commit c313fb7

Browse files
committed
[SPARK-53148][CONNECT][SQL] Make SqlCommand in SparkConnectPlanner side effect free
1 parent ea8b6fd commit c313fb7

File tree

3 files changed

+120
-67
lines changed

3 files changed

+120
-67
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
3535
import org.apache.spark.sql.connect.planner.{InvalidInputErrors, SparkConnectPlanner}
3636
import org.apache.spark.sql.connect.service.ExecuteHolder
3737
import org.apache.spark.sql.connect.utils.MetricGenerator
38-
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
38+
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, RemoveShuffleFiles, SkipMigration, SQLExecution}
3939
import org.apache.spark.sql.execution.arrow.ArrowConverters
4040
import org.apache.spark.sql.internal.SQLConf
4141
import org.apache.spark.sql.types.StructType
@@ -83,13 +83,13 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
8383
val command = request.getPlan.getCommand
8484
planner.transformCommand(command) match {
8585
case Some(transformer) =>
86-
val qe = new QueryExecution(
87-
session,
88-
transformer(tracker),
86+
val plan = transformer(tracker)
87+
planner.runCommand(
88+
plan,
8989
tracker,
90-
shuffleCleanupMode = shuffleCleanupMode)
91-
qe.assertCommandExecuted()
92-
executeHolder.eventsManager.postFinished()
90+
responseObserver,
91+
command,
92+
shuffleCleanupMode = Some(shuffleCleanupMode))
9393
case None =>
9494
planner.process(command, responseObserver)
9595
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
3535
import org.apache.spark.annotation.{DeveloperApi, Since}
3636
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
3737
import org.apache.spark.connect.proto
38-
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
38+
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
3939
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
4040
import org.apache.spark.connect.proto.Parse.ParseFormat
4141
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
@@ -69,7 +69,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
6969
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService}
7070
import org.apache.spark.sql.connect.utils.MetricGenerator
7171
import org.apache.spark.sql.errors.QueryCompilationErrors
72-
import org.apache.spark.sql.execution.QueryExecution
72+
import org.apache.spark.sql.execution.{QueryExecution, ShuffleCleanupMode}
7373
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, TypedAggregateExpression}
7474
import org.apache.spark.sql.execution.arrow.ArrowConverters
7575
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExternalCommandExecutor}
@@ -343,13 +343,6 @@ class SparkConnectPlanner(
343343
}
344344
}
345345

346-
private def transformSqlWithRefs(query: proto.WithRelations): LogicalPlan = {
347-
if (!isValidSQLWithRefs(query)) {
348-
throw InvalidInputErrors.invalidSQLWithReferences(query)
349-
}
350-
executeSQLWithRefs(query).logicalPlan
351-
}
352-
353346
private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = {
354347
val aliasIdentifier =
355348
if (alias.getQualifierCount > 0) {
@@ -2651,6 +2644,8 @@ class SparkConnectPlanner(
26512644
Some(transformWriteOperation(command.getWriteOperation))
26522645
case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
26532646
Some(transformWriteOperationV2(command.getWriteOperationV2))
2647+
case proto.Command.CommandTypeCase.SQL_COMMAND =>
2648+
Some(transformSqlCommand(command.getSqlCommand))
26542649
case _ =>
26552650
None
26562651
}
@@ -2661,7 +2656,8 @@ class SparkConnectPlanner(
26612656
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
26622657
val transformerOpt = transformCommand(command)
26632658
if (transformerOpt.isDefined) {
2664-
transformAndRunCommand(transformerOpt.get)
2659+
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2660+
runCommand(transformerOpt.get(tracker), tracker, responseObserver, command)
26652661
return
26662662
}
26672663
command.getCommandTypeCase match {
@@ -2675,8 +2671,6 @@ class SparkConnectPlanner(
26752671
handleCreateViewCommand(command.getCreateDataframeView)
26762672
case proto.Command.CommandTypeCase.EXTENSION =>
26772673
handleCommandPlugin(command.getExtension)
2678-
case proto.Command.CommandTypeCase.SQL_COMMAND =>
2679-
handleSqlCommand(command.getSqlCommand, responseObserver)
26802674
case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START =>
26812675
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
26822676
case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND =>
@@ -2781,12 +2775,8 @@ class SparkConnectPlanner(
27812775
.build())
27822776
}
27832777

2784-
private def handleSqlCommand(
2785-
command: SqlCommand,
2786-
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
2787-
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2788-
2789-
val relation = if (command.hasInput) {
2778+
private def getRelationFromSQLCommand(command: proto.SqlCommand): proto.Relation = {
2779+
if (command.hasInput) {
27902780
command.getInput
27912781
} else {
27922782
// for backward compatibility
@@ -2803,15 +2793,33 @@ class SparkConnectPlanner(
28032793
.build())
28042794
.build()
28052795
}
2796+
}
2797+
2798+
private def transformSqlCommand(command: proto.SqlCommand)(
2799+
tracker: QueryPlanningTracker): LogicalPlan = {
2800+
val relation = getRelationFromSQLCommand(command)
28062801

2807-
val df = relation.getRelTypeCase match {
2802+
relation.getRelTypeCase match {
28082803
case proto.Relation.RelTypeCase.SQL =>
2809-
executeSQL(relation.getSql, tracker)
2804+
transformSQL(relation.getSql, tracker)
28102805
case proto.Relation.RelTypeCase.WITH_RELATIONS =>
2811-
executeSQLWithRefs(relation.getWithRelations, tracker)
2806+
transformSQLWithRefs(relation.getWithRelations, tracker)
28122807
case other =>
28132808
throw InvalidInputErrors.sqlCommandExpectsSqlOrWithRelations(other)
28142809
}
2810+
}
2811+
2812+
private def runSQLCommand(
2813+
command: LogicalPlan,
2814+
tracker: QueryPlanningTracker,
2815+
responseObserver: StreamObserver[ExecutePlanResponse],
2816+
protoCommand: proto.Command,
2817+
shuffleCleanupMode: Option[ShuffleCleanupMode]): Unit = {
2818+
val df = if (shuffleCleanupMode.isDefined) {
2819+
Dataset.ofRows(session, command, tracker, shuffleCleanupMode.get)
2820+
} else {
2821+
Dataset.ofRows(session, command, tracker)
2822+
}
28152823

28162824
// Check if command or SQL Script has been executed.
28172825
val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult]
@@ -2867,7 +2875,7 @@ class SparkConnectPlanner(
28672875
} else {
28682876
// No execution triggered for relations. Manually set ready
28692877
tracker.setReadyForExecution()
2870-
result.setRelation(relation)
2878+
result.setRelation(getRelationFromSQLCommand(protoCommand.getSqlCommand))
28712879
}
28722880
executeHolder.eventsManager.postFinished(Some(rows.size))
28732881
// Exactly one SQL Command Result Batch
@@ -2909,9 +2917,9 @@ class SparkConnectPlanner(
29092917
true
29102918
}
29112919

2912-
private def executeSQLWithRefs(
2920+
private def transformSQLWithRefs(
29132921
query: proto.WithRelations,
2914-
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
2922+
tracker: QueryPlanningTracker): LogicalPlan = {
29152923
if (!isValidSQLWithRefs(query)) {
29162924
throw InvalidInputErrors.invalidSQLWithReferences(query)
29172925
}
@@ -2925,7 +2933,7 @@ class SparkConnectPlanner(
29252933
.ofRows(session, transformRelation(ref.getSubqueryAlias.getInput))
29262934
.createOrReplaceTempView(ref.getSubqueryAlias.getAlias)
29272935
}
2928-
executeSQL(sql, tracker)
2936+
transformSQL(sql, tracker)
29292937
} finally {
29302938
// drop all temporary views
29312939
query.getReferencesList.asScala.foreach { ref =>
@@ -2935,36 +2943,46 @@ class SparkConnectPlanner(
29352943
}
29362944
}
29372945

2938-
private def executeSQL(
2939-
sql: proto.SQL,
2946+
private def executeSQLWithRefs(
2947+
query: proto.WithRelations,
29402948
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
2949+
Dataset.ofRows(session, transformSQLWithRefs(query, tracker), tracker)
2950+
}
2951+
2952+
private def transformSQL(sql: proto.SQL, tracker: QueryPlanningTracker): LogicalPlan = {
29412953
// Eagerly execute commands of the provided SQL string.
29422954
val args = sql.getArgsMap
29432955
val namedArguments = sql.getNamedArgumentsMap
29442956
val posArgs = sql.getPosArgsList
29452957
val posArguments = sql.getPosArgumentsList
29462958
if (!namedArguments.isEmpty) {
2947-
session.sql(
2959+
session.sqlParsedPlan(
29482960
sql.getQuery,
29492961
namedArguments.asScala.toMap.transform((_, e) => Column(transformExpression(e))),
29502962
tracker)
29512963
} else if (!posArguments.isEmpty) {
2952-
session.sql(
2964+
session.sqlParsedPlan(
29532965
sql.getQuery,
29542966
posArguments.asScala.map(e => Column(transformExpression(e))).toArray,
29552967
tracker)
29562968
} else if (!args.isEmpty) {
2957-
session.sql(
2969+
session.sqlParsedPlan(
29582970
sql.getQuery,
29592971
args.asScala.toMap.transform((_, v) => transformLiteral(v)),
29602972
tracker)
29612973
} else if (!posArgs.isEmpty) {
2962-
session.sql(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2974+
session.sqlParsedPlan(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
29632975
} else {
2964-
session.sql(sql.getQuery, Map.empty[String, Any], tracker)
2976+
session.sqlParsedPlan(sql.getQuery, Map.empty[String, Any], tracker)
29652977
}
29662978
}
29672979

2980+
private def executeSQL(
2981+
sql: proto.SQL,
2982+
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
2983+
Dataset.ofRows(session, transformSQL(sql, tracker), tracker)
2984+
}
2985+
29682986
private def handleRegisterUserDefinedFunction(
29692987
fun: proto.CommonInlineUserDefinedFunction): Unit = {
29702988
fun.getFunctionCase match {
@@ -3157,11 +3175,27 @@ class SparkConnectPlanner(
31573175
}
31583176
}
31593177

3160-
private def transformAndRunCommand(transformer: QueryPlanningTracker => LogicalPlan): Unit = {
3161-
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3162-
val qe = new QueryExecution(session, transformer(tracker), tracker)
3163-
qe.assertCommandExecuted()
3164-
executeHolder.eventsManager.postFinished()
3178+
private[connect] def runCommand(
3179+
command: LogicalPlan,
3180+
tracker: QueryPlanningTracker,
3181+
responseObserver: StreamObserver[ExecutePlanResponse],
3182+
protoCommand: proto.Command,
3183+
shuffleCleanupMode: Option[ShuffleCleanupMode] = None): Unit = {
3184+
if (protoCommand.getCommandTypeCase == proto.Command.CommandTypeCase.SQL_COMMAND) {
3185+
runSQLCommand(command, tracker, responseObserver, protoCommand, shuffleCleanupMode)
3186+
} else {
3187+
val qe = if (shuffleCleanupMode.isDefined) {
3188+
new QueryExecution(
3189+
session,
3190+
command,
3191+
tracker = tracker,
3192+
shuffleCleanupMode = shuffleCleanupMode.get)
3193+
} else {
3194+
new QueryExecution(session, command, tracker = tracker)
3195+
}
3196+
qe.assertCommandExecuted()
3197+
executeHolder.eventsManager.postFinished()
3198+
}
31653199
}
31663200

31673201
/**
@@ -4105,7 +4139,7 @@ class SparkConnectPlanner(
41054139

41064140
private def transformWithRelations(getWithRelations: proto.WithRelations): LogicalPlan = {
41074141
if (isValidSQLWithRefs(getWithRelations)) {
4108-
transformSqlWithRefs(getWithRelations)
4142+
executeSQLWithRefs(getWithRelations).logicalPlan
41094143
} else {
41104144
// Wrap the plan to keep the original planId.
41114145
val plan = Project(Seq(UnresolvedStar(None)), transformRelation(getWithRelations.getRoot))

sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParame
4444
import org.apache.spark.sql.catalyst.encoders._
4545
import org.apache.spark.sql.catalyst.expressions.AttributeReference
4646
import org.apache.spark.sql.catalyst.parser.ParserInterface
47-
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, Range}
47+
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, LogicalPlan, Range}
4848
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
4949
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
5050
import org.apache.spark.sql.classic.SparkSession.applyAndLoadExtensions
@@ -430,6 +430,39 @@ class SparkSession private(
430430
| Everything else |
431431
* ----------------- */
432432

433+
private[sql] def sqlParsedPlan(
434+
sqlText: String,
435+
args: Array[_],
436+
tracker: QueryPlanningTracker): LogicalPlan = {
437+
tracker.measurePhase(QueryPlanningTracker.PARSING) {
438+
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
439+
if (args.nonEmpty) {
440+
if (parsedPlan.isInstanceOf[CompoundBody]) {
441+
// Positional parameters are not supported for SQL scripting.
442+
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
443+
}
444+
PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq)
445+
} else {
446+
parsedPlan
447+
}
448+
}
449+
}
450+
451+
private[sql] def sqlParsedPlan(
452+
sqlText: String,
453+
args: Map[String, Any],
454+
tracker: QueryPlanningTracker): LogicalPlan = {
455+
tracker.measurePhase(QueryPlanningTracker.PARSING) {
456+
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
457+
if (args.nonEmpty) {
458+
NameParameterizedQuery(parsedPlan, args.transform((_, v) => lit(v).expr))
459+
} else {
460+
parsedPlan
461+
}
462+
}
463+
}
464+
465+
433466
/**
434467
* Executes a SQL query substituting positional parameters by the given arguments,
435468
* returning the result as a `DataFrame`.
@@ -445,22 +478,15 @@ class SparkSession private(
445478
* such as `map()`, `array()`, `struct()`, in that case it is taken as is.
446479
* @param tracker A tracker that can notify when query is ready for execution
447480
*/
448-
private[sql] def sql(sqlText: String, args: Array[_], tracker: QueryPlanningTracker): DataFrame =
481+
private[sql] def sql(
482+
sqlText: String,
483+
args: Array[_],
484+
tracker: QueryPlanningTracker): DataFrame = {
449485
withActive {
450-
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
451-
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
452-
if (args.nonEmpty) {
453-
if (parsedPlan.isInstanceOf[CompoundBody]) {
454-
// Positional parameters are not supported for SQL scripting.
455-
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
456-
}
457-
PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq)
458-
} else {
459-
parsedPlan
460-
}
461-
}
486+
val plan = sqlParsedPlan(sqlText, args, tracker)
462487
Dataset.ofRows(self, plan, tracker)
463488
}
489+
}
464490

465491
/** @inheritdoc */
466492
def sql(sqlText: String, args: Array[_]): DataFrame = {
@@ -488,14 +514,7 @@ class SparkSession private(
488514
args: Map[String, Any],
489515
tracker: QueryPlanningTracker): DataFrame =
490516
withActive {
491-
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
492-
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
493-
if (args.nonEmpty) {
494-
NameParameterizedQuery(parsedPlan, args.transform((_, v) => lit(v).expr))
495-
} else {
496-
parsedPlan
497-
}
498-
}
517+
val plan = sqlParsedPlan(sqlText, args, tracker)
499518
Dataset.ofRows(self, plan, tracker)
500519
}
501520

0 commit comments

Comments
 (0)