Skip to content

Commit 56f4d01

Browse files
committed
feat: Avoid duplicated write nodes for AQE execution
1 parent a122a14 commit 56f4d01

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.sql.execution._
3333
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
3434
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
3535
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
36+
import org.apache.spark.sql.execution.datasources.WriteFilesExec
3637
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
3738
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
3839
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -197,6 +198,14 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
197198
case op if shouldApplySparkToColumnar(conf, op) =>
198199
convertToComet(op, CometSparkToColumnarExec).getOrElse(op)
199200

201+
// AQE reoptimization looks for `DataWritingCommandExec` or `WriteFilesExec`
202+
// if there is none it would reinsert write nodes, and since Comet remap those nodes
203+
// to Comet counterparties the write nodes are twice to the plan.
204+
// Checking if AQE inserted another write Command on top of existing write command
205+
case _ @DataWritingCommandExec(_, w: WriteFilesExec)
206+
if w.child.isInstanceOf[CometNativeWriteExec] =>
207+
w.child
208+
200209
case op: DataWritingCommandExec =>
201210
convertToComet(op, CometDataWritingCommand).getOrElse(op)
202211

spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class CometParquetWriterSuite extends CometTestBase {
5454

5555
private def writeWithCometNativeWriteExec(
5656
inputPath: String,
57-
outputPath: String): Option[QueryExecution] = {
57+
outputPath: String,
58+
num_partitions: Option[Int] = None): Option[QueryExecution] = {
5859
val df = spark.read.parquet(inputPath)
5960

6061
// Use a listener to capture the execution plan during write
@@ -77,8 +78,8 @@ class CometParquetWriterSuite extends CometTestBase {
7778
spark.listenerManager.register(listener)
7879

7980
try {
80-
// Perform native write
81-
df.write.parquet(outputPath)
81+
// Perform native write with optional partitioning
82+
num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath)
8283

8384
// Wait for listener to be called with timeout
8485
val maxWaitTimeMs = 15000
@@ -97,7 +98,7 @@ class CometParquetWriterSuite extends CometTestBase {
9798
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")
9899

99100
capturedPlan.foreach { qe =>
100-
val executedPlan = qe.executedPlan
101+
val executedPlan = stripAQEPlan(qe.executedPlan)
101102
val hasNativeWrite = executedPlan.exists {
102103
case _: CometNativeWriteExec => true
103104
case d: DataWritingCommandExec =>
@@ -197,4 +198,30 @@ class CometParquetWriterSuite extends CometTestBase {
197198
}
198199
}
199200
}
201+
202+
test("basic parquet write with repartition") {
203+
withTempPath { dir =>
204+
val outputPath = new File(dir, "output.parquet").getAbsolutePath
205+
206+
// Create test data and write it to a temp parquet file first
207+
withTempPath { inputDir =>
208+
val inputPath = createTestData(inputDir)
209+
Seq(true, false).foreach(adaptive => {
210+
withSQLConf(
211+
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
212+
"spark.sql.adaptive.enabled" -> adaptive.toString,
213+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
214+
CometConf.getOperatorAllowIncompatConfigKey(
215+
classOf[DataWritingCommandExec]) -> "true",
216+
CometConf.COMET_EXEC_ENABLED.key -> "true") {
217+
218+
val plan = writeWithCometNativeWriteExec(inputPath, outputPath, Some(10))
219+
println(plan)
220+
221+
verifyWrittenFile(outputPath)
222+
}
223+
})
224+
}
225+
}
226+
}
200227
}

0 commit comments

Comments
 (0)