Skip to content

Commit 43e71e2

Browse files
committed
[SPARK-53148][CONNECT][SQL] Make SqlCommand in SparkConnectPlanner side effect free
1 parent ea8b6fd commit 43e71e2

File tree

3 files changed

+120
-58
lines changed

3 files changed

+120
-58
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
3535
import org.apache.spark.sql.connect.planner.{InvalidInputErrors, SparkConnectPlanner}
3636
import org.apache.spark.sql.connect.service.ExecuteHolder
3737
import org.apache.spark.sql.connect.utils.MetricGenerator
38-
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
38+
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, RemoveShuffleFiles, SkipMigration, SQLExecution}
3939
import org.apache.spark.sql.execution.arrow.ArrowConverters
4040
import org.apache.spark.sql.internal.SQLConf
4141
import org.apache.spark.sql.types.StructType
@@ -83,13 +83,13 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
8383
val command = request.getPlan.getCommand
8484
planner.transformCommand(command) match {
8585
case Some(transformer) =>
86-
val qe = new QueryExecution(
87-
session,
88-
transformer(tracker),
86+
val plan = transformer(tracker)
87+
planner.runCommand(
88+
plan,
8989
tracker,
90-
shuffleCleanupMode = shuffleCleanupMode)
91-
qe.assertCommandExecuted()
92-
executeHolder.eventsManager.postFinished()
90+
responseObserver,
91+
command,
92+
shuffleCleanupMode = Some(shuffleCleanupMode))
9393
case None =>
9494
planner.process(command, responseObserver)
9595
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
3535
import org.apache.spark.annotation.{DeveloperApi, Since}
3636
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
3737
import org.apache.spark.connect.proto
38-
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
38+
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
3939
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
4040
import org.apache.spark.connect.proto.Parse.ParseFormat
4141
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
@@ -69,7 +69,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
6969
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService}
7070
import org.apache.spark.sql.connect.utils.MetricGenerator
7171
import org.apache.spark.sql.errors.QueryCompilationErrors
72-
import org.apache.spark.sql.execution.QueryExecution
72+
import org.apache.spark.sql.execution.{QueryExecution, ShuffleCleanupMode}
7373
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, TypedAggregateExpression}
7474
import org.apache.spark.sql.execution.arrow.ArrowConverters
7575
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExternalCommandExecutor}
@@ -2651,6 +2651,8 @@ class SparkConnectPlanner(
26512651
Some(transformWriteOperation(command.getWriteOperation))
26522652
case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
26532653
Some(transformWriteOperationV2(command.getWriteOperationV2))
2654+
case proto.Command.CommandTypeCase.SQL_COMMAND =>
2655+
Some(transformSqlCommand(command.getSqlCommand))
26542656
case _ =>
26552657
None
26562658
}
@@ -2661,7 +2663,8 @@ class SparkConnectPlanner(
26612663
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
26622664
val transformerOpt = transformCommand(command)
26632665
if (transformerOpt.isDefined) {
2664-
transformAndRunCommand(transformerOpt.get)
2666+
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2667+
runCommand(transformerOpt.get(tracker), tracker, responseObserver, command)
26652668
return
26662669
}
26672670
command.getCommandTypeCase match {
@@ -2675,8 +2678,6 @@ class SparkConnectPlanner(
26752678
handleCreateViewCommand(command.getCreateDataframeView)
26762679
case proto.Command.CommandTypeCase.EXTENSION =>
26772680
handleCommandPlugin(command.getExtension)
2678-
case proto.Command.CommandTypeCase.SQL_COMMAND =>
2679-
handleSqlCommand(command.getSqlCommand, responseObserver)
26802681
case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START =>
26812682
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
26822683
case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND =>
@@ -2781,12 +2782,8 @@ class SparkConnectPlanner(
27812782
.build())
27822783
}
27832784

2784-
private def handleSqlCommand(
2785-
command: SqlCommand,
2786-
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
2787-
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2788-
2789-
val relation = if (command.hasInput) {
2785+
private def getRelationFromSQLCommand(command: proto.SqlCommand): proto.Relation = {
2786+
if (command.hasInput) {
27902787
command.getInput
27912788
} else {
27922789
// for backward compatibility
@@ -2803,15 +2800,33 @@ class SparkConnectPlanner(
28032800
.build())
28042801
.build()
28052802
}
2803+
}
2804+
2805+
private def transformSqlCommand(command: proto.SqlCommand)(
2806+
tracker: QueryPlanningTracker): LogicalPlan = {
2807+
val relation = getRelationFromSQLCommand(command)
28062808

2807-
val df = relation.getRelTypeCase match {
2809+
relation.getRelTypeCase match {
28082810
case proto.Relation.RelTypeCase.SQL =>
2809-
executeSQL(relation.getSql, tracker)
2811+
transformSQL(relation.getSql, tracker)
28102812
case proto.Relation.RelTypeCase.WITH_RELATIONS =>
2811-
executeSQLWithRefs(relation.getWithRelations, tracker)
2813+
transformSQLWithRefs(relation.getWithRelations, tracker)
28122814
case other =>
28132815
throw InvalidInputErrors.sqlCommandExpectsSqlOrWithRelations(other)
28142816
}
2817+
}
2818+
2819+
private def runSQLCommand(
2820+
command: LogicalPlan,
2821+
tracker: QueryPlanningTracker,
2822+
responseObserver: StreamObserver[ExecutePlanResponse],
2823+
protoCommand: proto.Command,
2824+
shuffleCleanupMode: Option[ShuffleCleanupMode]): Unit = {
2825+
val df = if (shuffleCleanupMode.isDefined) {
2826+
Dataset.ofRows(session, command, tracker, shuffleCleanupMode.get)
2827+
} else {
2828+
Dataset.ofRows(session, command, tracker)
2829+
}
28152830

28162831
// Check if command or SQL Script has been executed.
28172832
val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult]
@@ -2867,7 +2882,7 @@ class SparkConnectPlanner(
28672882
} else {
28682883
// No execution triggered for relations. Manually set ready
28692884
tracker.setReadyForExecution()
2870-
result.setRelation(relation)
2885+
result.setRelation(getRelationFromSQLCommand(protoCommand.getSqlCommand))
28712886
}
28722887
executeHolder.eventsManager.postFinished(Some(rows.size))
28732888
// Exactly one SQL Command Result Batch
@@ -2909,9 +2924,9 @@ class SparkConnectPlanner(
29092924
true
29102925
}
29112926

2912-
private def executeSQLWithRefs(
2927+
private def transformSQLWithRefs(
29132928
query: proto.WithRelations,
2914-
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
2929+
tracker: QueryPlanningTracker = new QueryPlanningTracker): LogicalPlan = {
29152930
if (!isValidSQLWithRefs(query)) {
29162931
throw InvalidInputErrors.invalidSQLWithReferences(query)
29172932
}
@@ -2925,7 +2940,7 @@ class SparkConnectPlanner(
29252940
.ofRows(session, transformRelation(ref.getSubqueryAlias.getInput))
29262941
.createOrReplaceTempView(ref.getSubqueryAlias.getAlias)
29272942
}
2928-
executeSQL(sql, tracker)
2943+
transformSQL(sql, tracker)
29292944
} finally {
29302945
// drop all temporary views
29312946
query.getReferencesList.asScala.foreach { ref =>
@@ -2935,36 +2950,48 @@ class SparkConnectPlanner(
29352950
}
29362951
}
29372952

2938-
private def executeSQL(
2939-
sql: proto.SQL,
2953+
private def executeSQLWithRefs(
2954+
query: proto.WithRelations,
29402955
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
2956+
Dataset.ofRows(session, transformSQLWithRefs(query, tracker), tracker)
2957+
}
2958+
2959+
private def transformSQL(
2960+
sql: proto.SQL,
2961+
tracker: QueryPlanningTracker = new QueryPlanningTracker): LogicalPlan = {
29412962
// Eagerly execute commands of the provided SQL string.
29422963
val args = sql.getArgsMap
29432964
val namedArguments = sql.getNamedArgumentsMap
29442965
val posArgs = sql.getPosArgsList
29452966
val posArguments = sql.getPosArgumentsList
29462967
if (!namedArguments.isEmpty) {
2947-
session.sql(
2968+
session.sqlParsedPlan(
29482969
sql.getQuery,
29492970
namedArguments.asScala.toMap.transform((_, e) => Column(transformExpression(e))),
29502971
tracker)
29512972
} else if (!posArguments.isEmpty) {
2952-
session.sql(
2973+
session.sqlParsedPlan(
29532974
sql.getQuery,
29542975
posArguments.asScala.map(e => Column(transformExpression(e))).toArray,
29552976
tracker)
29562977
} else if (!args.isEmpty) {
2957-
session.sql(
2978+
session.sqlParsedPlan(
29582979
sql.getQuery,
29592980
args.asScala.toMap.transform((_, v) => transformLiteral(v)),
29602981
tracker)
29612982
} else if (!posArgs.isEmpty) {
2962-
session.sql(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2983+
session.sqlParsedPlan(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
29632984
} else {
2964-
session.sql(sql.getQuery, Map.empty[String, Any], tracker)
2985+
session.sqlParsedPlan(sql.getQuery, Map.empty[String, Any], tracker)
29652986
}
29662987
}
29672988

2989+
private def executeSQL(
2990+
sql: proto.SQL,
2991+
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
2992+
Dataset.ofRows(session, transformSQL(sql, tracker), tracker)
2993+
}
2994+
29682995
private def handleRegisterUserDefinedFunction(
29692996
fun: proto.CommonInlineUserDefinedFunction): Unit = {
29702997
fun.getFunctionCase match {
@@ -3157,11 +3184,27 @@ class SparkConnectPlanner(
31573184
}
31583185
}
31593186

3160-
private def transformAndRunCommand(transformer: QueryPlanningTracker => LogicalPlan): Unit = {
3161-
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3162-
val qe = new QueryExecution(session, transformer(tracker), tracker)
3163-
qe.assertCommandExecuted()
3164-
executeHolder.eventsManager.postFinished()
3187+
private[connect] def runCommand(
3188+
command: LogicalPlan,
3189+
tracker: QueryPlanningTracker,
3190+
responseObserver: StreamObserver[ExecutePlanResponse],
3191+
protoCommand: proto.Command,
3192+
shuffleCleanupMode: Option[ShuffleCleanupMode] = None): Unit = {
3193+
if (protoCommand.getCommandTypeCase == proto.Command.CommandTypeCase.SQL_COMMAND) {
3194+
runSQLCommand(command, tracker, responseObserver, protoCommand, shuffleCleanupMode)
3195+
} else {
3196+
val qe = if (shuffleCleanupMode.isDefined) {
3197+
new QueryExecution(
3198+
session,
3199+
command,
3200+
tracker = tracker,
3201+
shuffleCleanupMode = shuffleCleanupMode.get)
3202+
} else {
3203+
new QueryExecution(session, command, tracker = tracker)
3204+
}
3205+
qe.assertCommandExecuted()
3206+
executeHolder.eventsManager.postFinished()
3207+
}
31653208
}
31663209

31673210
/**

sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,39 @@ class SparkSession private(
430430
| Everything else |
431431
* ----------------- */
432432

433+
private[sql] def sqlParsedPlan(
434+
sqlText: String,
435+
args: Array[_],
436+
tracker: QueryPlanningTracker): LogicalPlan = {
437+
tracker.measurePhase(QueryPlanningTracker.PARSING) {
438+
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
439+
if (args.nonEmpty) {
440+
if (parsedPlan.isInstanceOf[CompoundBody]) {
441+
// Positional parameters are not supported for SQL scripting.
442+
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
443+
}
444+
PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq)
445+
} else {
446+
parsedPlan
447+
}
448+
}
449+
}
450+
451+
private[sql] def sqlParsedPlan(
452+
sqlText: String,
453+
args: Map[String, Any],
454+
tracker: QueryPlanningTracker): LogicalPlan = {
455+
tracker.measurePhase(QueryPlanningTracker.PARSING) {
456+
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
457+
if (args.nonEmpty) {
458+
NameParameterizedQuery(parsedPlan, args.transform((_, v) => lit(v).expr))
459+
} else {
460+
parsedPlan
461+
}
462+
}
463+
}
464+
465+
433466
/**
434467
* Executes a SQL query substituting positional parameters by the given arguments,
435468
* returning the result as a `DataFrame`.
@@ -445,22 +478,15 @@ class SparkSession private(
445478
* such as `map()`, `array()`, `struct()`, in that case it is taken as is.
446479
* @param tracker A tracker that can notify when query is ready for execution
447480
*/
448-
private[sql] def sql(sqlText: String, args: Array[_], tracker: QueryPlanningTracker): DataFrame =
481+
private[sql] def sql(
482+
sqlText: String,
483+
args: Array[_],
484+
tracker: QueryPlanningTracker): DataFrame = {
449485
withActive {
450-
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
451-
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
452-
if (args.nonEmpty) {
453-
if (parsedPlan.isInstanceOf[CompoundBody]) {
454-
// Positional parameters are not supported for SQL scripting.
455-
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
456-
}
457-
PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq)
458-
} else {
459-
parsedPlan
460-
}
461-
}
486+
val plan = sqlParsedPlan(sqlText, args, tracker)
462487
Dataset.ofRows(self, plan, tracker)
463488
}
489+
}
464490

465491
/** @inheritdoc */
466492
def sql(sqlText: String, args: Array[_]): DataFrame = {
@@ -488,14 +514,7 @@ class SparkSession private(
488514
args: Map[String, Any],
489515
tracker: QueryPlanningTracker): DataFrame =
490516
withActive {
491-
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
492-
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
493-
if (args.nonEmpty) {
494-
NameParameterizedQuery(parsedPlan, args.transform((_, v) => lit(v).expr))
495-
} else {
496-
parsedPlan
497-
}
498-
}
517+
val plan = sqlParsedPlan(sqlText, args, tracker)
499518
Dataset.ofRows(self, plan, tracker)
500519
}
501520

0 commit comments

Comments
 (0)