Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/pyspark/sql/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def test_nested_dataframe(self):
self.assertEqual(df3.take(1), [Row(id=4)])
self.assertEqual(df3.tail(1), [Row(id=9)])

def test_nested_dataframe_for_sql_scripting(self):
with self.sql_conf({"spark.sql.scripting.enabled": True}):
df0 = self.spark.range(10)
df1 = self.spark.sql(
"BEGIN SELECT * FROM {df} WHERE id > 1; END;",
df=df0,
)
self.assertEqual(df1.count(), 8)
self.assertEqual(df1.take(1), [Row(id=2)])
self.assertEqual(df1.tail(1), [Row(id=9)])

def test_lit_time(self):
import datetime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
import org.apache.spark.sql.connect.planner.{InvalidInputErrors, SparkConnectPlanner}
import org.apache.spark.sql.connect.service.ExecuteHolder
import org.apache.spark.sql.connect.utils.MetricGenerator
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, RemoveShuffleFiles, SkipMigration, SQLExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -83,13 +83,13 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
val command = request.getPlan.getCommand
planner.transformCommand(command) match {
case Some(transformer) =>
val qe = new QueryExecution(
session,
transformer(tracker),
val plan = transformer(tracker)
planner.runCommand(
plan,
tracker,
shuffleCleanupMode = shuffleCleanupMode)
qe.assertCommandExecuted()
executeHolder.eventsManager.postFinished()
responseObserver,
command,
shuffleCleanupMode = Some(shuffleCleanupMode))
case None =>
planner.process(command, responseObserver)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
Expand Down Expand Up @@ -68,7 +68,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService}
import org.apache.spark.sql.connect.utils.MetricGenerator
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.{QueryExecution, ShuffleCleanupMode}
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, TypedAggregateExpression}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExternalCommandExecutor}
Expand Down Expand Up @@ -342,13 +342,6 @@ class SparkConnectPlanner(
}
}

private def transformSqlWithRefs(query: proto.WithRelations): LogicalPlan = {
if (!isValidSQLWithRefs(query)) {
throw InvalidInputErrors.invalidSQLWithReferences(query)
}
executeSQLWithRefs(query).logicalPlan
}

