Skip to content

Commit 76c9516

Browse files
cloud-fanzhengruifeng
authored andcommitted
[SPARK-54835][SQL] Avoid unnecessary temp QueryExecution for nested command execution
### What changes were proposed in this pull request? This PR is a small refactor. In DS v2 CRAS/RTAS command, we run a nested `AppendData`/`OverwriteByExpression` command by creating a `QueryExecution`. This `QueryExecution` will create another temp `QueryExecution` to eagerly execute commands. This PR avoids the unnecessary temp `QueryExecution` by using `CommandExecutionMode.SKIP` to create `QueryExecution`. ### Why are the changes needed? Remove useless temp `QueryExecution` objects. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? cursor 2.2.43 Closes #53596 from cloud-fan/command. Lead-authored-by: Wenchen Fan <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent eaaf3ca commit 76c9516

File tree

4 files changed

+68
-11
lines changed

4 files changed

+68
-11
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,8 @@ class QueryExecution(
178178
// with the rest of processing of the root plan being just outputting command results,
179179
// for eagerly executed commands we mark this place as beginning of execution.
180180
tracker.setReadyForExecution()
181-
val qe = new QueryExecution(sparkSession, p, mode = mode,
182-
shuffleCleanupMode = shuffleCleanupMode, refreshPhaseEnabled = refreshPhaseEnabled)
183-
val result = QueryExecution.withInternalError(s"Eagerly executed $name failed.") {
184-
SQLExecution.withNewExecutionId(qe, Some(name)) {
185-
qe.executedPlan.executeCollect()
186-
}
187-
}
181+
val (qe, result) = QueryExecution.runCommand(
182+
sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode))
188183
CommandResult(
189184
qe.analyzed.output,
190185
qe.commandExecuted,
@@ -763,4 +758,28 @@ object QueryExecution {
763758
case _ => false
764759
}
765760
}
761+
762+
def runCommand(
763+
sparkSession: SparkSession,
764+
command: LogicalPlan,
765+
name: String,
766+
refreshPhaseEnabled: Boolean = true,
767+
mode: CommandExecutionMode.Value = CommandExecutionMode.SKIP,
768+
shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None)
769+
: (QueryExecution, Array[InternalRow]) = {
770+
val shuffleCleanupMode = shuffleCleanupModeOpt.getOrElse(
771+
determineShuffleCleanupMode(sparkSession.sessionState.conf))
772+
val qe = new QueryExecution(
773+
sparkSession,
774+
command,
775+
mode = mode,
776+
shuffleCleanupMode = shuffleCleanupMode,
777+
refreshPhaseEnabled = refreshPhaseEnabled)
778+
val result = QueryExecution.withInternalError(s"Executed $name failed.") {
779+
SQLExecution.withNewExecutionId(qe, Some(name)) {
780+
qe.executedPlan.executeCollect()
781+
}
782+
}
783+
(qe, result)
784+
}
766785
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import org.apache.spark.sql.classic.ClassicConversions.castToImpl
3939
import org.apache.spark.sql.classic.Dataset
4040
import org.apache.spark.sql.connector.catalog.TableProvider
4141
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
42+
import org.apache.spark.sql.execution.QueryExecution
4243
import org.apache.spark.sql.execution.command.DataWritingCommand
4344
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
4445
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
@@ -531,8 +532,7 @@ case class DataSource(
531532
disallowWritingIntervals(
532533
outputColumns.toStructType.asNullable, format.toString, forbidAnsiIntervals = false)
533534
val cmd = planForWritingFileFormat(format, mode, data)
534-
val qe = sessionState(sparkSession).executePlan(cmd)
535-
qe.assertCommandExecuted()
535+
QueryExecution.runCommand(sparkSession, cmd, "file source write")
536536
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring
537537
copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation()
538538
case _ => throw SparkException.internalError(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,8 +737,11 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec {
737737
} else {
738738
AppendData.byPosition(relation, query, writeOptions)
739739
}
740-
val qe = QueryExecution.create(session, writeCommand, refreshPhaseEnabled)
741-
qe.assertCommandExecuted()
740+
QueryExecution.runCommand(
741+
session,
742+
writeCommand,
743+
"inner data writing for CTAS/RTAS",
744+
refreshPhaseEnabled)
742745
DataSourceV2Utils.commitStagedChanges(sparkContext, table, metrics)
743746
Nil
744747
})(catchBlock = {

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2053,4 +2053,39 @@ class DataSourceV2DataFrameSuite
20532053
case _ => fail(s"can't pin $ident in $catalogName")
20542054
}
20552055
}
2056+
2057+
test("CTAS/RTAS should trigger two query executions") {
2058+
// CTAS/RTAS triggers 2 query executions:
2059+
// 1. The outer CTAS/RTAS command execution
2060+
// 2. The inner AppendData/OverwriteByExpression execution
2061+
var executionCount = 0
2062+
val listener = new QueryExecutionListener {
2063+
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
2064+
executionCount += 1
2065+
}
2066+
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
2067+
}
2068+
2069+
try {
2070+
spark.listenerManager.register(listener)
2071+
val t = "testcat.ns1.ns2.tbl"
2072+
withTable(t) {
2073+
// Test CTAS (CreateTableAsSelect)
2074+
executionCount = 0
2075+
sql(s"CREATE TABLE $t USING foo AS SELECT 1 as id, 'a' as data")
2076+
sparkContext.listenerBus.waitUntilEmpty()
2077+
assert(executionCount == 2,
2078+
s"CTAS should trigger 2 executions, got $executionCount")
2079+
2080+
// Test RTAS (ReplaceTableAsSelect)
2081+
executionCount = 0
2082+
sql(s"CREATE OR REPLACE TABLE $t USING foo AS SELECT 2 as id, 'b' as data")
2083+
sparkContext.listenerBus.waitUntilEmpty()
2084+
assert(executionCount == 2,
2085+
s"RTAS should trigger 2 executions, got $executionCount")
2086+
}
2087+
} finally {
2088+
spark.listenerManager.unregister(listener)
2089+
}
2090+
}
20562091
}

0 commit comments

Comments
 (0)