Skip to content

Commit 4771a57

Browse files
committed
[SPARK-53148][CONNECT][SQL] Make SqlCommand in SparkConnectPlanner side effect free
1 parent eba5381 commit 4771a57

File tree

4 files changed

+173
-80
lines changed

4 files changed

+173
-80
lines changed

python/pyspark/sql/tests/test_sql.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,17 @@ def test_nested_dataframe(self):
168168
self.assertEqual(df3.take(1), [Row(id=4)])
169169
self.assertEqual(df3.tail(1), [Row(id=9)])
170170

171+
def test_nested_dataframe_for_sql_scripting(self):
172+
with self.sql_conf({"spark.sql.scripting.enabled": True}):
173+
df0 = self.spark.range(10)
174+
df1 = self.spark.sql(
175+
"BEGIN SELECT * FROM {df} WHERE id > 1; END;",
176+
df=df0,
177+
)
178+
self.assertEqual(df1.count(), 8)
179+
self.assertEqual(df1.take(1), [Row(id=2)])
180+
self.assertEqual(df1.tail(1), [Row(id=9)])
181+
171182
def test_lit_time(self):
172183
import datetime
173184

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
3535
import org.apache.spark.sql.connect.planner.{InvalidInputErrors, SparkConnectPlanner}
3636
import org.apache.spark.sql.connect.service.ExecuteHolder
3737
import org.apache.spark.sql.connect.utils.MetricGenerator
38-
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
38+
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, RemoveShuffleFiles, SkipMigration, SQLExecution}
3939
import org.apache.spark.sql.execution.arrow.ArrowConverters
4040
import org.apache.spark.sql.internal.SQLConf
4141
import org.apache.spark.sql.types.StructType
@@ -83,13 +83,13 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
8383
val command = request.getPlan.getCommand
8484
planner.transformCommand(command) match {
8585
case Some(transformer) =>
86-
val qe = new QueryExecution(
87-
session,
88-
transformer(tracker),
86+
val plan = transformer(tracker)
87+
planner.runCommand(
88+
plan,
8989
tracker,
90-
shuffleCleanupMode = shuffleCleanupMode)
91-
qe.assertCommandExecuted()
92-
executeHolder.eventsManager.postFinished()
90+
responseObserver,
91+
command,
92+
shuffleCleanupMode = Some(shuffleCleanupMode))
9393
case None =>
9494
planner.process(command, responseObserver)
9595
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 114 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
3434
import org.apache.spark.annotation.{DeveloperApi, Since}
3535
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
3636
import org.apache.spark.connect.proto
37-
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
37+
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
3838
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
3939
import org.apache.spark.connect.proto.Parse.ParseFormat
4040
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
@@ -68,7 +68,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
6868
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService}
6969
import org.apache.spark.sql.connect.utils.MetricGenerator
7070
import org.apache.spark.sql.errors.QueryCompilationErrors
71-
import org.apache.spark.sql.execution.QueryExecution
71+
import org.apache.spark.sql.execution.{QueryExecution, ShuffleCleanupMode}
7272
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, TypedAggregateExpression}
7373
import org.apache.spark.sql.execution.arrow.ArrowConverters
7474
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExternalCommandExecutor}
@@ -342,13 +342,6 @@ class SparkConnectPlanner(
342342
}
343343
}
344344