private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = {
val aliasIdentifier =
if (alias.getQualifierCount > 0) {
Expand Down Expand Up @@ -2650,6 +2643,8 @@ class SparkConnectPlanner(
Some(transformWriteOperation(command.getWriteOperation))
case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
Some(transformWriteOperationV2(command.getWriteOperationV2))
case proto.Command.CommandTypeCase.SQL_COMMAND =>
Some(transformSqlCommand(command.getSqlCommand))
case _ =>
None
}
Expand All @@ -2660,7 +2655,8 @@ class SparkConnectPlanner(
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
val transformerOpt = transformCommand(command)
if (transformerOpt.isDefined) {
transformAndRunCommand(transformerOpt.get)
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
runCommand(transformerOpt.get(tracker), tracker, responseObserver, command)
return
}
command.getCommandTypeCase match {
Expand All @@ -2674,8 +2670,6 @@ class SparkConnectPlanner(
handleCreateViewCommand(command.getCreateDataframeView)
case proto.Command.CommandTypeCase.EXTENSION =>
handleCommandPlugin(command.getExtension)
case proto.Command.CommandTypeCase.SQL_COMMAND =>
handleSqlCommand(command.getSqlCommand, responseObserver)
case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START =>
handleWriteStreamOperationStart(command.getWriteStreamOperationStart, responseObserver)
case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND =>
Expand Down Expand Up @@ -2780,12 +2774,8 @@ class SparkConnectPlanner(
.build())
}

private def handleSqlCommand(
command: SqlCommand,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()

val relation = if (command.hasInput) {
private def getRelationFromSQLCommand(command: proto.SqlCommand): proto.Relation = {
if (command.hasInput) {
command.getInput
} else {
// for backward compatibility
Expand All @@ -2802,19 +2792,47 @@ class SparkConnectPlanner(
.build())
.build()
}
}

val df = relation.getRelTypeCase match {
private def transformSqlCommand(command: proto.SqlCommand)(
tracker: QueryPlanningTracker): LogicalPlan = {
val relation = getRelationFromSQLCommand(command)

relation.getRelTypeCase match {
case proto.Relation.RelTypeCase.SQL =>
executeSQL(relation.getSql, tracker)
transformSQL(relation.getSql, tracker)
case proto.Relation.RelTypeCase.WITH_RELATIONS =>
executeSQLWithRefs(relation.getWithRelations, tracker)
transformSQLWithRefs(relation.getWithRelations, tracker)
case other =>
throw InvalidInputErrors.sqlCommandExpectsSqlOrWithRelations(other)
}
}

private def runSQLCommand(
command: LogicalPlan,
tracker: QueryPlanningTracker,
responseObserver: StreamObserver[ExecutePlanResponse],
protoSQLCommand: proto.SqlCommand,
shuffleCleanupMode: Option[ShuffleCleanupMode]): Unit = {
val isSqlScript = command.isInstanceOf[CompoundBody]
val refs = if (isSqlScript && protoSQLCommand.getInput.hasWithRelations) {
protoSQLCommand.getInput.getWithRelations.getReferencesList.asScala
.map(_.getSubqueryAlias)
.toSeq
} else {
Seq.empty
}

val df = runWithRefs(refs) {
if (shuffleCleanupMode.isDefined) {
Dataset.ofRows(session, command, tracker, shuffleCleanupMode.get)
} else {
Dataset.ofRows(session, command, tracker)
}
}

// Check if command or SQL Script has been executed.
val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult]
val isSqlScript = df.queryExecution.logical.isInstanceOf[CompoundBody]
val rows = df.logicalPlan match {
case lr: LocalRelation => lr.data
case cr: CommandResult => cr.rows
Expand Down Expand Up @@ -2866,7 +2884,7 @@ class SparkConnectPlanner(
} else {
// No execution triggered for relations. Manually set ready
tracker.setReadyForExecution()
result.setRelation(relation)
result.setRelation(getRelationFromSQLCommand(protoSQLCommand))
}
executeHolder.eventsManager.postFinished(Some(rows.size))
// Exactly one SQL Command Result Batch
Expand Down Expand Up @@ -2908,59 +2926,83 @@ class SparkConnectPlanner(
true
}

private def executeSQLWithRefs(
query: proto.WithRelations,
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
if (!isValidSQLWithRefs(query)) {
throw InvalidInputErrors.invalidSQLWithReferences(query)
private def runWithRefs[T](refs: Seq[proto.SubqueryAlias])(f: => T): T = {
if (refs.isEmpty) {
return f
}

// Eagerly execute commands of the provided SQL string, with given references.
val sql = query.getRoot.getSql
this.synchronized {
try {
query.getReferencesList.asScala.foreach { ref =>
refs.foreach { ref =>
Dataset
.ofRows(session, transformRelation(ref.getSubqueryAlias.getInput))
.createOrReplaceTempView(ref.getSubqueryAlias.getAlias)
.ofRows(session, transformRelation(ref.getInput))
.createOrReplaceTempView(ref.getAlias)
}
executeSQL(sql, tracker)
f
} finally {
// drop all temporary views
query.getReferencesList.asScala.foreach { ref =>
session.catalog.dropTempView(ref.getSubqueryAlias.getAlias)
refs.foreach { ref =>
session.catalog.dropTempView(ref.getAlias)
}
}
}
}

private def executeSQL(
sql: proto.SQL,
private def transformSQLWithRefs(
query: proto.WithRelations,
tracker: QueryPlanningTracker): LogicalPlan = {
if (!isValidSQLWithRefs(query)) {
throw InvalidInputErrors.invalidSQLWithReferences(query)
}

transformSQL(
query.getRoot.getSql,
tracker,
query.getReferencesList.asScala.map(_.getSubqueryAlias).toSeq)
}

private def executeSQLWithRefs(
query: proto.WithRelations,
tracker: QueryPlanningTracker = new QueryPlanningTracker) = {
// Eagerly execute commands of the provided SQL string.
Dataset.ofRows(session, transformSQLWithRefs(query, tracker), tracker)
}

private def transformSQL(
sql: proto.SQL,
tracker: QueryPlanningTracker,
refsToAnalyze: Seq[proto.SubqueryAlias] = Seq.empty): LogicalPlan = {
val args = sql.getArgsMap
val namedArguments = sql.getNamedArgumentsMap
val posArgs = sql.getPosArgsList
val posArguments = sql.getPosArgumentsList
if (!namedArguments.isEmpty) {
session.sql(
val parsedPlan = if (!namedArguments.isEmpty) {
session.sqlParsedPlan(
sql.getQuery,
namedArguments.asScala.toMap.transform((_, e) => Column(transformExpression(e))),
tracker)
} else if (!posArguments.isEmpty) {
session.sql(
session.sqlParsedPlan(
sql.getQuery,
posArguments.asScala.map(e => Column(transformExpression(e))).toArray,
tracker)
} else if (!args.isEmpty) {
session.sql(
session.sqlParsedPlan(
sql.getQuery,
args.asScala.toMap.transform((_, v) => transformLiteral(v)),
tracker)
} else if (!posArgs.isEmpty) {
session.sql(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
session.sqlParsedPlan(sql.getQuery, posArgs.asScala.map(transformLiteral).toArray, tracker)
} else {
session.sql(sql.getQuery, Map.empty[String, Any], tracker)
session.sqlParsedPlan(sql.getQuery, Map.empty[String, Any], tracker)
}
if (parsedPlan.isInstanceOf[CompoundBody]) {
// If the parsed plan is a CompoundBody, skip analysis and return it.
// SQL scripting is a special case as execution occurs during the analysis phase.
parsedPlan
} else {
runWithRefs(refsToAnalyze) {
new QueryExecution(session, parsedPlan, tracker).analyzed
}
}
}

Expand Down Expand Up @@ -3156,11 +3198,32 @@ class SparkConnectPlanner(
}
}

private def transformAndRunCommand(transformer: QueryPlanningTracker => LogicalPlan): Unit = {
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
val qe = new QueryExecution(session, transformer(tracker), tracker)
qe.assertCommandExecuted()
executeHolder.eventsManager.postFinished()
private[connect] def runCommand(
command: LogicalPlan,
tracker: QueryPlanningTracker,
responseObserver: StreamObserver[ExecutePlanResponse],
protoCommand: proto.Command,
shuffleCleanupMode: Option[ShuffleCleanupMode] = None): Unit = {
if (protoCommand.getCommandTypeCase == proto.Command.CommandTypeCase.SQL_COMMAND) {
runSQLCommand(
command,
tracker,
responseObserver,
protoCommand.getSqlCommand,
shuffleCleanupMode)
} else {
val qe = if (shuffleCleanupMode.isDefined) {
new QueryExecution(
session,
command,
tracker = tracker,
shuffleCleanupMode = shuffleCleanupMode.get)
} else {
new QueryExecution(session, command, tracker = tracker)
}
qe.assertCommandExecuted()
executeHolder.eventsManager.postFinished()
}
}

/**
Expand Down Expand Up @@ -4104,7 +4167,7 @@ class SparkConnectPlanner(

private def transformWithRelations(getWithRelations: proto.WithRelations): LogicalPlan = {
if (isValidSQLWithRefs(getWithRelations)) {
transformSqlWithRefs(getWithRelations)
executeSQLWithRefs(getWithRelations).logicalPlan
} else {
// Wrap the plan to keep the original planId.
val plan = Project(Seq(UnresolvedStar(None)), transformRelation(getWithRelations.getRoot))
Expand Down
Loading