Skip to content

Commit 0ae6515

Browse files
authored
Avoid duplicated writer nodes when AQE enabled (#2982)
* feat: Avoid duplicated write nodes for AQE execution
1 parent a951da9 commit 0ae6515

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.sql.execution._
3333
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
3434
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
3535
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
36+
import org.apache.spark.sql.execution.datasources.WriteFilesExec
3637
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
3738
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
3839
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -197,6 +198,14 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
197198
case op if shouldApplySparkToColumnar(conf, op) =>
198199
convertToComet(op, CometSparkToColumnarExec).getOrElse(op)
199200

201+
// AQE reoptimization looks for `DataWritingCommandExec` or `WriteFilesExec`
202+
// if there is none it would reinsert write nodes, and since Comet remap those nodes
203+
// to Comet counterparties the write nodes are twice to the plan.
204+
// Checking if AQE inserted another write Command on top of existing write command
205+
case _ @DataWritingCommandExec(_, w: WriteFilesExec)
206+
if w.child.isInstanceOf[CometNativeWriteExec] =>
207+
w.child
208+
200209
case op: DataWritingCommandExec =>
201210
convertToComet(op, CometDataWritingCommand).getOrElse(op)
202211

spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class CometParquetWriterSuite extends CometTestBase {
5454

5555
private def writeWithCometNativeWriteExec(
5656
inputPath: String,
57-
outputPath: String): Option[QueryExecution] = {
57+
outputPath: String,
58+
num_partitions: Option[Int] = None): Option[QueryExecution] = {
5859
val df = spark.read.parquet(inputPath)
5960

6061
// Use a listener to capture the execution plan during write
@@ -77,8 +78,8 @@ class CometParquetWriterSuite extends CometTestBase {
7778
spark.listenerManager.register(listener)
7879

7980
try {
80-
// Perform native write
81-
df.write.parquet(outputPath)
81+
// Perform native write with optional partitioning
82+
num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath)
8283

8384
// Wait for listener to be called with timeout
8485
val maxWaitTimeMs = 15000
@@ -97,20 +98,25 @@ class CometParquetWriterSuite extends CometTestBase {
9798
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")
9899

99100
capturedPlan.foreach { qe =>
100-
val executedPlan = qe.executedPlan
101-
val hasNativeWrite = executedPlan.exists {
102-
case _: CometNativeWriteExec => true
101+
val executedPlan = stripAQEPlan(qe.executedPlan)
102+
103+
// Count CometNativeWriteExec instances in the plan
104+
var nativeWriteCount = 0
105+
executedPlan.foreach {
106+
case _: CometNativeWriteExec =>
107+
nativeWriteCount += 1
103108
case d: DataWritingCommandExec =>
104-
d.child.exists {
105-
case _: CometNativeWriteExec => true
106-
case _ => false
109+
d.child.foreach {
110+
case _: CometNativeWriteExec =>
111+
nativeWriteCount += 1
112+
case _ =>
107113
}
108-
case _ => false
114+
case _ =>
109115
}
110116

111117
assert(
112-
hasNativeWrite,
113-
s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}")
118+
nativeWriteCount == 1,
119+
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${executedPlan.treeString}")
114120
}
115121
} finally {
116122
spark.listenerManager.unregister(listener)
@@ -197,4 +203,29 @@ class CometParquetWriterSuite extends CometTestBase {
197203
}
198204
}
199205
}
206+
207+
test("basic parquet write with repartition") {
208+
withTempPath { dir =>
209+
// Create test data and write it to a temp parquet file first
210+
withTempPath { inputDir =>
211+
val inputPath = createTestData(inputDir)
212+
Seq(true, false).foreach(adaptive => {
213+
// Create a new output path for each AQE value
214+
val outputPath = new File(dir, s"output_aqe_$adaptive.parquet").getAbsolutePath
215+
216+
withSQLConf(
217+
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
218+
"spark.sql.adaptive.enabled" -> adaptive.toString,
219+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
220+
CometConf.getOperatorAllowIncompatConfigKey(
221+
classOf[DataWritingCommandExec]) -> "true",
222+
CometConf.COMET_EXEC_ENABLED.key -> "true") {
223+
224+
writeWithCometNativeWriteExec(inputPath, outputPath, Some(10))
225+
verifyWrittenFile(outputPath)
226+
}
227+
})
228+
}
229+
}
230+
}
200231
}

0 commit comments

Comments
 (0)