345-
private def transformSqlWithRefs(query: proto.WithRelations): LogicalPlan = {
346-
if (!isValidSQLWithRefs(query)) {
347-
throw InvalidInputErrors.invalidSQLWithReferences(query)
348-
}
349-
executeSQLWithRefs(query).logicalPlan
350-
}
351-
352345
private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = {
353346
val aliasIdentifier =
354347
if (alias.getQualifierCount > 0) {
@@ -2650,6 +2643,8 @@ class SparkConnectPlanner(
26502643
Some(transformWriteOperation(command.getWriteOperation))
26512644
case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
26522645
Some(transformWriteOperationV2(command.getWriteOperationV2))
2646+
case proto.Command.CommandTypeCase.SQL_COMMAND =>
2647+
Some(transformSqlCommand(command.getSqlCommand))
26532648
case _ =>
26542649
None
26552650
}
@@ -2660,7 +2655,8 @@ class SparkConnectPlanner(
26602655
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
26612656
val transformerOpt = transformCommand(command)
26622657
if (transformerOpt.isDefined) {
2663-
transformAndRunCommand(transformerOpt.get)
2658+
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2659+
runCommand(transformerOpt.get(tracker), tracker, responseObserver, command)
26642660
return
26652661
}
26662662
command.getCommandTypeCase match {
@@ -2674,8 +2670,6 @@ class SparkConnectPlanner(
26742670
handleCreateViewCommand(command.getCreateDataframeView)
26752671
case proto.Command.CommandTypeCase.EXTENSION =>
26762672
handleCommandPlugin(command.getExtension)
2677-
case proto.Command.CommandTypeCase.SQL_COMMAND =>
2678-
handleSqlCommand(command.getSqlCommand, responseObserver)
26792673
case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START =>
26802674
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
26812675
case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND =>
@@ -2780,12 +2774,8 @@ class SparkConnectPlanner(
27802774
.build())
27812775
}
27822776

2783-
private def handleSqlCommand(
2784-
command: SqlCommand,
2785-
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
2786-
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
2787-
2788-
val relation = if (command.hasInput) {
2777+
private def getRelationFromSQLCommand(command: proto.SqlCommand): proto.Relation = {
2778+
if (command.hasInput) {
27892779
command.getInput
27902780
} else {
27912781
// for backward compatibility
@@ -2802,19 +2792,47 @@ class SparkConnectPlanner(
28022792
.build())
28032793
.build()
28042794
}
2795+
}
28052796

2806-
val df = relation.getRelTypeCase match {
2797+
private def transformSqlCommand(command: proto.SqlCommand)(
2798+
tracker: QueryPlanningTracker): LogicalPlan = {
2799+
val relation = getRelationFromSQLCommand(command)
2800+
2801+
relation.getRelTypeCase match {
28072802
case proto.Relation.RelTypeCase.SQL =>
2808-
executeSQL(relation.getSql, tracker)
2803+
transformSQL(relation.getSql, tracker)
28092804
case proto.Relation.RelTypeCase.WITH_RELATIONS =>
2810-
executeSQLWithRefs(relation.getWithRelations, tracker)
2805+
transformSQLWithRefs(relation.getWithRelations, tracker)
28112806
case other =>
28122807
throw InvalidInputErrors.sqlCommandExpectsSqlOrWithRelations(other)
28132808
}
2809+
}
2810+
2811+
private def runSQLCommand(
2812+
command: LogicalPlan,
2813+
tracker: QueryPlanningTracker,
2814+
responseObserver: StreamObserver[ExecutePlanResponse],
2815+
protoSQLCommand: proto.SqlCommand,
2816+
shuffleCleanupMode: Option[ShuffleCleanupMode]): Unit = {
2817+
val isSqlScript = command.isInstanceOf[CompoundBody]
2818+
val refs = if (isSqlScript && protoSQLCommand.getInput.hasWithRelations) {
2819+
protoSQLCommand.getInput.getWithRelations.getReferencesList.asScala
2820+
.map(_.getSubqueryAlias)
2821+
.toSeq
2822+
} else {
2823+
Seq.empty
2824+
}
2825+
2826+
val df = runWithRefs(refs) {
2827+
if (shuffleCleanupMode.isDefined) {
2828+
Dataset.ofRows(session, command, tracker, shuffleCleanupMode.get)
2829+
} else {
2830+
Dataset.ofRows(session, command, tracker)
2831+
}
2832+
}
28142833

28152834
// Check if command or SQL Script has been executed.
28162835
val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult]
2817-
val isSqlScript = df.queryExecution.logical.isInstanceOf[CompoundBody]
28182836
val rows = df.logicalPlan match {
28192837
case lr: LocalRelation => lr.data
28202838
case cr: CommandResult => cr.rows
@@ -2866,7 +2884,7 @@ class SparkConnectPlanner(
28662884
} else {
28672885
// No execution triggered for relations. Manually set ready
28682886
tracker.setReadyForExecution()
2869-
result.setRelation(relation)
2887+
result.setRelation(getRelationFromSQLCommand(protoSQLCommand))
28702888
}
28712889
executeHolder.eventsManager.postFinished(Some(rows.size))
28722890
// Exactly one SQL Command Result Batch
@@ -2908,59 +2926,83 @@ class SparkConnectPlanner(
29082926
true
29092927
}
29102928

2911-
private def executeSQLWithRefs(
2912-
query: proto.WithRelations,
2913-
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
2914-
if (!isValidSQLWithRefs(query)) {
2915-
throw InvalidInputErrors.invalidSQLWithReferences(query)
2929+
private def runWithRefs[T](refs: Seq[proto.SubqueryAlias])(f: => T): T = {
2930+
if (refs.isEmpty) {
2931+
return f
29162932
}
2917-
2918-
// Eagerly execute commands of the provided SQL string, with given references.
2919-
val sql = query.getRoot.getSql
29202933
this.synchronized {
29212934
try {
2922-
query.getReferencesList.asScala.foreach { ref =>
2935+
refs.foreach { ref =>
29232936
Dataset
2924-
.ofRows(session, transformRelation(ref.getSubqueryAlias.getInput))
2925-
.createOrReplaceTempView(ref.getSubqueryAlias.getAlias)
2937+
.ofRows(session, transformRelation(ref.getInput))
2938+
.createOrReplaceTempView(ref.getAlias)
29262939
}
2927-
executeSQL(sql, tracker)
2940+
f
29282941
} finally {
29292942
// drop all temporary views
2930-
query.getReferencesList.asScala.foreach { ref =>
2931-
session.catalog.dropTempView(ref.getSubqueryAlias.getAlias)
2943+
refs.foreach { ref =>
2944+
session.catalog.dropTempView(ref.getAlias)
29322945
}
29332946
}
29342947
}
29352948
}
29362949

2937-
private def executeSQL(
2938-
sql: proto.SQL,
2950+
private def transformSQLWithRefs(
2951+
query: proto.WithRelations,
2952+
tracker: QueryPlanningTracker): LogicalPlan = {
2953+
if (!isValidSQLWithRefs(query)) {
2954+
throw InvalidInputErrors.invalidSQLWithReferences(query)
2955+
}
2956+
2957+
transformSQL(
2958+
query.getRoot.getSql,
2959+
tracker,
2960+
query.getReferencesList.asScala.map(_.getSubqueryAlias).toSeq)
2961+
}
2962+
2963+
private def executeSQLWithRefs(
2964+
query: proto.WithRelations,
29392965
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
29402966
// Eagerly execute commands of the provided SQL string.
2967+
Dataset.ofRows(session, transformSQLWithRefs(query, tracker), tracker)
2968+
}
2969+
2970+
private def transformSQL(
2971+
sql: proto.SQL,
2972+
tracker: QueryPlanningTracker,
2973+
refsToAnalyze: Seq[proto.SubqueryAlias] = Seq.empty): LogicalPlan = {
29412974
val args = sql.getArgsMap
29422975
val namedArguments = sql.getNamedArgumentsMap
29432976
val posArgs = sql.getPosArgsList
29442977
val posArguments = sql.getPosArgumentsList
2945-
if (!namedArguments.isEmpty) {
2946-
session.sql(
2978+
val parsedPlan = if (!namedArguments.isEmpty) {
2979+
session.sqlParsedPlan(
29472980
sql.getQuery,
29482981
namedArguments.asScala.toMap.transform((_, e) => Column(transformExpression(e))),
29492982
tracker)
29502983
} else if (!posArguments.isEmpty) {
2951-
session.sql(
2984+
session.sqlParsedPlan(
29522985
sql.getQuery,
29532986
posArguments.asScala.map(e => Column(transformExpression(e))).toArray,
29542987
tracker)
29552988
} else if (!args.isEmpty) {
2956-
session.sql(
2989+
session.sqlParsedPlan(
29572990
sql.getQuery,
29582991
args.asScala.toMap.transform((_, v) => transformLiteral(v)),
29592992
tracker)
29602993
} else if (!posArgs.isEmpty) {
2961-
session.sql(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
2994+
session.sqlParsedPlan(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
29622995
} else {
2963-
session.sql(sql.getQuery, Map.empty[String, Any], tracker)
2996+
session.sqlParsedPlan(sql.getQuery, Map.empty[String, Any], tracker)
2997+
}
2998+
if (parsedPlan.isInstanceOf[CompoundBody]) {
2999+
// If the parsed plan is a CompoundBody, skip analysis and return it.
3000+
// SQL scripting is a special case as execution occurs during the analysis phase.
3001+
parsedPlan
3002+
} else {
3003+
runWithRefs(refsToAnalyze) {
3004+
new QueryExecution(session, parsedPlan, tracker).analyzed
3005+
}
29643006
}
29653007
}
29663008

@@ -3156,11 +3198,32 @@ class SparkConnectPlanner(
31563198
}
31573199
}
31583200

3159-
private def transformAndRunCommand(transformer: QueryPlanningTracker => LogicalPlan): Unit = {
3160-
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3161-
val qe = new QueryExecution(session, transformer(tracker), tracker)
3162-
qe.assertCommandExecuted()
3163-
executeHolder.eventsManager.postFinished()
3201+
private[connect] def runCommand(
3202+
command: LogicalPlan,
3203+
tracker: QueryPlanningTracker,
3204+
responseObserver: StreamObserver[ExecutePlanResponse],
3205+
protoCommand: proto.Command,
3206+
shuffleCleanupMode: Option[ShuffleCleanupMode] = None): Unit = {
3207+
if (protoCommand.getCommandTypeCase == proto.Command.CommandTypeCase.SQL_COMMAND) {
3208+
runSQLCommand(
3209+
command,
3210+
tracker,
3211+
responseObserver,
3212+
protoCommand.getSqlCommand,
3213+
shuffleCleanupMode)
3214+
} else {
3215+
val qe = if (shuffleCleanupMode.isDefined) {
3216+
new QueryExecution(
3217+
session,
3218+
command,
3219+
tracker = tracker,
3220+
shuffleCleanupMode = shuffleCleanupMode.get)
3221+
} else {
3222+
new QueryExecution(session, command, tracker = tracker)
3223+
}
3224+
qe.assertCommandExecuted()
3225+
executeHolder.eventsManager.postFinished()
3226+
}
31643227
}
31653228

31663229
/**
@@ -4104,7 +4167,7 @@ class SparkConnectPlanner(
41044167

41054168
private def transformWithRelations(getWithRelations: proto.WithRelations): LogicalPlan = {
41064169
if (isValidSQLWithRefs(getWithRelations)) {
4107-
transformSqlWithRefs(getWithRelations)
4170+
executeSQLWithRefs(getWithRelations).logicalPlan
41084171
} else {
41094172
// Wrap the plan to keep the original planId.
41104173
val plan = Project(Seq(UnresolvedStar(None)), transformRelation(getWithRelations.getRoot))

0 commit comments

Comments
 (0)